LCOV - code coverage report
Current view: top level - src - manybody_allegro.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 256 270 94.8 %
Date: 2024-12-21 06:28:57 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \par History
      10             : !>      allegro implementation
      11             : !> \author Gabriele Tocci
      12             : ! **************************************************************************************************
      13             : MODULE manybody_allegro
      14             : 
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type
      16             :    USE cell_types,                      ONLY: cell_type
      17             :    USE fist_neighbor_list_types,        ONLY: fist_neighbor_type,&
      18             :                                               neighbor_kind_pairs_type
      19             :    USE fist_nonbond_env_types,          ONLY: allegro_data_type,&
      20             :                                               fist_nonbond_env_get,&
      21             :                                               fist_nonbond_env_set,&
      22             :                                               fist_nonbond_env_type,&
      23             :                                               pos_type
      24             :    USE kinds,                           ONLY: dp,&
      25             :                                               int_8,&
      26             :                                               sp
      27             :    USE message_passing,                 ONLY: mp_para_env_type
      28             :    USE pair_potential_types,            ONLY: allegro_pot_type,&
      29             :                                               allegro_type,&
      30             :                                               pair_potential_pp_type,&
      31             :                                               pair_potential_single_type
      32             :    USE particle_types,                  ONLY: particle_type
      33             :    USE torch_api,                       ONLY: torch_dict_create,&
      34             :                                               torch_dict_get,&
      35             :                                               torch_dict_insert,&
      36             :                                               torch_dict_release,&
      37             :                                               torch_dict_type,&
      38             :                                               torch_model_eval,&
      39             :                                               torch_model_freeze,&
      40             :                                               torch_model_load
      41             :    USE util,                            ONLY: sort
      42             : #include "./base/base_uses.f90"
      43             : 
      44             :    IMPLICIT NONE
      45             : 
      46             :    PRIVATE
      47             :    PUBLIC :: setup_allegro_arrays, destroy_allegro_arrays, &
      48             :              allegro_energy_store_force_virial, allegro_add_force_virial
      49             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_allegro'
      50             : 
      51             : CONTAINS
      52             : 
      53             : ! **************************************************************************************************
      54             : !> \brief ...
      55             : !> \param nonbonded ...
      56             : !> \param potparm ...
      57             : !> \param glob_loc_list ...
      58             : !> \param glob_cell_v ...
      59             : !> \param glob_loc_list_a ...
      60             : !> \param unique_list_a ...
      61             : !> \param cell ...
      62             : !> \par History
      63             : !>      Implementation of the allegro potential - [gtocci] 2023
      64             : !> \author Gabriele Tocci - University of Zurich
      65             : ! **************************************************************************************************
      66           4 :    SUBROUTINE setup_allegro_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, &
      67             :                                    unique_list_a, cell)
      68             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
      69             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
      70             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
      71             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
      72             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a, unique_list_a
      73             :       TYPE(cell_type), POINTER                           :: cell
      74             : 
      75             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_allegro_arrays'
      76             : 
      77             :       INTEGER                                            :: handle, i, iend, igrp, ikind, ilist, &
      78             :                                                             ipair, istart, jkind, nkinds, nlocal, &
      79             :                                                             npairs, npairs_tot
      80           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: temp_unique_list_a, work_list, work_list2
      81           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list
      82           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: rwork_list
      83             :       REAL(KIND=dp), DIMENSION(3)                        :: cell_v, cvi
      84             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
      85             :       TYPE(pair_potential_single_type), POINTER          :: pot
      86             : 
      87           0 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
      88           4 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
      89           4 :       CPASSERT(.NOT. ASSOCIATED(unique_list_a))
      90           4 :       CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
      91           4 :       CALL timeset(routineN, handle)
      92           4 :       npairs_tot = 0
      93           4 :       nkinds = SIZE(potparm%pot, 1)
      94         112 :       DO ilist = 1, nonbonded%nlists
      95         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
      96         108 :          npairs = neighbor_kind_pair%npairs
      97         108 :          IF (npairs == 0) CYCLE
      98         258 :          Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
      99         169 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     100         169 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     101         169 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     102         169 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     103         169 :             pot => potparm%pot(ikind, jkind)%pot
     104         169 :             npairs = iend - istart + 1
     105         169 :             IF (pot%no_mb) CYCLE
     106         446 :             DO i = 1, SIZE(pot%type)
     107         338 :                IF (pot%type(i) == allegro_type) npairs_tot = npairs_tot + npairs
     108             :             END DO
     109             :          END DO Kind_Group_Loop1
     110             :       END DO
     111          12 :       ALLOCATE (work_list(npairs_tot))
     112           8 :       ALLOCATE (work_list2(npairs_tot))
     113          12 :       ALLOCATE (glob_loc_list(2, npairs_tot))
     114          12 :       ALLOCATE (glob_cell_v(3, npairs_tot))
     115             :       ! Fill arrays with data
     116           4 :       npairs_tot = 0
     117         112 :       DO ilist = 1, nonbonded%nlists
     118         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     119         108 :          npairs = neighbor_kind_pair%npairs
     120         108 :          IF (npairs == 0) CYCLE
     121         258 :          Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     122         169 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     123         169 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     124         169 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     125         169 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     126         169 :             list => neighbor_kind_pair%list
     127         676 :             cvi = neighbor_kind_pair%cell_vector
     128         169 :             pot => potparm%pot(ikind, jkind)%pot
     129         169 :             npairs = iend - istart + 1
     130         169 :             IF (pot%no_mb) CYCLE
     131        2197 :             cell_v = MATMUL(cell%hmat, cvi)
     132         446 :             DO i = 1, SIZE(pot%type)
     133             :                ! ALLEGRO
     134         338 :                IF (pot%type(i) == allegro_type) THEN
     135       10533 :                   DO ipair = 1, npairs
     136       62184 :                      glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
     137       41625 :                      glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
     138             :                   END DO
     139         169 :                   npairs_tot = npairs_tot + npairs
     140             :                END IF
     141             :             END DO
     142             :          END DO Kind_Group_Loop2
     143             :       END DO
     144             :       ! Order the arrays w.r.t. the first index of glob_loc_list
     145           4 :       CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
     146       10368 :       DO ipair = 1, npairs_tot
     147       10368 :          work_list2(ipair) = glob_loc_list(2, work_list(ipair))
     148             :       END DO
     149       10368 :       glob_loc_list(2, :) = work_list2
     150           4 :       DEALLOCATE (work_list2)
     151          12 :       ALLOCATE (rwork_list(3, npairs_tot))
     152       10368 :       DO ipair = 1, npairs_tot
     153       41460 :          rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
     154             :       END DO
     155       41460 :       glob_cell_v = rwork_list
     156           4 :       DEALLOCATE (rwork_list)
     157           4 :       DEALLOCATE (work_list)
     158          12 :       ALLOCATE (glob_loc_list_a(npairs_tot))
     159       20736 :       glob_loc_list_a = glob_loc_list(1, :)
     160           8 :       ALLOCATE (temp_unique_list_a(npairs_tot))
     161           4 :       nlocal = 1
     162           4 :       temp_unique_list_a(1) = glob_loc_list_a(1)
     163       10364 :       DO ipair = 2, npairs_tot
     164       10364 :          IF (glob_loc_list_a(ipair - 1) /= glob_loc_list_a(ipair)) THEN
     165         156 :             nlocal = nlocal + 1
     166         156 :             temp_unique_list_a(nlocal) = glob_loc_list_a(ipair)
     167             :          END IF
     168             :       END DO
     169          12 :       ALLOCATE (unique_list_a(nlocal))
     170         164 :       unique_list_a(:) = temp_unique_list_a(:nlocal)
     171           4 :       DEALLOCATE (temp_unique_list_a)
     172           4 :       CALL timestop(handle)
     173           8 :    END SUBROUTINE setup_allegro_arrays
     174             : 
     175             : ! **************************************************************************************************
     176             : !> \brief ...
     177             : !> \param glob_loc_list ...
     178             : !> \param glob_cell_v ...
     179             : !> \param glob_loc_list_a ...
     180             : !> \param unique_list_a ...
     181             : !> \par History
     182             : !>      Implementation of the allegro potential - [gtocci] 2023
     183             : !> \author Gabriele Tocci - University of Zurich
     184             : ! **************************************************************************************************
     185           4 :    SUBROUTINE destroy_allegro_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a, unique_list_a)
     186             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
     187             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
     188             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a, unique_list_a
     189             : 
     190           4 :       IF (ASSOCIATED(glob_loc_list)) THEN
     191           4 :          DEALLOCATE (glob_loc_list)
     192             :       END IF
     193           4 :       IF (ASSOCIATED(glob_loc_list_a)) THEN
     194           4 :          DEALLOCATE (glob_loc_list_a)
     195             :       END IF
     196           4 :       IF (ASSOCIATED(glob_cell_v)) THEN
     197           4 :          DEALLOCATE (glob_cell_v)
     198             :       END IF
     199           4 :       IF (ASSOCIATED(unique_list_a)) THEN
     200           4 :          DEALLOCATE (unique_list_a)
     201             :       END IF
     202             : 
     203           4 :    END SUBROUTINE destroy_allegro_arrays
     204             : 
     205             : ! **************************************************************************************************
     206             : !> \brief ...
     207             : !> \param nonbonded ...
     208             : !> \param particle_set ...
     209             : !> \param cell ...
     210             : !> \param atomic_kind_set ...
     211             : !> \param potparm ...
     212             : !> \param allegro ...
     213             : !> \param glob_loc_list_a ...
     214             : !> \param r_last_update_pbc ...
     215             : !> \param pot_allegro ...
     216             : !> \param fist_nonbond_env ...
     217             : !> \param unique_list_a ...
     218             : !> \param para_env ...
     219             : !> \param use_virial ...
     220             : !> \par History
     221             : !>      Implementation of the allegro potential - [gtocci] 2023
     222             : !>      Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
     223             : !> \author Gabriele Tocci - University of Zurich
     224             : ! **************************************************************************************************
     225           4 :    SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
     226             :                                                 potparm, allegro, glob_loc_list_a, r_last_update_pbc, &
     227             :                                                 pot_allegro, fist_nonbond_env, unique_list_a, para_env, use_virial)
     228             : 
     229             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
     230             :       TYPE(particle_type), POINTER                       :: particle_set(:)
     231             :       TYPE(cell_type), POINTER                           :: cell
     232             :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     233             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     234             :       TYPE(allegro_pot_type), POINTER                    :: allegro
     235             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     236             :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_last_update_pbc
     237             :       REAL(kind=dp)                                      :: pot_allegro
     238             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     239             :       INTEGER, DIMENSION(:), POINTER                     :: unique_list_a
     240             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     241             :       LOGICAL, INTENT(IN)                                :: use_virial
     242             : 
     243             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'allegro_energy_store_force_virial'
     244             : 
     245             :       INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
     246             :          ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
     247             :          nedges, nloc_size, npairs, nunique
     248           4 :       INTEGER(kind=int_8), ALLOCATABLE                   :: atom_types(:), temp_atom_types(:)
     249           4 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :)  :: edge_index, t_edge_index, temp_edge_index
     250           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: work_list
     251           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list, sort_list
     252           4 :       LOGICAL, ALLOCATABLE                               :: use_atom(:)
     253             :       REAL(kind=dp)                                      :: drij, rab2_max, rij(3)
     254           4 :       REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts, lattice, &
     255           4 :                                                             new_edge_cell_shifts, pos
     256             :       REAL(kind=dp), DIMENSION(3)                        :: cell_v, cvi
     257           4 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: atomic_energy, forces, virial
     258           4 :       REAL(kind=dp), DIMENSION(:, :, :), POINTER         :: virial3d
     259           4 :       REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :)        :: lattice_sp, new_edge_cell_shifts_sp, &
     260           4 :                                                             pos_sp
     261           4 :       REAL(kind=sp), DIMENSION(:, :), POINTER            :: atomic_energy_sp, forces_sp
     262             :       TYPE(allegro_data_type), POINTER                   :: allegro_data
     263             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
     264             :       TYPE(pair_potential_single_type), POINTER          :: pot
     265             :       TYPE(torch_dict_type)                              :: inputs, outputs
     266             : 
     267           4 :       CALL timeset(routineN, handle)
     268             : 
     269           4 :       NULLIFY (atomic_energy, forces, atomic_energy_sp, forces_sp, virial3d, virial)
     270           4 :       n_atoms = SIZE(particle_set)
     271          12 :       ALLOCATE (use_atom(n_atoms))
     272         324 :       use_atom = .FALSE.
     273             : 
     274          10 :       DO ikind = 1, SIZE(atomic_kind_set)
     275          20 :          DO jkind = 1, SIZE(atomic_kind_set)
     276          10 :             pot => potparm%pot(ikind, jkind)%pot
     277          26 :             DO i = 1, SIZE(pot%type)
     278          10 :                IF (pot%type(i) /= allegro_type) CYCLE
     279         916 :                DO iat = 1, n_atoms
     280         896 :                   IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
     281         714 :                       particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
     282             :                END DO ! iat
     283             :             END DO ! i
     284             :          END DO ! jkind
     285             :       END DO ! ikind
     286         324 :       n_atoms_use = COUNT(use_atom)
     287             : 
     288             :       ! get allegro_data to save force, virial info and to load model
     289           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
     290           4 :       IF (.NOT. ASSOCIATED(allegro_data)) THEN
     291          52 :          ALLOCATE (allegro_data)
     292           4 :          CALL fist_nonbond_env_set(fist_nonbond_env, allegro_data=allegro_data)
     293           4 :          NULLIFY (allegro_data%use_indices, allegro_data%force)
     294           4 :          CALL torch_model_load(allegro_data%model, pot%set(1)%allegro%allegro_file_name)
     295           4 :          CALL torch_model_freeze(allegro_data%model)
     296             :       END IF
     297           4 :       IF (ASSOCIATED(allegro_data%force)) THEN
     298           0 :          IF (SIZE(allegro_data%force, 2) /= n_atoms_use) THEN
     299           0 :             DEALLOCATE (allegro_data%force, allegro_data%use_indices)
     300             :          END IF
     301             :       END IF
     302           4 :       IF (.NOT. ASSOCIATED(allegro_data%force)) THEN
     303          12 :          ALLOCATE (allegro_data%force(3, n_atoms_use))
     304          12 :          ALLOCATE (allegro_data%use_indices(n_atoms_use))
     305             :       END IF
     306             : 
     307             :       iat_use = 0
     308         324 :       DO iat = 1, n_atoms_use
     309         324 :          IF (use_atom(iat)) THEN
     310         320 :             iat_use = iat_use + 1
     311         320 :             allegro_data%use_indices(iat_use) = iat
     312             :          END IF
     313             :       END DO
     314             : 
     315           4 :       nedges = 0
     316             : 
     317          12 :       ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
     318          12 :       ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
     319          12 :       ALLOCATE (temp_atom_types(SIZE(glob_loc_list_a)))
     320             : 
     321         112 :       DO ilist = 1, nonbonded%nlists
     322         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     323         108 :          npairs = neighbor_kind_pair%npairs
     324         108 :          IF (npairs == 0) CYCLE
     325         258 :          Kind_Group_Loop_Allegro: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     326         169 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     327         169 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     328         169 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     329         169 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     330         169 :             list => neighbor_kind_pair%list
     331         676 :             cvi = neighbor_kind_pair%cell_vector
     332         169 :             pot => potparm%pot(ikind, jkind)%pot
     333         446 :             DO i = 1, SIZE(pot%type)
     334         169 :                IF (pot%type(i) /= allegro_type) CYCLE
     335         169 :                rab2_max = pot%set(i)%allegro%rcutsq
     336        2197 :                cell_v = MATMUL(cell%hmat, cvi)
     337         169 :                pot => potparm%pot(ikind, jkind)%pot
     338         169 :                allegro => pot%set(i)%allegro
     339         169 :                npairs = iend - istart + 1
     340         338 :                IF (npairs /= 0) THEN
     341         845 :                   ALLOCATE (sort_list(2, npairs), work_list(npairs))
     342       62353 :                   sort_list = list(:, istart:iend)
     343             :                   ! Sort the list of neighbors, this increases the efficiency for single
     344             :                   ! potential contributions
     345         169 :                   CALL sort(sort_list(1, :), npairs, work_list)
     346       10533 :                   DO ipair = 1, npairs
     347       10533 :                      work_list(ipair) = sort_list(2, work_list(ipair))
     348             :                   END DO
     349       10533 :                   sort_list(2, :) = work_list
     350             :                   ! find number of unique elements of array index 1
     351             :                   nunique = 1
     352       10364 :                   DO ipair = 1, npairs - 1
     353       10364 :                      IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
     354             :                   END DO
     355         169 :                   ipair = 1
     356         169 :                   junique = sort_list(1, ipair)
     357         169 :                   ifirst = 1
     358        1538 :                   DO iunique = 1, nunique
     359        1369 :                      atom_a = junique
     360        1369 :                      IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
     361      360440 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     362      360440 :                         IF (glob_loc_list_a(mpair) == atom_a) EXIT
     363             :                      END DO
     364      106529 :                      ifirst = mpair
     365      106529 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     366      106529 :                         IF (glob_loc_list_a(mpair) /= atom_a) EXIT
     367             :                      END DO
     368        1369 :                      ilast = mpair - 1
     369        1369 :                      nloc_size = 0
     370        1369 :                      IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
     371       11733 :                      DO WHILE (ipair <= npairs)
     372       11564 :                         IF (sort_list(1, ipair) /= junique) EXIT
     373       10364 :                         atom_b = sort_list(2, ipair)
     374       41456 :                         rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
     375       41456 :                         drij = DOT_PRODUCT(rij, rij)
     376       10364 :                         ipair = ipair + 1
     377       11733 :                         IF (drij <= rab2_max) THEN
     378        5998 :                            nedges = nedges + 1
     379       17994 :                            edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
     380       23992 :                            edge_cell_shifts(:, nedges) = cvi
     381             :                         END IF
     382             :                      END DO
     383        1369 :                      ifirst = ilast + 1
     384        1538 :                      IF (ipair <= npairs) junique = sort_list(1, ipair)
     385             :                   END DO
     386         169 :                   DEALLOCATE (sort_list, work_list)
     387             :                END IF
     388             :             END DO
     389             :          END DO Kind_Group_Loop_Allegro
     390             :       END DO
     391             : 
     392           4 :       allegro => pot%set(1)%allegro
     393             : 
     394          12 :       ALLOCATE (temp_edge_index(2, nedges))
     395       17998 :       temp_edge_index(:, :) = edge_index(:, :nedges)
     396          12 :       ALLOCATE (new_edge_cell_shifts(3, nedges))
     397       23996 :       new_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
     398           4 :       DEALLOCATE (edge_cell_shifts)
     399             : 
     400           8 :       ALLOCATE (t_edge_index(nedges, 2))
     401             : 
     402       12008 :       t_edge_index(:, :) = TRANSPOSE(temp_edge_index)
     403           4 :       DEALLOCATE (temp_edge_index, edge_index)
     404           4 :       ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
     405          52 :       lattice(:, :) = cell%hmat/pot%set(1)%allegro%unit_cell_val
     406          52 :       lattice_sp(:, :) = REAL(lattice, kind=sp)
     407           4 :       iat_use = 0
     408          20 :       ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
     409         324 :       DO iat = 1, n_atoms_use
     410         320 :          IF (.NOT. use_atom(iat)) CYCLE
     411         320 :          iat_use = iat_use + 1
     412             :          ! Find index of the element based on its position in config.yaml file to have correct mapping
     413        1024 :          DO i = 1, SIZE(allegro%type_names_torch)
     414        1024 :             IF (particle_set(iat)%atomic_kind%element_symbol == allegro%type_names_torch(i)) THEN
     415         320 :                atom_idx = i - 1
     416             :             END IF
     417             :          END DO
     418         320 :          atom_types(iat_use) = atom_idx
     419        1284 :          pos(:, iat) = r_last_update_pbc(iat)%r(:)/allegro%unit_coords_val
     420             :       END DO
     421             : 
     422           4 :       CALL torch_dict_create(inputs)
     423             : 
     424           4 :       IF (allegro%do_allegro_sp) THEN
     425          10 :          ALLOCATE (new_edge_cell_shifts_sp(3, nedges), pos_sp(3, n_atoms_use))
     426       19898 :          new_edge_cell_shifts_sp(:, :) = REAL(new_edge_cell_shifts(:, :), kind=sp)
     427         770 :          pos_sp(:, :) = REAL(pos(:, :), kind=sp)
     428           2 :          DEALLOCATE (pos, new_edge_cell_shifts)
     429           2 :          CALL torch_dict_insert(inputs, "pos", pos_sp)
     430           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts_sp)
     431           2 :          CALL torch_dict_insert(inputs, "cell", lattice_sp)
     432             :       ELSE
     433           2 :          CALL torch_dict_insert(inputs, "pos", pos)
     434           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts)
     435           2 :          CALL torch_dict_insert(inputs, "cell", lattice)
     436             :       END IF
     437           4 :       CALL torch_dict_insert(inputs, "edge_index", t_edge_index)
     438           4 :       CALL torch_dict_insert(inputs, "atom_types", atom_types)
     439           4 :       CALL torch_dict_create(outputs)
     440           4 :       CALL torch_model_eval(allegro_data%model, inputs, outputs)
     441           4 :       pot_allegro = 0.0_dp
     442             : 
     443           4 :       IF (allegro%do_allegro_sp) THEN
     444           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_sp)
     445           2 :          CALL torch_dict_get(outputs, "forces", forces_sp)
     446           2 :          IF (use_virial) THEN
     447           0 :             ALLOCATE (virial(3, 3))
     448           0 :             CALL torch_dict_get(outputs, "virial", virial3d)
     449           0 :             virial = RESHAPE(virial3d, (/3, 3/))
     450           0 :             allegro_data%virial(:, :) = virial(:, :)*allegro%unit_energy_val
     451           0 :             DEALLOCATE (virial, virial3d)
     452             :          END IF
     453         770 :          allegro_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*allegro%unit_forces_val
     454          98 :          DO iat_use = 1, SIZE(unique_list_a)
     455          96 :             i = unique_list_a(iat_use)
     456          98 :             pot_allegro = pot_allegro + REAL(atomic_energy_sp(1, i), kind=dp)*allegro%unit_energy_val
     457             :          END DO
     458           2 :          DEALLOCATE (forces_sp, atomic_energy_sp, new_edge_cell_shifts_sp, pos_sp)
     459             :       ELSE
     460           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
     461           2 :          CALL torch_dict_get(outputs, "forces", forces)
     462           2 :          IF (use_virial) THEN
     463           0 :             ALLOCATE (virial(3, 3))
     464           0 :             CALL torch_dict_get(outputs, "virial", virial3d)
     465           0 :             virial = RESHAPE(virial3d, (/3, 3/))
     466           0 :             allegro_data%virial(:, :) = virial(:, :)*allegro%unit_energy_val
     467           0 :             DEALLOCATE (virial, virial3d)
     468             :          END IF
     469        1026 :          allegro_data%force(:, :) = forces(:, :)*allegro%unit_forces_val
     470          66 :          DO iat_use = 1, SIZE(unique_list_a)
     471          64 :             i = unique_list_a(iat_use)
     472          66 :             pot_allegro = pot_allegro + atomic_energy(1, i)*allegro%unit_energy_val
     473             :          END DO
     474           2 :          DEALLOCATE (forces, atomic_energy, pos, new_edge_cell_shifts)
     475             :       END IF
     476             : 
     477           4 :       CALL torch_dict_release(inputs)
     478           4 :       CALL torch_dict_release(outputs)
     479             : 
     480           4 :       DEALLOCATE (t_edge_index, atom_types)
     481             : 
     482           4 :       IF (use_virial) allegro_data%virial(:, :) = allegro_data%virial/REAL(para_env%num_pe, dp)
     483           4 :       CALL timestop(handle)
     484           8 :    END SUBROUTINE allegro_energy_store_force_virial
     485             : 
     486             : ! **************************************************************************************************
     487             : !> \brief ...
     488             : !> \param fist_nonbond_env ...
     489             : !> \param f_nonbond ...
     490             : !> \param pv_nonbond ...
     491             : !> \param use_virial ...
     492             : ! **************************************************************************************************
     493           4 :    SUBROUTINE allegro_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
     494             : 
     495             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     496             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: f_nonbond, pv_nonbond
     497             :       LOGICAL, INTENT(IN)                                :: use_virial
     498             : 
     499             :       INTEGER                                            :: iat, iat_use
     500             :       TYPE(allegro_data_type), POINTER                   :: allegro_data
     501             : 
     502           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
     503             : 
     504           4 :       IF (use_virial) THEN
     505           0 :          pv_nonbond = pv_nonbond + allegro_data%virial
     506             :       END IF
     507             : 
     508         324 :       DO iat_use = 1, SIZE(allegro_data%use_indices)
     509         320 :          iat = allegro_data%use_indices(iat_use)
     510         320 :          CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
     511        1284 :          f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + allegro_data%force(1:3, iat_use)
     512             :       END DO
     513             : 
     514           4 :    END SUBROUTINE allegro_add_force_virial
     515             : END MODULE manybody_allegro
     516             : 

Generated by: LCOV version 1.15