Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Module for equivariant PAO-ML based on PyTorch.
10 : !> \author Ole Schuett
11 : ! **************************************************************************************************
12 : MODULE pao_model
13 : USE OMP_LIB, ONLY: omp_init_lock,&
14 : omp_set_lock,&
15 : omp_unset_lock
16 : USE atomic_kind_types, ONLY: atomic_kind_type,&
17 : get_atomic_kind
18 : USE basis_set_types, ONLY: gto_basis_set_type
19 : USE cell_types, ONLY: cell_type,&
20 : pbc
21 : USE cp_dbcsr_api, ONLY: dbcsr_get_info,&
22 : dbcsr_iterator_blocks_left,&
23 : dbcsr_iterator_next_block,&
24 : dbcsr_iterator_start,&
25 : dbcsr_iterator_stop,&
26 : dbcsr_iterator_type,&
27 : dbcsr_type
28 : USE kinds, ONLY: default_path_length,&
29 : default_string_length,&
30 : dp,&
31 : sp
32 : USE message_passing, ONLY: mp_para_env_type
33 : USE pao_types, ONLY: pao_env_type,&
34 : pao_model_type
35 : USE particle_types, ONLY: particle_type
36 : USE physcon, ONLY: angstrom
37 : USE qs_environment_types, ONLY: get_qs_env,&
38 : qs_environment_type
39 : USE qs_kind_types, ONLY: get_qs_kind,&
40 : qs_kind_type
41 : USE torch_api, ONLY: &
42 : torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
43 : torch_model_forward, torch_model_get_attr, torch_model_load, torch_tensor_backward, &
44 : torch_tensor_data_ptr, torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
45 : torch_tensor_type
46 : USE util, ONLY: sort
47 : #include "./base/base_uses.f90"
48 :
49 : IMPLICIT NONE
50 :
51 : PRIVATE
52 :
53 : PUBLIC :: pao_model_load, pao_model_predict, pao_model_forces, pao_model_type
54 :
55 : CONTAINS
56 :
57 : ! **************************************************************************************************
58 : !> \brief Loads a PAO-ML model.
59 : !> \param pao ...
60 : !> \param qs_env ...
61 : !> \param ikind ...
62 : !> \param pao_model_file ...
63 : !> \param model ...
64 : ! **************************************************************************************************
65 0 : SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
66 : TYPE(pao_env_type), INTENT(IN) :: pao
67 : TYPE(qs_environment_type), INTENT(IN) :: qs_env
68 : INTEGER, INTENT(IN) :: ikind
69 : CHARACTER(LEN=default_path_length), INTENT(IN) :: pao_model_file
70 : TYPE(pao_model_type), INTENT(OUT) :: model
71 :
72 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_load'
73 :
74 : CHARACTER(LEN=default_string_length) :: kind_name
75 : CHARACTER(LEN=default_string_length), &
76 8 : ALLOCATABLE, DIMENSION(:) :: feature_kind_names
77 : INTEGER :: handle, jkind, kkind, pao_basis_size, z
78 8 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
79 : TYPE(gto_basis_set_type), POINTER :: basis_set
80 8 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
81 :
82 8 : CALL timeset(routineN, handle)
83 8 : CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
84 :
85 8 : IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
86 8 : CALL torch_model_load(model%torch_model, pao_model_file)
87 :
88 : ! Read model attributes.
89 8 : CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
90 8 : CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
91 8 : CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
92 8 : CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
93 8 : CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
94 8 : CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
95 8 : CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
96 8 : CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
97 8 : CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
98 :
99 : ! Freeze model after all attributes have been read.
100 : ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
101 : ! https://github.com/pytorch/pytorch/issues/96726
102 : ! CALL torch_model_freeze(model%torch_model)
103 :
104 : ! For each feature kind name lookup its corresponding atomic kind number.
105 24 : ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
106 24 : model%feature_kinds(:) = -1
107 24 : DO jkind = 1, SIZE(feature_kind_names)
108 48 : DO kkind = 1, SIZE(atomic_kind_set)
109 48 : IF (TRIM(atomic_kind_set(kkind)%name) == TRIM(feature_kind_names(jkind))) THEN
110 16 : model%feature_kinds(jkind) = kkind
111 : END IF
112 : END DO
113 24 : IF (model%feature_kinds(jkind) < 0) THEN
114 0 : IF (pao%iw > 0) &
115 : WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
116 0 : TRIM(feature_kind_names(jkind))//"' that is not present in subsys."
117 : END IF
118 : END DO
119 :
120 : ! Check for missing kinds.
121 24 : DO jkind = 1, SIZE(atomic_kind_set)
122 32 : IF (ALL(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
123 0 : IF (pao%iw > 0) &
124 : WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
125 0 : TRIM(atomic_kind_set(jkind)%name)//"' that is present in subsys."
126 : END IF
127 : END DO
128 :
129 : ! Check compatibility
130 8 : CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
131 8 : CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
132 8 : IF (model%version /= 1) &
133 0 : CPABORT("Model version not supported.")
134 8 : IF (TRIM(model%kind_name) .NE. TRIM(kind_name)) &
135 0 : CPABORT("Kind name does not match.")
136 8 : IF (model%atomic_number /= z) &
137 0 : CPABORT("Atomic number does not match.")
138 8 : IF (TRIM(model%prim_basis_name) .NE. TRIM(basis_set%name)) &
139 0 : CPABORT("Primary basis set name does not match.")
140 8 : IF (model%prim_basis_size /= basis_set%nsgf) &
141 0 : CPABORT("Primary basis set size does not match.")
142 8 : IF (model%pao_basis_size /= pao_basis_size) &
143 0 : CPABORT("PAO basis size does not match.")
144 :
145 8 : CALL omp_init_lock(model%lock)
146 8 : CALL timestop(handle)
147 :
148 24 : END SUBROUTINE pao_model_load
149 :
150 : ! **************************************************************************************************
151 : !> \brief Fills pao%matrix_X based on machine learning predictions
152 : !> \param pao ...
153 : !> \param qs_env ...
154 : ! **************************************************************************************************
155 16 : SUBROUTINE pao_model_predict(pao, qs_env)
156 : TYPE(pao_env_type), POINTER :: pao
157 : TYPE(qs_environment_type), POINTER :: qs_env
158 :
159 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_predict'
160 :
161 : INTEGER :: acol, arow, handle, iatom
162 16 : REAL(dp), DIMENSION(:, :), POINTER :: block_X
163 : TYPE(dbcsr_iterator_type) :: iter
164 :
165 16 : CALL timeset(routineN, handle)
166 :
167 16 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
168 : CALL dbcsr_iterator_start(iter, pao%matrix_X)
169 : DO WHILE (dbcsr_iterator_blocks_left(iter))
170 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
171 : IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
172 : iatom = arow; CPASSERT(arow == acol)
173 : CALL predict_single_atom(pao, qs_env, iatom, block_X=block_X)
174 : END DO
175 : CALL dbcsr_iterator_stop(iter)
176 : !$OMP END PARALLEL
177 :
178 16 : CALL timestop(handle)
179 :
180 16 : END SUBROUTINE pao_model_predict
181 :
182 : ! **************************************************************************************************
183 : !> \brief Calculate forces contributed by machine learning
184 : !> \param pao ...
185 : !> \param qs_env ...
186 : !> \param matrix_G ...
187 : !> \param forces ...
188 : ! **************************************************************************************************
189 2 : SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
190 : TYPE(pao_env_type), POINTER :: pao
191 : TYPE(qs_environment_type), POINTER :: qs_env
192 : TYPE(dbcsr_type) :: matrix_G
193 : REAL(dp), DIMENSION(:, :), INTENT(INOUT) :: forces
194 :
195 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_forces'
196 :
197 : INTEGER :: acol, arow, handle, iatom
198 2 : REAL(dp), DIMENSION(:, :), POINTER :: block_G
199 : TYPE(dbcsr_iterator_type) :: iter
200 :
201 2 : CALL timeset(routineN, handle)
202 :
203 2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
204 : CALL dbcsr_iterator_start(iter, matrix_G)
205 : DO WHILE (dbcsr_iterator_blocks_left(iter))
206 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_G)
207 : iatom = arow; CPASSERT(arow == acol)
208 : IF (SIZE(block_G) == 0) CYCLE ! pao disabled for iatom
209 : CALL predict_single_atom(pao, qs_env, iatom, block_G=block_G, forces=forces)
210 : END DO
211 : CALL dbcsr_iterator_stop(iter)
212 : !$OMP END PARALLEL
213 :
214 2 : CALL timestop(handle)
215 :
216 2 : END SUBROUTINE pao_model_forces
217 :
218 : ! **************************************************************************************************
219 : !> \brief Predicts a single block_X.
220 : !> \param pao ...
221 : !> \param qs_env ...
222 : !> \param iatom ...
223 : !> \param block_X ...
224 : !> \param block_G ...
225 : !> \param forces ...
226 : ! **************************************************************************************************
227 54 : SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
228 : TYPE(pao_env_type), INTENT(IN), POINTER :: pao
229 : TYPE(qs_environment_type), INTENT(IN), POINTER :: qs_env
230 : INTEGER, INTENT(IN) :: iatom
231 : REAL(dp), DIMENSION(:, :), OPTIONAL :: block_X, block_G, forces
232 :
233 : INTEGER :: ikind, jatom, jkind, jneighbor, m, n, &
234 : natoms
235 54 : INTEGER, ALLOCATABLE, DIMENSION(:) :: neighbors_index
236 54 : INTEGER, DIMENSION(:), POINTER :: blk_sizes_pao, blk_sizes_pri
237 : REAL(dp), DIMENSION(3) :: Ri, Rij, Rj
238 54 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: neighbors_distance
239 54 : REAL(sp), ALLOCATABLE, DIMENSION(:, :) :: features, outer_grad, relpos
240 54 : REAL(sp), DIMENSION(:, :), POINTER :: predicted_xblock, relpos_grad
241 54 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
242 : TYPE(cell_type), POINTER :: cell
243 : TYPE(mp_para_env_type), POINTER :: para_env
244 : TYPE(pao_model_type), POINTER :: model
245 54 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
246 54 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
247 : TYPE(torch_dict_type) :: model_inputs, model_outputs
248 : TYPE(torch_tensor_type) :: features_tensor, outer_grad_tensor, &
249 : predicted_xblock_tensor, &
250 : relpos_grad_tensor, relpos_tensor
251 :
252 54 : CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
253 54 : n = blk_sizes_pri(iatom) ! size of primary basis
254 54 : m = blk_sizes_pao(iatom) ! size of pao basis
255 :
256 : CALL get_qs_env(qs_env, &
257 : para_env=para_env, &
258 : cell=cell, &
259 : particle_set=particle_set, &
260 : atomic_kind_set=atomic_kind_set, &
261 : qs_kind_set=qs_kind_set, &
262 54 : natom=natoms)
263 :
264 54 : CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
265 54 : model => pao%models(ikind)
266 54 : CPASSERT(model%version > 0)
267 54 : CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
268 :
269 : ! Find neighbors.
270 : ! TODO: this is a quadratic algorithm, use a neighbor-list instead
271 270 : ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
272 216 : Ri = particle_set(iatom)%r
273 378 : DO jatom = 1, natoms
274 1296 : Rj = particle_set(jatom)%r
275 324 : Rij = pbc(Ri, Rj, cell)
276 1350 : neighbors_distance(jatom) = DOT_PRODUCT(Rij, Rij) ! using squared distances for performance
277 : END DO
278 54 : CALL sort(neighbors_distance, natoms, neighbors_index)
279 54 : CPASSERT(neighbors_index(1) == iatom) ! central atom should be closesd to itself
280 :
281 : ! Compute neighbors relative positions.
282 162 : ALLOCATE (relpos(3, model%num_neighbors))
283 1134 : relpos(:, :) = 0.0_sp
284 324 : DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
285 270 : jatom = neighbors_index(jneighbor + 1) ! skipping central atom
286 1080 : Rj = particle_set(jatom)%r
287 270 : Rij = pbc(Ri, Rj, cell)
288 1134 : relpos(:, jneighbor) = REAL(angstrom*Rij, kind=sp)
289 : END DO
290 :
291 : ! Compute neighbors features.
292 216 : ALLOCATE (features(SIZE(model%feature_kinds), model%num_neighbors))
293 864 : features(:, :) = 0.0_sp
294 324 : DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
295 270 : jatom = neighbors_index(jneighbor + 1) ! skipping central atom
296 270 : jkind = particle_set(jatom)%atomic_kind%kind_number
297 864 : WHERE (model%feature_kinds == jkind) features(:, jneighbor) = 1.0_sp
298 : END DO
299 :
300 : ! Inference.
301 54 : CALL torch_dict_create(model_inputs)
302 :
303 54 : CALL torch_tensor_from_array(relpos_tensor, relpos, requires_grad=PRESENT(block_G))
304 54 : CALL torch_dict_insert(model_inputs, "neighbors_relpos", relpos_tensor)
305 54 : CALL torch_tensor_from_array(features_tensor, features)
306 54 : CALL torch_dict_insert(model_inputs, "neighbors_features", features_tensor)
307 :
308 54 : CALL torch_dict_create(model_outputs)
309 54 : CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
310 :
311 : ! Copy predicted XBlock.
312 54 : NULLIFY (predicted_xblock)
313 54 : CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
314 54 : CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
315 54 : CPASSERT(SIZE(predicted_xblock, 1) == n .AND. SIZE(predicted_xblock, 2) == m)
316 54 : IF (PRESENT(block_X)) THEN
317 1664 : block_X = RESHAPE(predicted_xblock, [n*m, 1])
318 : END IF
319 :
320 : ! TURNING POINT (if calc forces) ------------------------------------------
321 54 : IF (PRESENT(block_G)) THEN
322 24 : ALLOCATE (outer_grad(n, m))
323 226 : outer_grad(:, :) = REAL(RESHAPE(block_G, [n, m]), kind=sp)
324 6 : CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
325 6 : CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
326 6 : CALL torch_tensor_grad(relpos_tensor, relpos_grad_tensor)
327 6 : NULLIFY (relpos_grad)
328 6 : CALL torch_tensor_data_ptr(relpos_grad_tensor, relpos_grad)
329 6 : CPASSERT(SIZE(relpos_grad, 1) == 3 .AND. SIZE(relpos_grad, 2) == model%num_neighbors)
330 36 : DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
331 30 : jatom = neighbors_index(jneighbor + 1) ! skipping central atom
332 120 : forces(iatom, :) = forces(iatom, :) + relpos_grad(:, jneighbor)*angstrom
333 126 : forces(jatom, :) = forces(jatom, :) - relpos_grad(:, jneighbor)*angstrom
334 : END DO
335 6 : CALL torch_tensor_release(outer_grad_tensor)
336 6 : CALL torch_tensor_release(relpos_grad_tensor)
337 : END IF
338 :
339 : ! Clean up.
340 54 : CALL torch_tensor_release(relpos_tensor)
341 54 : CALL torch_tensor_release(features_tensor)
342 54 : CALL torch_tensor_release(predicted_xblock_tensor)
343 54 : CALL torch_dict_release(model_inputs)
344 54 : CALL torch_dict_release(model_outputs)
345 54 : DEALLOCATE (neighbors_distance, neighbors_index, relpos, features)
346 54 : CALL omp_unset_lock(model%lock)
347 :
348 162 : END SUBROUTINE predict_single_atom
349 :
350 : END MODULE pao_model
|