Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \par History
10 : !> allegro implementation
11 : !> \author Gabriele Tocci
12 : ! **************************************************************************************************
13 : MODULE manybody_allegro
14 :
15 : USE atomic_kind_types, ONLY: atomic_kind_type
16 : USE cell_types, ONLY: cell_type
17 : USE fist_neighbor_list_types, ONLY: fist_neighbor_type,&
18 : neighbor_kind_pairs_type
19 : USE fist_nonbond_env_types, ONLY: allegro_data_type,&
20 : fist_nonbond_env_get,&
21 : fist_nonbond_env_set,&
22 : fist_nonbond_env_type,&
23 : pos_type
24 : USE kinds, ONLY: dp,&
25 : int_8,&
26 : sp
27 : USE message_passing, ONLY: mp_para_env_type
28 : USE pair_potential_types, ONLY: allegro_pot_type,&
29 : allegro_type,&
30 : pair_potential_pp_type,&
31 : pair_potential_single_type
32 : USE particle_types, ONLY: particle_type
33 : USE torch_api, ONLY: &
34 : torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
35 : torch_model_forward, torch_model_freeze, torch_model_load, torch_tensor_data_ptr, &
36 : torch_tensor_from_array, torch_tensor_release, torch_tensor_type
37 : USE util, ONLY: sort
38 : #include "./base/base_uses.f90"
39 :
40 : IMPLICIT NONE
41 :
42 : PRIVATE
43 : PUBLIC :: setup_allegro_arrays, destroy_allegro_arrays, &
44 : allegro_energy_store_force_virial, allegro_add_force_virial
45 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_allegro'
46 :
47 : CONTAINS
48 :
49 : ! **************************************************************************************************
50 : !> \brief ...
51 : !> \param nonbonded ...
52 : !> \param potparm ...
53 : !> \param glob_loc_list ...
54 : !> \param glob_cell_v ...
55 : !> \param glob_loc_list_a ...
56 : !> \param unique_list_a ...
57 : !> \param cell ...
58 : !> \par History
59 : !> Implementation of the allegro potential - [gtocci] 2023
60 : !> \author Gabriele Tocci - University of Zurich
61 : ! **************************************************************************************************
62 4 : SUBROUTINE setup_allegro_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, &
63 : unique_list_a, cell)
64 : TYPE(fist_neighbor_type), POINTER :: nonbonded
65 : TYPE(pair_potential_pp_type), POINTER :: potparm
66 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
67 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
68 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a, unique_list_a
69 : TYPE(cell_type), POINTER :: cell
70 :
71 : CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_allegro_arrays'
72 :
73 : INTEGER :: handle, i, iend, igrp, ikind, ilist, &
74 : ipair, istart, jkind, nkinds, nlocal, &
75 : npairs, npairs_tot
76 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: temp_unique_list_a, work_list, work_list2
77 4 : INTEGER, DIMENSION(:, :), POINTER :: list
78 4 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: rwork_list
79 : REAL(KIND=dp), DIMENSION(3) :: cell_v, cvi
80 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
81 : TYPE(pair_potential_single_type), POINTER :: pot
82 :
83 0 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
84 4 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
85 4 : CPASSERT(.NOT. ASSOCIATED(unique_list_a))
86 4 : CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
87 4 : CALL timeset(routineN, handle)
88 4 : npairs_tot = 0
89 4 : nkinds = SIZE(potparm%pot, 1)
90 112 : DO ilist = 1, nonbonded%nlists
91 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
92 108 : npairs = neighbor_kind_pair%npairs
93 108 : IF (npairs == 0) CYCLE
94 258 : Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
95 169 : istart = neighbor_kind_pair%grp_kind_start(igrp)
96 169 : iend = neighbor_kind_pair%grp_kind_end(igrp)
97 169 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
98 169 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
99 169 : pot => potparm%pot(ikind, jkind)%pot
100 169 : npairs = iend - istart + 1
101 169 : IF (pot%no_mb) CYCLE
102 446 : DO i = 1, SIZE(pot%type)
103 338 : IF (pot%type(i) == allegro_type) npairs_tot = npairs_tot + npairs
104 : END DO
105 : END DO Kind_Group_Loop1
106 : END DO
107 12 : ALLOCATE (work_list(npairs_tot))
108 8 : ALLOCATE (work_list2(npairs_tot))
109 12 : ALLOCATE (glob_loc_list(2, npairs_tot))
110 12 : ALLOCATE (glob_cell_v(3, npairs_tot))
111 : ! Fill arrays with data
112 4 : npairs_tot = 0
113 112 : DO ilist = 1, nonbonded%nlists
114 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
115 108 : npairs = neighbor_kind_pair%npairs
116 108 : IF (npairs == 0) CYCLE
117 258 : Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
118 169 : istart = neighbor_kind_pair%grp_kind_start(igrp)
119 169 : iend = neighbor_kind_pair%grp_kind_end(igrp)
120 169 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
121 169 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
122 169 : list => neighbor_kind_pair%list
123 676 : cvi = neighbor_kind_pair%cell_vector
124 169 : pot => potparm%pot(ikind, jkind)%pot
125 169 : npairs = iend - istart + 1
126 169 : IF (pot%no_mb) CYCLE
127 2197 : cell_v = MATMUL(cell%hmat, cvi)
128 446 : DO i = 1, SIZE(pot%type)
129 : ! ALLEGRO
130 338 : IF (pot%type(i) == allegro_type) THEN
131 10533 : DO ipair = 1, npairs
132 62184 : glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
133 41625 : glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
134 : END DO
135 169 : npairs_tot = npairs_tot + npairs
136 : END IF
137 : END DO
138 : END DO Kind_Group_Loop2
139 : END DO
140 : ! Order the arrays w.r.t. the first index of glob_loc_list
141 4 : CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
142 10368 : DO ipair = 1, npairs_tot
143 10368 : work_list2(ipair) = glob_loc_list(2, work_list(ipair))
144 : END DO
145 10368 : glob_loc_list(2, :) = work_list2
146 4 : DEALLOCATE (work_list2)
147 12 : ALLOCATE (rwork_list(3, npairs_tot))
148 10368 : DO ipair = 1, npairs_tot
149 41460 : rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
150 : END DO
151 41460 : glob_cell_v = rwork_list
152 4 : DEALLOCATE (rwork_list)
153 4 : DEALLOCATE (work_list)
154 12 : ALLOCATE (glob_loc_list_a(npairs_tot))
155 20736 : glob_loc_list_a = glob_loc_list(1, :)
156 8 : ALLOCATE (temp_unique_list_a(npairs_tot))
157 4 : nlocal = 1
158 4 : temp_unique_list_a(1) = glob_loc_list_a(1)
159 10364 : DO ipair = 2, npairs_tot
160 10364 : IF (glob_loc_list_a(ipair - 1) /= glob_loc_list_a(ipair)) THEN
161 156 : nlocal = nlocal + 1
162 156 : temp_unique_list_a(nlocal) = glob_loc_list_a(ipair)
163 : END IF
164 : END DO
165 12 : ALLOCATE (unique_list_a(nlocal))
166 164 : unique_list_a(:) = temp_unique_list_a(:nlocal)
167 4 : DEALLOCATE (temp_unique_list_a)
168 4 : CALL timestop(handle)
169 8 : END SUBROUTINE setup_allegro_arrays
170 :
171 : ! **************************************************************************************************
172 : !> \brief ...
173 : !> \param glob_loc_list ...
174 : !> \param glob_cell_v ...
175 : !> \param glob_loc_list_a ...
176 : !> \param unique_list_a ...
177 : !> \par History
178 : !> Implementation of the allegro potential - [gtocci] 2023
179 : !> \author Gabriele Tocci - University of Zurich
180 : ! **************************************************************************************************
181 4 : SUBROUTINE destroy_allegro_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a, unique_list_a)
182 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
183 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
184 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a, unique_list_a
185 :
186 4 : IF (ASSOCIATED(glob_loc_list)) THEN
187 4 : DEALLOCATE (glob_loc_list)
188 : END IF
189 4 : IF (ASSOCIATED(glob_loc_list_a)) THEN
190 4 : DEALLOCATE (glob_loc_list_a)
191 : END IF
192 4 : IF (ASSOCIATED(glob_cell_v)) THEN
193 4 : DEALLOCATE (glob_cell_v)
194 : END IF
195 4 : IF (ASSOCIATED(unique_list_a)) THEN
196 4 : DEALLOCATE (unique_list_a)
197 : END IF
198 :
199 4 : END SUBROUTINE destroy_allegro_arrays
200 :
201 : ! **************************************************************************************************
202 : !> \brief ...
203 : !> \param nonbonded ...
204 : !> \param particle_set ...
205 : !> \param cell ...
206 : !> \param atomic_kind_set ...
207 : !> \param potparm ...
208 : !> \param allegro ...
209 : !> \param glob_loc_list_a ...
210 : !> \param r_last_update_pbc ...
211 : !> \param pot_allegro ...
212 : !> \param fist_nonbond_env ...
213 : !> \param unique_list_a ...
214 : !> \param para_env ...
215 : !> \param use_virial ...
216 : !> \par History
217 : !> Implementation of the allegro potential - [gtocci] 2023
218 : !> Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
219 : !> \author Gabriele Tocci - University of Zurich
220 : ! **************************************************************************************************
221 4 : SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
222 : potparm, allegro, glob_loc_list_a, r_last_update_pbc, &
223 : pot_allegro, fist_nonbond_env, unique_list_a, para_env, use_virial)
224 :
225 : TYPE(fist_neighbor_type), POINTER :: nonbonded
226 : TYPE(particle_type), POINTER :: particle_set(:)
227 : TYPE(cell_type), POINTER :: cell
228 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
229 : TYPE(pair_potential_pp_type), POINTER :: potparm
230 : TYPE(allegro_pot_type), POINTER :: allegro
231 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
232 : TYPE(pos_type), DIMENSION(:), POINTER :: r_last_update_pbc
233 : REAL(kind=dp) :: pot_allegro
234 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
235 : INTEGER, DIMENSION(:), POINTER :: unique_list_a
236 : TYPE(mp_para_env_type), POINTER :: para_env
237 : LOGICAL, INTENT(IN) :: use_virial
238 :
239 : CHARACTER(LEN=*), PARAMETER :: routineN = 'allegro_energy_store_force_virial'
240 :
241 : INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
242 : ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
243 : nedges, nloc_size, npairs, nunique
244 4 : INTEGER(kind=int_8), ALLOCATABLE :: atom_types(:), temp_atom_types(:)
245 4 : INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
246 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: work_list
247 4 : INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
248 4 : LOGICAL, ALLOCATABLE :: use_atom(:)
249 : REAL(kind=dp) :: drij, rab2_max, rij(3)
250 4 : REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, &
251 4 : new_edge_cell_shifts, pos
252 : REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
253 4 : REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, virial
254 4 : REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d
255 4 : REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: lattice_sp, new_edge_cell_shifts_sp, &
256 4 : pos_sp
257 4 : REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp
258 : TYPE(allegro_data_type), POINTER :: allegro_data
259 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
260 : TYPE(pair_potential_single_type), POINTER :: pot
261 : TYPE(torch_dict_type) :: inputs, outputs
262 : TYPE(torch_tensor_type) :: atom_types_tensor, atomic_energy_tensor, forces_tensor, &
263 : lattice_tensor, new_edge_cell_shifts_tensor, pos_tensor, t_edge_index_tensor, &
264 : virial_tensor
265 :
266 4 : CALL timeset(routineN, handle)
267 :
268 4 : NULLIFY (atomic_energy, forces, atomic_energy_sp, forces_sp, virial3d, virial)
269 4 : n_atoms = SIZE(particle_set)
270 12 : ALLOCATE (use_atom(n_atoms))
271 324 : use_atom = .FALSE.
272 :
273 10 : DO ikind = 1, SIZE(atomic_kind_set)
274 20 : DO jkind = 1, SIZE(atomic_kind_set)
275 10 : pot => potparm%pot(ikind, jkind)%pot
276 26 : DO i = 1, SIZE(pot%type)
277 10 : IF (pot%type(i) /= allegro_type) CYCLE
278 916 : DO iat = 1, n_atoms
279 896 : IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
280 714 : particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
281 : END DO ! iat
282 : END DO ! i
283 : END DO ! jkind
284 : END DO ! ikind
285 324 : n_atoms_use = COUNT(use_atom)
286 :
287 : ! get allegro_data to save force, virial info and to load model
288 4 : CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
289 4 : IF (.NOT. ASSOCIATED(allegro_data)) THEN
290 52 : ALLOCATE (allegro_data)
291 4 : CALL fist_nonbond_env_set(fist_nonbond_env, allegro_data=allegro_data)
292 4 : NULLIFY (allegro_data%use_indices, allegro_data%force)
293 4 : CALL torch_model_load(allegro_data%model, pot%set(1)%allegro%allegro_file_name)
294 4 : CALL torch_model_freeze(allegro_data%model)
295 : END IF
296 4 : IF (ASSOCIATED(allegro_data%force)) THEN
297 0 : IF (SIZE(allegro_data%force, 2) /= n_atoms_use) THEN
298 0 : DEALLOCATE (allegro_data%force, allegro_data%use_indices)
299 : END IF
300 : END IF
301 4 : IF (.NOT. ASSOCIATED(allegro_data%force)) THEN
302 12 : ALLOCATE (allegro_data%force(3, n_atoms_use))
303 12 : ALLOCATE (allegro_data%use_indices(n_atoms_use))
304 : END IF
305 :
306 : iat_use = 0
307 324 : DO iat = 1, n_atoms_use
308 324 : IF (use_atom(iat)) THEN
309 320 : iat_use = iat_use + 1
310 320 : allegro_data%use_indices(iat_use) = iat
311 : END IF
312 : END DO
313 :
314 4 : nedges = 0
315 :
316 12 : ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
317 12 : ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
318 12 : ALLOCATE (temp_atom_types(SIZE(glob_loc_list_a)))
319 :
320 112 : DO ilist = 1, nonbonded%nlists
321 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
322 108 : npairs = neighbor_kind_pair%npairs
323 108 : IF (npairs == 0) CYCLE
324 258 : Kind_Group_Loop_Allegro: DO igrp = 1, neighbor_kind_pair%ngrp_kind
325 169 : istart = neighbor_kind_pair%grp_kind_start(igrp)
326 169 : iend = neighbor_kind_pair%grp_kind_end(igrp)
327 169 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
328 169 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
329 169 : list => neighbor_kind_pair%list
330 676 : cvi = neighbor_kind_pair%cell_vector
331 169 : pot => potparm%pot(ikind, jkind)%pot
332 446 : DO i = 1, SIZE(pot%type)
333 169 : IF (pot%type(i) /= allegro_type) CYCLE
334 169 : rab2_max = pot%set(i)%allegro%rcutsq
335 2197 : cell_v = MATMUL(cell%hmat, cvi)
336 169 : pot => potparm%pot(ikind, jkind)%pot
337 169 : allegro => pot%set(i)%allegro
338 169 : npairs = iend - istart + 1
339 338 : IF (npairs /= 0) THEN
340 845 : ALLOCATE (sort_list(2, npairs), work_list(npairs))
341 62353 : sort_list = list(:, istart:iend)
342 : ! Sort the list of neighbors, this increases the efficiency for single
343 : ! potential contributions
344 169 : CALL sort(sort_list(1, :), npairs, work_list)
345 10533 : DO ipair = 1, npairs
346 10533 : work_list(ipair) = sort_list(2, work_list(ipair))
347 : END DO
348 10533 : sort_list(2, :) = work_list
349 : ! find number of unique elements of array index 1
350 : nunique = 1
351 10364 : DO ipair = 1, npairs - 1
352 10364 : IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
353 : END DO
354 169 : ipair = 1
355 169 : junique = sort_list(1, ipair)
356 169 : ifirst = 1
357 1538 : DO iunique = 1, nunique
358 1369 : atom_a = junique
359 1369 : IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
360 360440 : DO mpair = ifirst, SIZE(glob_loc_list_a)
361 360440 : IF (glob_loc_list_a(mpair) == atom_a) EXIT
362 : END DO
363 106529 : ifirst = mpair
364 106529 : DO mpair = ifirst, SIZE(glob_loc_list_a)
365 106529 : IF (glob_loc_list_a(mpair) /= atom_a) EXIT
366 : END DO
367 1369 : ilast = mpair - 1
368 1369 : nloc_size = 0
369 1369 : IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
370 11733 : DO WHILE (ipair <= npairs)
371 11564 : IF (sort_list(1, ipair) /= junique) EXIT
372 10364 : atom_b = sort_list(2, ipair)
373 41456 : rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
374 41456 : drij = DOT_PRODUCT(rij, rij)
375 10364 : ipair = ipair + 1
376 11733 : IF (drij <= rab2_max) THEN
377 5998 : nedges = nedges + 1
378 17994 : edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
379 23992 : edge_cell_shifts(:, nedges) = cvi
380 : END IF
381 : END DO
382 1369 : ifirst = ilast + 1
383 1538 : IF (ipair <= npairs) junique = sort_list(1, ipair)
384 : END DO
385 169 : DEALLOCATE (sort_list, work_list)
386 : END IF
387 : END DO
388 : END DO Kind_Group_Loop_Allegro
389 : END DO
390 :
391 4 : allegro => pot%set(1)%allegro
392 :
393 12 : ALLOCATE (temp_edge_index(2, nedges))
394 17998 : temp_edge_index(:, :) = edge_index(:, :nedges)
395 12 : ALLOCATE (new_edge_cell_shifts(3, nedges))
396 23996 : new_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
397 4 : DEALLOCATE (edge_cell_shifts)
398 :
399 8 : ALLOCATE (t_edge_index(nedges, 2))
400 :
401 12008 : t_edge_index(:, :) = TRANSPOSE(temp_edge_index)
402 4 : DEALLOCATE (temp_edge_index, edge_index)
403 4 : ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
404 52 : lattice(:, :) = cell%hmat/pot%set(1)%allegro%unit_cell_val
405 52 : lattice_sp(:, :) = REAL(lattice, kind=sp)
406 4 : iat_use = 0
407 20 : ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
408 324 : DO iat = 1, n_atoms_use
409 320 : IF (.NOT. use_atom(iat)) CYCLE
410 320 : iat_use = iat_use + 1
411 : ! Find index of the element based on its position in config.yaml file to have correct mapping
412 1024 : DO i = 1, SIZE(allegro%type_names_torch)
413 1024 : IF (particle_set(iat)%atomic_kind%element_symbol == allegro%type_names_torch(i)) THEN
414 320 : atom_idx = i - 1
415 : END IF
416 : END DO
417 320 : atom_types(iat_use) = atom_idx
418 1284 : pos(:, iat) = r_last_update_pbc(iat)%r(:)/allegro%unit_coords_val
419 : END DO
420 :
421 4 : CALL torch_dict_create(inputs)
422 :
423 4 : IF (allegro%do_allegro_sp) THEN
424 10 : ALLOCATE (new_edge_cell_shifts_sp(3, nedges), pos_sp(3, n_atoms_use))
425 19898 : new_edge_cell_shifts_sp(:, :) = REAL(new_edge_cell_shifts(:, :), kind=sp)
426 770 : pos_sp(:, :) = REAL(pos(:, :), kind=sp)
427 2 : DEALLOCATE (pos, new_edge_cell_shifts)
428 2 : CALL torch_tensor_from_array(pos_tensor, pos_sp)
429 2 : CALL torch_tensor_from_array(new_edge_cell_shifts_tensor, new_edge_cell_shifts_sp)
430 2 : CALL torch_tensor_from_array(lattice_tensor, lattice_sp)
431 : ELSE
432 2 : CALL torch_tensor_from_array(pos_tensor, pos)
433 2 : CALL torch_tensor_from_array(new_edge_cell_shifts_tensor, new_edge_cell_shifts)
434 2 : CALL torch_tensor_from_array(lattice_tensor, lattice)
435 : END IF
436 :
437 4 : CALL torch_dict_insert(inputs, "pos", pos_tensor)
438 4 : CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts_tensor)
439 4 : CALL torch_dict_insert(inputs, "cell", lattice_tensor)
440 4 : CALL torch_tensor_release(pos_tensor)
441 4 : CALL torch_tensor_release(new_edge_cell_shifts_tensor)
442 4 : CALL torch_tensor_release(lattice_tensor)
443 :
444 4 : CALL torch_tensor_from_array(t_edge_index_tensor, t_edge_index)
445 4 : CALL torch_dict_insert(inputs, "edge_index", t_edge_index_tensor)
446 4 : CALL torch_tensor_release(t_edge_index_tensor)
447 :
448 4 : CALL torch_tensor_from_array(atom_types_tensor, atom_types)
449 4 : CALL torch_dict_insert(inputs, "atom_types", atom_types_tensor)
450 4 : CALL torch_tensor_release(atom_types_tensor)
451 :
452 4 : CALL torch_dict_create(outputs)
453 4 : CALL torch_model_forward(allegro_data%model, inputs, outputs)
454 4 : pot_allegro = 0.0_dp
455 :
456 4 : CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_tensor)
457 4 : CALL torch_dict_get(outputs, "forces", forces_tensor)
458 4 : IF (allegro%do_allegro_sp) THEN
459 2 : CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy_sp)
460 2 : CALL torch_tensor_data_ptr(forces_tensor, forces_sp)
461 770 : allegro_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*allegro%unit_forces_val
462 98 : DO iat_use = 1, SIZE(unique_list_a)
463 96 : i = unique_list_a(iat_use)
464 98 : pot_allegro = pot_allegro + REAL(atomic_energy_sp(1, i), kind=dp)*allegro%unit_energy_val
465 : END DO
466 2 : DEALLOCATE (new_edge_cell_shifts_sp, pos_sp)
467 : ELSE
468 2 : CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy)
469 2 : CALL torch_tensor_data_ptr(forces_tensor, forces)
470 :
471 1026 : allegro_data%force(:, :) = forces(:, :)*allegro%unit_forces_val
472 66 : DO iat_use = 1, SIZE(unique_list_a)
473 64 : i = unique_list_a(iat_use)
474 66 : pot_allegro = pot_allegro + atomic_energy(1, i)*allegro%unit_energy_val
475 : END DO
476 2 : DEALLOCATE (pos, new_edge_cell_shifts)
477 : END IF
478 4 : CALL torch_tensor_release(atomic_energy_tensor)
479 4 : CALL torch_tensor_release(forces_tensor)
480 :
481 4 : IF (use_virial) THEN
482 0 : CALL torch_dict_get(outputs, "virial", virial_tensor)
483 0 : CALL torch_tensor_data_ptr(virial_tensor, virial3d)
484 0 : allegro_data%virial(:, :) = RESHAPE(virial3d, (/3, 3/))*allegro%unit_energy_val
485 0 : CALL torch_tensor_release(virial_tensor)
486 : END IF
487 :
488 4 : CALL torch_dict_release(inputs)
489 4 : CALL torch_dict_release(outputs)
490 :
491 4 : DEALLOCATE (t_edge_index, atom_types)
492 :
493 4 : IF (use_virial) allegro_data%virial(:, :) = allegro_data%virial/REAL(para_env%num_pe, dp)
494 4 : CALL timestop(handle)
495 8 : END SUBROUTINE allegro_energy_store_force_virial
496 :
497 : ! **************************************************************************************************
498 : !> \brief ...
499 : !> \param fist_nonbond_env ...
500 : !> \param f_nonbond ...
501 : !> \param pv_nonbond ...
502 : !> \param use_virial ...
503 : ! **************************************************************************************************
504 4 : SUBROUTINE allegro_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
505 :
506 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
507 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: f_nonbond, pv_nonbond
508 : LOGICAL, INTENT(IN) :: use_virial
509 :
510 : INTEGER :: iat, iat_use
511 : TYPE(allegro_data_type), POINTER :: allegro_data
512 :
513 4 : CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
514 :
515 4 : IF (use_virial) THEN
516 0 : pv_nonbond = pv_nonbond + allegro_data%virial
517 : END IF
518 :
519 324 : DO iat_use = 1, SIZE(allegro_data%use_indices)
520 320 : iat = allegro_data%use_indices(iat_use)
521 320 : CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
522 1284 : f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + allegro_data%force(1:3, iat_use)
523 : END DO
524 :
525 4 : END SUBROUTINE allegro_add_force_virial
526 : END MODULE manybody_allegro
527 :
|