Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \par History
10 : !> nequip implementation
11 : !> \author Gabriele Tocci
12 : ! **************************************************************************************************
13 : MODULE manybody_nequip
14 :
15 : USE atomic_kind_types, ONLY: atomic_kind_type
16 : USE cell_types, ONLY: cell_type
17 : USE fist_neighbor_list_types, ONLY: fist_neighbor_type,&
18 : neighbor_kind_pairs_type
19 : USE fist_nonbond_env_types, ONLY: fist_nonbond_env_get,&
20 : fist_nonbond_env_set,&
21 : fist_nonbond_env_type,&
22 : nequip_data_type,&
23 : pos_type
24 : USE kinds, ONLY: dp,&
25 : int_8,&
26 : sp
27 : USE message_passing, ONLY: mp_para_env_type
28 : USE pair_potential_types, ONLY: nequip_pot_type,&
29 : nequip_type,&
30 : pair_potential_pp_type,&
31 : pair_potential_single_type
32 : USE particle_types, ONLY: particle_type
33 : USE torch_api, ONLY: &
34 : torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
35 : torch_model_forward, torch_model_freeze, torch_model_load, torch_tensor_data_ptr, &
36 : torch_tensor_from_array, torch_tensor_release, torch_tensor_type
37 : USE util, ONLY: sort
38 : #include "./base/base_uses.f90"
39 :
40 : IMPLICIT NONE
41 :
42 : PRIVATE
43 : PUBLIC :: setup_nequip_arrays, destroy_nequip_arrays, &
44 : nequip_energy_store_force_virial, nequip_add_force_virial
45 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
46 :
47 : CONTAINS
48 :
49 : ! **************************************************************************************************
50 : !> \brief ...
51 : !> \param nonbonded ...
52 : !> \param potparm ...
53 : !> \param glob_loc_list ...
54 : !> \param glob_cell_v ...
55 : !> \param glob_loc_list_a ...
56 : !> \param cell ...
57 : !> \par History
58 : !> Implementation of the nequip potential - [gtocci] 2022
59 : !> \author Gabriele Tocci - University of Zurich
60 : ! **************************************************************************************************
61 4 : SUBROUTINE setup_nequip_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, cell)
62 : TYPE(fist_neighbor_type), POINTER :: nonbonded
63 : TYPE(pair_potential_pp_type), POINTER :: potparm
64 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
65 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
66 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
67 : TYPE(cell_type), POINTER :: cell
68 :
69 : CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_nequip_arrays'
70 :
71 : INTEGER :: handle, i, iend, igrp, ikind, ilist, &
72 : ipair, istart, jkind, nkinds, npairs, &
73 : npairs_tot
74 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: work_list, work_list2
75 4 : INTEGER, DIMENSION(:, :), POINTER :: list
76 4 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: rwork_list
77 : REAL(KIND=dp), DIMENSION(3) :: cell_v, cvi
78 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
79 : TYPE(pair_potential_single_type), POINTER :: pot
80 :
81 0 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
82 4 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
83 4 : CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
84 4 : CALL timeset(routineN, handle)
85 4 : npairs_tot = 0
86 4 : nkinds = SIZE(potparm%pot, 1)
87 112 : DO ilist = 1, nonbonded%nlists
88 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
89 108 : npairs = neighbor_kind_pair%npairs
90 108 : IF (npairs == 0) CYCLE
91 163 : Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
92 116 : istart = neighbor_kind_pair%grp_kind_start(igrp)
93 116 : iend = neighbor_kind_pair%grp_kind_end(igrp)
94 116 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
95 116 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
96 116 : pot => potparm%pot(ikind, jkind)%pot
97 116 : npairs = iend - istart + 1
98 116 : IF (pot%no_mb) CYCLE
99 340 : DO i = 1, SIZE(pot%type)
100 232 : IF (pot%type(i) == nequip_type) npairs_tot = npairs_tot + npairs
101 : END DO
102 : END DO Kind_Group_Loop1
103 : END DO
104 12 : ALLOCATE (work_list(npairs_tot))
105 8 : ALLOCATE (work_list2(npairs_tot))
106 12 : ALLOCATE (glob_loc_list(2, npairs_tot))
107 12 : ALLOCATE (glob_cell_v(3, npairs_tot))
108 : ! Fill arrays with data
109 4 : npairs_tot = 0
110 112 : DO ilist = 1, nonbonded%nlists
111 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
112 108 : npairs = neighbor_kind_pair%npairs
113 108 : IF (npairs == 0) CYCLE
114 163 : Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
115 116 : istart = neighbor_kind_pair%grp_kind_start(igrp)
116 116 : iend = neighbor_kind_pair%grp_kind_end(igrp)
117 116 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
118 116 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
119 116 : list => neighbor_kind_pair%list
120 464 : cvi = neighbor_kind_pair%cell_vector
121 116 : pot => potparm%pot(ikind, jkind)%pot
122 116 : npairs = iend - istart + 1
123 116 : IF (pot%no_mb) CYCLE
124 1508 : cell_v = MATMUL(cell%hmat, cvi)
125 340 : DO i = 1, SIZE(pot%type)
126 : ! NEQUIP
127 232 : IF (pot%type(i) == nequip_type) THEN
128 5096 : DO ipair = 1, npairs
129 29880 : glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
130 20036 : glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
131 : END DO
132 116 : npairs_tot = npairs_tot + npairs
133 : END IF
134 : END DO
135 : END DO Kind_Group_Loop2
136 : END DO
137 : ! Order the arrays w.r.t. the first index of glob_loc_list
138 4 : CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
139 4984 : DO ipair = 1, npairs_tot
140 4984 : work_list2(ipair) = glob_loc_list(2, work_list(ipair))
141 : END DO
142 4984 : glob_loc_list(2, :) = work_list2
143 4 : DEALLOCATE (work_list2)
144 12 : ALLOCATE (rwork_list(3, npairs_tot))
145 4984 : DO ipair = 1, npairs_tot
146 19924 : rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
147 : END DO
148 19924 : glob_cell_v = rwork_list
149 4 : DEALLOCATE (rwork_list)
150 4 : DEALLOCATE (work_list)
151 12 : ALLOCATE (glob_loc_list_a(npairs_tot))
152 9968 : glob_loc_list_a = glob_loc_list(1, :)
153 4 : CALL timestop(handle)
154 8 : END SUBROUTINE setup_nequip_arrays
155 :
156 : ! **************************************************************************************************
157 : !> \brief ...
158 : !> \param glob_loc_list ...
159 : !> \param glob_cell_v ...
160 : !> \param glob_loc_list_a ...
161 : !> \par History
162 : !> Implementation of the nequip potential - [gtocci] 2022
163 : !> \author Gabriele Tocci - University of Zurich
164 : ! **************************************************************************************************
165 4 : SUBROUTINE destroy_nequip_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a)
166 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
167 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
168 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
169 :
170 4 : IF (ASSOCIATED(glob_loc_list)) THEN
171 4 : DEALLOCATE (glob_loc_list)
172 : END IF
173 4 : IF (ASSOCIATED(glob_loc_list_a)) THEN
174 4 : DEALLOCATE (glob_loc_list_a)
175 : END IF
176 4 : IF (ASSOCIATED(glob_cell_v)) THEN
177 4 : DEALLOCATE (glob_cell_v)
178 : END IF
179 :
180 4 : END SUBROUTINE destroy_nequip_arrays
181 : ! **************************************************************************************************
182 : !> \brief ...
183 : !> \param nonbonded ...
184 : !> \param particle_set ...
185 : !> \param cell ...
186 : !> \param atomic_kind_set ...
187 : !> \param potparm ...
188 : !> \param nequip ...
189 : !> \param glob_loc_list_a ...
190 : !> \param r_last_update_pbc ...
191 : !> \param pot_nequip ...
192 : !> \param fist_nonbond_env ...
193 : !> \param para_env ...
194 : !> \param use_virial ...
195 : !> \par History
196 : !> Implementation of the nequip potential - [gtocci] 2022
197 : !> Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
198 : !> \author Gabriele Tocci - University of Zurich
199 : ! **************************************************************************************************
200 4 : SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
201 : potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
202 : pot_nequip, fist_nonbond_env, para_env, use_virial)
203 :
204 : TYPE(fist_neighbor_type), POINTER :: nonbonded
205 : TYPE(particle_type), POINTER :: particle_set(:)
206 : TYPE(cell_type), POINTER :: cell
207 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
208 : TYPE(pair_potential_pp_type), POINTER :: potparm
209 : TYPE(nequip_pot_type), POINTER :: nequip
210 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
211 : TYPE(pos_type), DIMENSION(:), POINTER :: r_last_update_pbc
212 : REAL(kind=dp) :: pot_nequip
213 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
214 : TYPE(mp_para_env_type), POINTER :: para_env
215 : LOGICAL, INTENT(IN) :: use_virial
216 :
217 : CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
218 :
219 : INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
220 : ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
221 : nedges, nedges_tot, nloc_size, npairs, nunique
222 4 : INTEGER(kind=int_8), ALLOCATABLE :: atom_types(:)
223 4 : INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
224 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: displ, displ_cell, edge_count, &
225 4 : edge_count_cell, work_list
226 4 : INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
227 4 : LOGICAL, ALLOCATABLE :: use_atom(:)
228 : REAL(kind=dp) :: drij, rab2_max, rij(3)
229 4 : REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, pos, &
230 4 : temp_edge_cell_shifts
231 : REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
232 4 : REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, total_energy
233 4 : REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d
234 4 : REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, lattice_sp, pos_sp
235 4 : REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp, &
236 4 : total_energy_sp
237 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
238 : TYPE(nequip_data_type), POINTER :: nequip_data
239 : TYPE(pair_potential_single_type), POINTER :: pot
240 : TYPE(torch_dict_type) :: inputs, outputs
241 : TYPE(torch_tensor_type) :: atom_types_tensor, atomic_energy_tensor, edge_cell_shifts_tensor, &
242 : forces_tensor, lattice_tensor, pos_tensor, t_edge_index_tensor, total_energy_tensor, &
243 : virial_tensor
244 :
245 4 : CALL timeset(routineN, handle)
246 :
247 4 : NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp, virial3d)
248 4 : n_atoms = SIZE(particle_set)
249 12 : ALLOCATE (use_atom(n_atoms))
250 202 : use_atom = .FALSE.
251 :
252 12 : DO ikind = 1, SIZE(atomic_kind_set)
253 28 : DO jkind = 1, SIZE(atomic_kind_set)
254 16 : pot => potparm%pot(ikind, jkind)%pot
255 40 : DO i = 1, SIZE(pot%type)
256 16 : IF (pot%type(i) /= nequip_type) CYCLE
257 824 : DO iat = 1, n_atoms
258 792 : IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
259 610 : particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
260 : END DO ! iat
261 : END DO ! i
262 : END DO ! jkind
263 : END DO ! ikind
264 202 : n_atoms_use = COUNT(use_atom)
265 :
266 : ! get nequip_data to save force, virial info and to load model
267 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
268 4 : IF (.NOT. ASSOCIATED(nequip_data)) THEN
269 52 : ALLOCATE (nequip_data)
270 4 : CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=nequip_data)
271 4 : NULLIFY (nequip_data%use_indices, nequip_data%force)
272 4 : CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
273 4 : CALL torch_model_freeze(nequip_data%model)
274 : END IF
275 4 : IF (ASSOCIATED(nequip_data%force)) THEN
276 0 : IF (SIZE(nequip_data%force, 2) /= n_atoms_use) THEN
277 0 : DEALLOCATE (nequip_data%force, nequip_data%use_indices)
278 : END IF
279 : END IF
280 4 : IF (.NOT. ASSOCIATED(nequip_data%force)) THEN
281 12 : ALLOCATE (nequip_data%force(3, n_atoms_use))
282 12 : ALLOCATE (nequip_data%use_indices(n_atoms_use))
283 : END IF
284 :
285 : iat_use = 0
286 202 : DO iat = 1, n_atoms_use
287 202 : IF (use_atom(iat)) THEN
288 198 : iat_use = iat_use + 1
289 198 : nequip_data%use_indices(iat_use) = iat
290 : END IF
291 : END DO
292 :
293 4 : nedges = 0
294 12 : ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
295 12 : ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
296 112 : DO ilist = 1, nonbonded%nlists
297 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
298 108 : npairs = neighbor_kind_pair%npairs
299 108 : IF (npairs == 0) CYCLE
300 163 : Kind_Group_Loop_Nequip: DO igrp = 1, neighbor_kind_pair%ngrp_kind
301 116 : istart = neighbor_kind_pair%grp_kind_start(igrp)
302 116 : iend = neighbor_kind_pair%grp_kind_end(igrp)
303 116 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
304 116 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
305 116 : list => neighbor_kind_pair%list
306 464 : cvi = neighbor_kind_pair%cell_vector
307 116 : pot => potparm%pot(ikind, jkind)%pot
308 340 : DO i = 1, SIZE(pot%type)
309 116 : IF (pot%type(i) /= nequip_type) CYCLE
310 116 : rab2_max = pot%set(i)%nequip%rcutsq
311 1508 : cell_v = MATMUL(cell%hmat, cvi)
312 116 : pot => potparm%pot(ikind, jkind)%pot
313 116 : nequip => pot%set(i)%nequip
314 116 : npairs = iend - istart + 1
315 232 : IF (npairs /= 0) THEN
316 580 : ALLOCATE (sort_list(2, npairs), work_list(npairs))
317 29996 : sort_list = list(:, istart:iend)
318 : ! Sort the list of neighbors, this increases the efficiency for single
319 : ! potential contributions
320 116 : CALL sort(sort_list(1, :), npairs, work_list)
321 5096 : DO ipair = 1, npairs
322 5096 : work_list(ipair) = sort_list(2, work_list(ipair))
323 : END DO
324 5096 : sort_list(2, :) = work_list
325 : ! find number of unique elements of array index 1
326 : nunique = 1
327 4980 : DO ipair = 1, npairs - 1
328 4980 : IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
329 : END DO
330 116 : ipair = 1
331 116 : junique = sort_list(1, ipair)
332 116 : ifirst = 1
333 915 : DO iunique = 1, nunique
334 799 : atom_a = junique
335 799 : IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
336 171934 : DO mpair = ifirst, SIZE(glob_loc_list_a)
337 171934 : IF (glob_loc_list_a(mpair) == atom_a) EXIT
338 : END DO
339 41700 : ifirst = mpair
340 41700 : DO mpair = ifirst, SIZE(glob_loc_list_a)
341 41700 : IF (glob_loc_list_a(mpair) /= atom_a) EXIT
342 : END DO
343 799 : ilast = mpair - 1
344 799 : nloc_size = 0
345 799 : IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
346 5779 : DO WHILE (ipair <= npairs)
347 5663 : IF (sort_list(1, ipair) /= junique) EXIT
348 4980 : atom_b = sort_list(2, ipair)
349 19920 : rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
350 19920 : drij = DOT_PRODUCT(rij, rij)
351 4980 : ipair = ipair + 1
352 5779 : IF (drij <= rab2_max) THEN
353 2576 : nedges = nedges + 1
354 7728 : edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
355 10304 : edge_cell_shifts(:, nedges) = cvi
356 : END IF
357 : END DO
358 799 : ifirst = ilast + 1
359 915 : IF (ipair <= npairs) junique = sort_list(1, ipair)
360 : END DO
361 116 : DEALLOCATE (sort_list, work_list)
362 : END IF
363 : END DO
364 : END DO Kind_Group_Loop_Nequip
365 : END DO
366 :
367 4 : nequip => pot%set(1)%nequip
368 :
369 12 : ALLOCATE (edge_count(para_env%num_pe))
370 8 : ALLOCATE (edge_count_cell(para_env%num_pe))
371 8 : ALLOCATE (displ_cell(para_env%num_pe))
372 8 : ALLOCATE (displ(para_env%num_pe))
373 :
374 4 : CALL para_env%allgather(nedges, edge_count)
375 12 : nedges_tot = SUM(edge_count)
376 :
377 12 : ALLOCATE (temp_edge_index(2, nedges))
378 7732 : temp_edge_index(:, :) = edge_index(:, :nedges)
379 4 : DEALLOCATE (edge_index)
380 12 : ALLOCATE (temp_edge_cell_shifts(3, nedges))
381 10308 : temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
382 4 : DEALLOCATE (edge_cell_shifts)
383 :
384 12 : ALLOCATE (edge_index(2, nedges_tot))
385 12 : ALLOCATE (edge_cell_shifts(3, nedges_tot))
386 8 : ALLOCATE (t_edge_index(nedges_tot, 2))
387 :
388 12 : edge_count_cell(:) = edge_count*3
389 12 : edge_count = edge_count*2
390 4 : displ(1) = 0
391 4 : displ_cell(1) = 0
392 8 : DO ipair = 2, para_env%num_pe
393 4 : displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
394 8 : displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
395 : END DO
396 :
397 4 : CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
398 4 : CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
399 :
400 10316 : t_edge_index(:, :) = TRANSPOSE(edge_index)
401 4 : DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
402 :
403 4 : ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
404 52 : lattice(:, :) = cell%hmat/nequip%unit_cell_val
405 52 : lattice_sp(:, :) = REAL(lattice, kind=sp)
406 :
407 4 : iat_use = 0
408 20 : ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
409 :
410 202 : DO iat = 1, n_atoms_use
411 198 : IF (.NOT. use_atom(iat)) CYCLE
412 198 : iat_use = iat_use + 1
413 : ! Find index of the element based on its position in config.yaml file to have correct mapping
414 594 : DO i = 1, SIZE(nequip%type_names_torch)
415 594 : IF (particle_set(iat)%atomic_kind%element_symbol == nequip%type_names_torch(i)) THEN
416 198 : atom_idx = i - 1
417 : END IF
418 : END DO
419 198 : atom_types(iat_use) = atom_idx
420 796 : pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
421 : END DO
422 :
423 4 : CALL torch_dict_create(inputs)
424 4 : IF (nequip%do_nequip_sp) THEN
425 10 : ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
426 26 : pos_sp(:, :) = REAL(pos(:, :), kind=sp)
427 50 : edge_cell_shifts_sp(:, :) = REAL(edge_cell_shifts(:, :), kind=sp)
428 2 : CALL torch_tensor_from_array(pos_tensor, pos_sp)
429 2 : CALL torch_tensor_from_array(edge_cell_shifts_tensor, edge_cell_shifts_sp)
430 2 : CALL torch_tensor_from_array(lattice_tensor, lattice_sp)
431 : ELSE
432 2 : CALL torch_tensor_from_array(pos_tensor, pos)
433 2 : CALL torch_tensor_from_array(edge_cell_shifts_tensor, edge_cell_shifts)
434 2 : CALL torch_tensor_from_array(lattice_tensor, lattice)
435 : END IF
436 :
437 4 : CALL torch_dict_insert(inputs, "pos", pos_tensor)
438 4 : CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts_tensor)
439 4 : CALL torch_dict_insert(inputs, "cell", lattice_tensor)
440 4 : CALL torch_tensor_release(pos_tensor)
441 4 : CALL torch_tensor_release(edge_cell_shifts_tensor)
442 4 : CALL torch_tensor_release(lattice_tensor)
443 :
444 4 : CALL torch_tensor_from_array(t_edge_index_tensor, t_edge_index)
445 4 : CALL torch_dict_insert(inputs, "edge_index", t_edge_index_tensor)
446 4 : CALL torch_tensor_release(t_edge_index_tensor)
447 :
448 4 : CALL torch_tensor_from_array(atom_types_tensor, atom_types)
449 4 : CALL torch_dict_insert(inputs, "atom_types", atom_types_tensor)
450 4 : CALL torch_tensor_release(atom_types_tensor)
451 :
452 4 : CALL torch_dict_create(outputs)
453 4 : CALL torch_model_forward(nequip_data%model, inputs, outputs)
454 :
455 4 : CALL torch_dict_get(outputs, "total_energy", total_energy_tensor)
456 4 : CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_tensor)
457 4 : CALL torch_dict_get(outputs, "forces", forces_tensor)
458 4 : IF (nequip%do_nequip_sp) THEN
459 2 : CALL torch_tensor_data_ptr(total_energy_tensor, total_energy_sp)
460 2 : CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy_sp)
461 2 : CALL torch_tensor_data_ptr(forces_tensor, forces_sp)
462 2 : pot_nequip = REAL(total_energy_sp(1, 1), kind=dp)*nequip%unit_energy_val
463 26 : nequip_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*nequip%unit_forces_val
464 2 : DEALLOCATE (pos_sp, edge_cell_shifts_sp)
465 : ELSE
466 2 : CALL torch_tensor_data_ptr(total_energy_tensor, total_energy)
467 2 : CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy)
468 2 : CALL torch_tensor_data_ptr(forces_tensor, forces)
469 2 : pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
470 1538 : nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
471 2 : DEALLOCATE (pos, edge_cell_shifts)
472 : END IF
473 4 : CALL torch_tensor_release(total_energy_tensor)
474 4 : CALL torch_tensor_release(atomic_energy_tensor)
475 4 : CALL torch_tensor_release(forces_tensor)
476 :
477 4 : IF (use_virial) THEN
478 2 : CALL torch_dict_get(outputs, "virial", virial_tensor)
479 2 : CALL torch_tensor_data_ptr(virial_tensor, virial3d)
480 26 : nequip_data%virial(:, :) = RESHAPE(virial3d, (/3, 3/))*nequip%unit_energy_val
481 2 : CALL torch_tensor_release(virial_tensor)
482 : END IF
483 :
484 4 : CALL torch_dict_release(inputs)
485 4 : CALL torch_dict_release(outputs)
486 :
487 4 : DEALLOCATE (t_edge_index, atom_types)
488 :
489 : ! account for double counting from multiple MPI processes
490 4 : pot_nequip = pot_nequip/REAL(para_env%num_pe, dp)
491 796 : nequip_data%force = nequip_data%force/REAL(para_env%num_pe, dp)
492 28 : IF (use_virial) nequip_data%virial(:, :) = nequip_data%virial/REAL(para_env%num_pe, dp)
493 :
494 4 : CALL timestop(handle)
495 8 : END SUBROUTINE nequip_energy_store_force_virial
496 :
497 : ! **************************************************************************************************
498 : !> \brief ...
499 : !> \param fist_nonbond_env ...
500 : !> \param f_nonbond ...
501 : !> \param pv_nonbond ...
502 : !> \param use_virial ...
503 : ! **************************************************************************************************
504 4 : SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
505 :
506 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
507 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: f_nonbond, pv_nonbond
508 : LOGICAL, INTENT(IN) :: use_virial
509 :
510 : INTEGER :: iat, iat_use
511 : TYPE(nequip_data_type), POINTER :: nequip_data
512 :
513 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
514 :
515 4 : IF (use_virial) THEN
516 26 : pv_nonbond = pv_nonbond + nequip_data%virial
517 : END IF
518 :
519 202 : DO iat_use = 1, SIZE(nequip_data%use_indices)
520 198 : iat = nequip_data%use_indices(iat_use)
521 198 : CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
522 796 : f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + nequip_data%force(1:3, iat_use)
523 : END DO
524 :
525 4 : END SUBROUTINE nequip_add_force_virial
526 : END MODULE manybody_nequip
527 :
|