Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Module for equivariant PAO-ML based on PyTorch.
10 : !> \author Ole Schuett
11 : ! **************************************************************************************************
12 : MODULE pao_model
13 : USE atomic_kind_types, ONLY: atomic_kind_type,&
14 : get_atomic_kind
15 : USE basis_set_types, ONLY: gto_basis_set_type
16 : USE cell_types, ONLY: cell_type,&
17 : pbc
18 : USE cp_dbcsr_api, ONLY: dbcsr_iterator_blocks_left,&
19 : dbcsr_iterator_next_block,&
20 : dbcsr_iterator_start,&
21 : dbcsr_iterator_stop,&
22 : dbcsr_iterator_type
23 : USE kinds, ONLY: default_path_length,&
24 : default_string_length,&
25 : dp,&
26 : sp
27 : USE message_passing, ONLY: mp_para_env_type
28 : USE pao_types, ONLY: pao_env_type,&
29 : pao_model_type
30 : USE particle_types, ONLY: particle_type
31 : USE physcon, ONLY: angstrom
32 : USE qs_environment_types, ONLY: get_qs_env,&
33 : qs_environment_type
34 : USE qs_kind_types, ONLY: get_qs_kind,&
35 : qs_kind_type
36 : USE torch_api, ONLY: torch_dict_create,&
37 : torch_dict_get,&
38 : torch_dict_insert,&
39 : torch_dict_release,&
40 : torch_dict_type,&
41 : torch_model_eval,&
42 : torch_model_get_attr,&
43 : torch_model_load
44 : USE util, ONLY: sort
45 : #include "./base/base_uses.f90"
46 :
47 : IMPLICIT NONE
48 :
49 : PRIVATE
50 :
51 : PUBLIC :: pao_model_load, pao_model_predict, pao_model_type
52 :
53 : CONTAINS
54 :
55 : ! **************************************************************************************************
56 : !> \brief Loads a PAO-ML model.
57 : !> \param pao ...
58 : !> \param qs_env ...
59 : !> \param ikind ...
60 : !> \param pao_model_file ...
61 : !> \param model ...
62 : ! **************************************************************************************************
63 0 : SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
64 : TYPE(pao_env_type), INTENT(IN) :: pao
65 : TYPE(qs_environment_type), INTENT(IN) :: qs_env
66 : INTEGER, INTENT(IN) :: ikind
67 : CHARACTER(LEN=default_path_length), INTENT(IN) :: pao_model_file
68 : TYPE(pao_model_type), INTENT(OUT) :: model
69 :
70 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_load'
71 :
72 : CHARACTER(LEN=default_string_length) :: kind_name
73 : CHARACTER(LEN=default_string_length), &
74 4 : ALLOCATABLE, DIMENSION(:) :: feature_kind_names
75 : INTEGER :: handle, jkind, kkind, pao_basis_size, z
76 4 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
77 : TYPE(gto_basis_set_type), POINTER :: basis_set
78 4 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
79 :
80 4 : CALL timeset(routineN, handle)
81 4 : CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
82 :
83 4 : IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
84 4 : CALL torch_model_load(model%torch_model, pao_model_file)
85 :
86 : ! Read model attributes.
87 4 : CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
88 4 : CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
89 4 : CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
90 4 : CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
91 4 : CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
92 4 : CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
93 4 : CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
94 4 : CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
95 4 : CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
96 :
97 : ! Freeze model after all attributes have been read.
98 : ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
99 : ! https://github.com/pytorch/pytorch/issues/96726
100 : ! CALL torch_model_freeze(model%torch_model)
101 :
102 : ! For each feature kind name lookup its corresponding atomic kind number.
103 12 : ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
104 12 : model%feature_kinds(:) = -1
105 12 : DO jkind = 1, SIZE(feature_kind_names)
106 24 : DO kkind = 1, SIZE(atomic_kind_set)
107 24 : IF (TRIM(atomic_kind_set(kkind)%name) == TRIM(feature_kind_names(jkind))) THEN
108 8 : model%feature_kinds(jkind) = kkind
109 : END IF
110 : END DO
111 12 : IF (model%feature_kinds(jkind) < 0) THEN
112 0 : IF (pao%iw > 0) &
113 : WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
114 0 : TRIM(feature_kind_names(jkind))//"' that is not present in subsys."
115 : END IF
116 : END DO
117 :
118 : ! Check for missing kinds.
119 12 : DO jkind = 1, SIZE(atomic_kind_set)
120 16 : IF (ALL(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
121 0 : IF (pao%iw > 0) &
122 : WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
123 0 : TRIM(atomic_kind_set(jkind)%name)//"' that is present in subsys."
124 : END IF
125 : END DO
126 :
127 : ! Check compatibility
128 4 : CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
129 4 : CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
130 4 : IF (model%version /= 1) &
131 0 : CPABORT("Model version not supported.")
132 4 : IF (TRIM(model%kind_name) .NE. TRIM(kind_name)) &
133 0 : CPABORT("Kind name does not match.")
134 4 : IF (model%atomic_number /= z) &
135 0 : CPABORT("Atomic number does not match.")
136 4 : IF (TRIM(model%prim_basis_name) .NE. TRIM(basis_set%name)) &
137 0 : CPABORT("Primary basis set name does not match.")
138 4 : IF (model%prim_basis_size /= basis_set%nsgf) &
139 0 : CPABORT("Primary basis set size does not match.")
140 4 : IF (model%pao_basis_size /= pao_basis_size) &
141 0 : CPABORT("PAO basis size does not match.")
142 :
143 4 : CALL timestop(handle)
144 :
145 12 : END SUBROUTINE pao_model_load
146 :
147 : ! **************************************************************************************************
148 : !> \brief Fills pao%matrix_X based on machine learning predictions
149 : !> \param pao ...
150 : !> \param qs_env ...
151 : ! **************************************************************************************************
152 2 : SUBROUTINE pao_model_predict(pao, qs_env)
153 : TYPE(pao_env_type), POINTER :: pao
154 : TYPE(qs_environment_type), POINTER :: qs_env
155 :
156 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_predict'
157 :
158 : INTEGER :: acol, arow, handle, iatom
159 2 : REAL(dp), DIMENSION(:, :), POINTER :: block_X
160 : TYPE(dbcsr_iterator_type) :: iter
161 :
162 2 : CALL timeset(routineN, handle)
163 :
164 2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
165 : CALL dbcsr_iterator_start(iter, pao%matrix_X)
166 : DO WHILE (dbcsr_iterator_blocks_left(iter))
167 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
168 : IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
169 : iatom = arow; CPASSERT(arow == acol)
170 : CALL predict_single_atom(pao, qs_env, iatom, block_X)
171 : END DO
172 : CALL dbcsr_iterator_stop(iter)
173 : !$OMP END PARALLEL
174 :
175 2 : CALL timestop(handle)
176 :
177 2 : END SUBROUTINE pao_model_predict
178 :
179 : ! **************************************************************************************************
180 : !> \brief Predicts a single block_X.
181 : !> \param pao ...
182 : !> \param qs_env ...
183 : !> \param iatom ...
184 : !> \param block_X ...
185 : ! **************************************************************************************************
186 6 : SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X)
187 : TYPE(pao_env_type), INTENT(IN), POINTER :: pao
188 : TYPE(qs_environment_type), INTENT(IN), POINTER :: qs_env
189 : INTEGER, INTENT(IN) :: iatom
190 : REAL(dp), DIMENSION(:, :), INTENT(OUT) :: block_X
191 :
192 : INTEGER :: ikind, jatom, jkind, jneighbor, natoms
193 6 : INTEGER, ALLOCATABLE, DIMENSION(:) :: neighbors_index
194 : REAL(dp), DIMENSION(3) :: Ri, Rij, Rj
195 6 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: neighbors_distance
196 6 : REAL(sp), ALLOCATABLE, DIMENSION(:, :) :: neighbors_features, neighbors_relpos
197 6 : REAL(sp), DIMENSION(:, :), POINTER :: predicted_xblock
198 6 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
199 : TYPE(cell_type), POINTER :: cell
200 : TYPE(mp_para_env_type), POINTER :: para_env
201 : TYPE(pao_model_type), POINTER :: model
202 6 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
203 6 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
204 : TYPE(torch_dict_type) :: model_inputs, model_outputs
205 :
206 : CALL get_qs_env(qs_env, &
207 : para_env=para_env, &
208 : cell=cell, &
209 : particle_set=particle_set, &
210 : atomic_kind_set=atomic_kind_set, &
211 : qs_kind_set=qs_kind_set, &
212 6 : natom=natoms)
213 :
214 6 : CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
215 6 : model => pao%models(ikind)
216 6 : CPASSERT(model%version > 0)
217 :
218 : ! Find neighbors.
219 : ! TODO: this is a quadratic algorithm, use a neighbor-list instead
220 30 : ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
221 24 : Ri = particle_set(iatom)%r
222 42 : DO jatom = 1, natoms
223 144 : Rj = particle_set(jatom)%r
224 36 : Rij = pbc(Ri, Rj, cell)
225 150 : neighbors_distance(jatom) = DOT_PRODUCT(Rij, Rij) ! using squared distances for performance
226 : END DO
227 6 : CALL sort(neighbors_distance, natoms, neighbors_index)
228 6 : CPASSERT(neighbors_index(1) == iatom) ! central atom should be closesd to itself
229 :
230 : ! Compute neighbors relative positions.
231 18 : ALLOCATE (neighbors_relpos(3, model%num_neighbors))
232 126 : neighbors_relpos(:, :) = 0.0_sp
233 36 : DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
234 30 : jatom = neighbors_index(jneighbor + 1) ! skipping central atom
235 120 : Rj = particle_set(jatom)%r
236 30 : Rij = pbc(Ri, Rj, cell)
237 126 : neighbors_relpos(:, jneighbor) = REAL(angstrom*Rij, kind=sp)
238 : END DO
239 :
240 : ! Compute neighbors features.
241 24 : ALLOCATE (neighbors_features(SIZE(model%feature_kinds), model%num_neighbors))
242 96 : neighbors_features(:, :) = 0.0_sp
243 36 : DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
244 30 : jatom = neighbors_index(jneighbor + 1) ! skipping central atom
245 30 : jkind = particle_set(jatom)%atomic_kind%kind_number
246 96 : WHERE (model%feature_kinds == jkind) neighbors_features(:, jneighbor) = 1.0_sp
247 : END DO
248 :
249 : ! Inference.
250 6 : CALL torch_dict_create(model_inputs)
251 6 : CALL torch_dict_insert(model_inputs, "neighbors_relpos", neighbors_relpos)
252 6 : CALL torch_dict_insert(model_inputs, "neighbors_features", neighbors_features)
253 6 : CALL torch_dict_create(model_outputs)
254 6 : CALL torch_model_eval(model%torch_model, model_inputs, model_outputs)
255 :
256 : ! Copy predicted XBlock.
257 6 : NULLIFY (predicted_xblock)
258 6 : CALL torch_dict_get(model_outputs, "xblock", predicted_xblock)
259 220 : block_X = RESHAPE(predicted_xblock, (/SIZE(block_X), 1/))
260 :
261 : ! Clean up.
262 6 : CALL torch_dict_release(model_inputs)
263 6 : CALL torch_dict_release(model_outputs)
264 6 : DEALLOCATE (neighbors_distance, neighbors_index)
265 6 : DEALLOCATE (predicted_xblock, neighbors_relpos, neighbors_features)
266 :
267 6 : END SUBROUTINE predict_single_atom
268 :
269 : END MODULE pao_model
|