LCOV - code coverage report
Current view: top level - src - manybody_nequip.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b8e0b09) Lines: 265 273 97.1 %
Date: 2024-08-31 06:31:37 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \par History
      10             : !>      nequip implementation
      11             : !> \author Gabriele Tocci
      12             : ! **************************************************************************************************
      13             : MODULE manybody_nequip
      14             : 
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type
      16             :    USE cell_types,                      ONLY: cell_type
      17             :    USE fist_neighbor_list_types,        ONLY: fist_neighbor_type,&
      18             :                                               neighbor_kind_pairs_type
      19             :    USE fist_nonbond_env_types,          ONLY: fist_nonbond_env_get,&
      20             :                                               fist_nonbond_env_set,&
      21             :                                               fist_nonbond_env_type,&
      22             :                                               nequip_data_type,&
      23             :                                               pos_type
      24             :    USE kinds,                           ONLY: dp,&
      25             :                                               int_8,&
      26             :                                               sp
      27             :    USE message_passing,                 ONLY: mp_para_env_type
      28             :    USE pair_potential_types,            ONLY: nequip_pot_type,&
      29             :                                               nequip_type,&
      30             :                                               pair_potential_pp_type,&
      31             :                                               pair_potential_single_type
      32             :    USE particle_types,                  ONLY: particle_type
      33             :    USE torch_api,                       ONLY: torch_dict_create,&
      34             :                                               torch_dict_get,&
      35             :                                               torch_dict_insert,&
      36             :                                               torch_dict_release,&
      37             :                                               torch_dict_type,&
      38             :                                               torch_model_eval,&
      39             :                                               torch_model_freeze,&
      40             :                                               torch_model_load
      41             :    USE util,                            ONLY: sort
      42             : #include "./base/base_uses.f90"
      43             : 
      44             :    IMPLICIT NONE
      45             : 
      46             :    PRIVATE
      47             :    PUBLIC :: setup_nequip_arrays, destroy_nequip_arrays, &
      48             :              nequip_energy_store_force_virial, nequip_add_force_virial
      49             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
      50             : 
      51             : CONTAINS
      52             : 
      53             : ! **************************************************************************************************
      54             : !> \brief ...
      55             : !> \param nonbonded ...
      56             : !> \param potparm ...
      57             : !> \param glob_loc_list ...
      58             : !> \param glob_cell_v ...
      59             : !> \param glob_loc_list_a ...
      60             : !> \param cell ...
      61             : !> \par History
      62             : !>      Implementation of the nequip potential - [gtocci] 2022
      63             : !> \author Gabriele Tocci - University of Zurich
      64             : ! **************************************************************************************************
      65           4 :    SUBROUTINE setup_nequip_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, cell)
      66             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
      67             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
      68             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
      69             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
      70             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
      71             :       TYPE(cell_type), POINTER                           :: cell
      72             : 
      73             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_nequip_arrays'
      74             : 
      75             :       INTEGER                                            :: handle, i, iend, igrp, ikind, ilist, &
      76             :                                                             ipair, istart, jkind, nkinds, npairs, &
      77             :                                                             npairs_tot
      78           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: work_list, work_list2
      79           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list
      80           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: rwork_list
      81             :       REAL(KIND=dp), DIMENSION(3)                        :: cell_v, cvi
      82             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
      83             :       TYPE(pair_potential_single_type), POINTER          :: pot
      84             : 
      85           0 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
      86           4 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
      87           4 :       CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
      88           4 :       CALL timeset(routineN, handle)
      89           4 :       npairs_tot = 0
      90           4 :       nkinds = SIZE(potparm%pot, 1)
      91         112 :       DO ilist = 1, nonbonded%nlists
      92         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
      93         108 :          npairs = neighbor_kind_pair%npairs
      94         108 :          IF (npairs == 0) CYCLE
      95         163 :          Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
      96         116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
      97         116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
      98         116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
      99         116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     100         116 :             pot => potparm%pot(ikind, jkind)%pot
     101         116 :             npairs = iend - istart + 1
     102         116 :             IF (pot%no_mb) CYCLE
     103         340 :             DO i = 1, SIZE(pot%type)
     104         232 :                IF (pot%type(i) == nequip_type) npairs_tot = npairs_tot + npairs
     105             :             END DO
     106             :          END DO Kind_Group_Loop1
     107             :       END DO
     108          12 :       ALLOCATE (work_list(npairs_tot))
     109           8 :       ALLOCATE (work_list2(npairs_tot))
     110          12 :       ALLOCATE (glob_loc_list(2, npairs_tot))
     111          12 :       ALLOCATE (glob_cell_v(3, npairs_tot))
     112             :       ! Fill arrays with data
     113           4 :       npairs_tot = 0
     114         112 :       DO ilist = 1, nonbonded%nlists
     115         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     116         108 :          npairs = neighbor_kind_pair%npairs
     117         108 :          IF (npairs == 0) CYCLE
     118         163 :          Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     119         116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     120         116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     121         116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     122         116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     123         116 :             list => neighbor_kind_pair%list
     124         464 :             cvi = neighbor_kind_pair%cell_vector
     125         116 :             pot => potparm%pot(ikind, jkind)%pot
     126         116 :             npairs = iend - istart + 1
     127         116 :             IF (pot%no_mb) CYCLE
     128        1508 :             cell_v = MATMUL(cell%hmat, cvi)
     129         340 :             DO i = 1, SIZE(pot%type)
     130             :                ! NEQUIP
     131         232 :                IF (pot%type(i) == nequip_type) THEN
     132        5096 :                   DO ipair = 1, npairs
     133       29880 :                      glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
     134       20036 :                      glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
     135             :                   END DO
     136         116 :                   npairs_tot = npairs_tot + npairs
     137             :                END IF
     138             :             END DO
     139             :          END DO Kind_Group_Loop2
     140             :       END DO
     141             :       ! Order the arrays w.r.t. the first index of glob_loc_list
     142           4 :       CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
     143        4984 :       DO ipair = 1, npairs_tot
     144        4984 :          work_list2(ipair) = glob_loc_list(2, work_list(ipair))
     145             :       END DO
     146        4984 :       glob_loc_list(2, :) = work_list2
     147           4 :       DEALLOCATE (work_list2)
     148          12 :       ALLOCATE (rwork_list(3, npairs_tot))
     149        4984 :       DO ipair = 1, npairs_tot
     150       19924 :          rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
     151             :       END DO
     152       19924 :       glob_cell_v = rwork_list
     153           4 :       DEALLOCATE (rwork_list)
     154           4 :       DEALLOCATE (work_list)
     155          12 :       ALLOCATE (glob_loc_list_a(npairs_tot))
     156        9968 :       glob_loc_list_a = glob_loc_list(1, :)
     157           4 :       CALL timestop(handle)
     158           8 :    END SUBROUTINE setup_nequip_arrays
     159             : 
     160             : ! **************************************************************************************************
     161             : !> \brief ...
     162             : !> \param glob_loc_list ...
     163             : !> \param glob_cell_v ...
     164             : !> \param glob_loc_list_a ...
     165             : !> \par History
     166             : !>      Implementation of the nequip potential - [gtocci] 2022
     167             : !> \author Gabriele Tocci - University of Zurich
     168             : ! **************************************************************************************************
     169           4 :    SUBROUTINE destroy_nequip_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a)
     170             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
     171             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
     172             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     173             : 
     174           4 :       IF (ASSOCIATED(glob_loc_list)) THEN
     175           4 :          DEALLOCATE (glob_loc_list)
     176             :       END IF
     177           4 :       IF (ASSOCIATED(glob_loc_list_a)) THEN
     178           4 :          DEALLOCATE (glob_loc_list_a)
     179             :       END IF
     180           4 :       IF (ASSOCIATED(glob_cell_v)) THEN
     181           4 :          DEALLOCATE (glob_cell_v)
     182             :       END IF
     183             : 
     184           4 :    END SUBROUTINE destroy_nequip_arrays
     185             : ! **************************************************************************************************
     186             : !> \brief ...
     187             : !> \param nonbonded ...
     188             : !> \param particle_set ...
     189             : !> \param cell ...
     190             : !> \param atomic_kind_set ...
     191             : !> \param potparm ...
     192             : !> \param nequip ...
     193             : !> \param glob_loc_list_a ...
     194             : !> \param r_last_update_pbc ...
     195             : !> \param pot_nequip ...
     196             : !> \param fist_nonbond_env ...
     197             : !> \param para_env ...
     198             : !> \param use_virial ...
     199             : !> \par History
     200             : !>      Implementation of the nequip potential - [gtocci] 2022
     201             : !>      Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
     202             : !> \author Gabriele Tocci - University of Zurich
     203             : ! **************************************************************************************************
     204           4 :    SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
     205             :                                                potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
     206             :                                                pot_nequip, fist_nonbond_env, para_env, use_virial)
     207             : 
     208             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
     209             :       TYPE(particle_type), POINTER                       :: particle_set(:)
     210             :       TYPE(cell_type), POINTER                           :: cell
     211             :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     212             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     213             :       TYPE(nequip_pot_type), POINTER                     :: nequip
     214             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     215             :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_last_update_pbc
     216             :       REAL(kind=dp)                                      :: pot_nequip
     217             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     218             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     219             :       LOGICAL, INTENT(IN)                                :: use_virial
     220             : 
     221             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
     222             : 
     223             :       INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
     224             :          ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
     225             :          nedges, nedges_tot, nloc_size, npairs, nunique
     226           4 :       INTEGER(kind=int_8), ALLOCATABLE                   :: atom_types(:)
     227           4 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :)  :: edge_index, t_edge_index, temp_edge_index
     228           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: displ, displ_cell, edge_count, &
     229           4 :                                                             edge_count_cell, work_list
     230           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list, sort_list
     231           4 :       LOGICAL, ALLOCATABLE                               :: use_atom(:)
     232             :       REAL(kind=dp)                                      :: drij, lattice(3, 3), rab2_max, rij(3)
     233           4 :       REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts, pos, &
     234           4 :                                                             temp_edge_cell_shifts
     235             :       REAL(kind=dp), DIMENSION(3)                        :: cell_v, cvi
     236           4 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: atomic_energy, forces, total_energy, &
     237           4 :                                                             virial
     238           4 :       REAL(kind=dp), DIMENSION(:, :, :), POINTER         :: virial3d
     239             :       REAL(kind=sp)                                      :: lattice_sp(3, 3)
     240           4 :       REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts_sp, pos_sp
     241           4 :       REAL(kind=sp), DIMENSION(:, :), POINTER            :: atomic_energy_sp, forces_sp, &
     242           4 :                                                             total_energy_sp
     243             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
     244             :       TYPE(nequip_data_type), POINTER                    :: nequip_data
     245             :       TYPE(pair_potential_single_type), POINTER          :: pot
     246             :       TYPE(torch_dict_type)                              :: inputs, outputs
     247             : 
     248           4 :       CALL timeset(routineN, handle)
     249             : 
     250           4 :       NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp, virial3d, virial)
     251           4 :       n_atoms = SIZE(particle_set)
     252          12 :       ALLOCATE (use_atom(n_atoms))
     253         202 :       use_atom = .FALSE.
     254             : 
     255          12 :       DO ikind = 1, SIZE(atomic_kind_set)
     256          28 :          DO jkind = 1, SIZE(atomic_kind_set)
     257          16 :             pot => potparm%pot(ikind, jkind)%pot
     258          40 :             DO i = 1, SIZE(pot%type)
     259          16 :                IF (pot%type(i) /= nequip_type) CYCLE
     260         824 :                DO iat = 1, n_atoms
     261         792 :                   IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
     262         610 :                       particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
     263             :                END DO ! iat
     264             :             END DO ! i
     265             :          END DO ! jkind
     266             :       END DO ! ikind
     267         202 :       n_atoms_use = COUNT(use_atom)
     268             : 
     269             :       ! get nequip_data to save force, virial info and to load model
     270           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
     271           4 :       IF (.NOT. ASSOCIATED(nequip_data)) THEN
     272          52 :          ALLOCATE (nequip_data)
     273           4 :          CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=nequip_data)
     274           4 :          NULLIFY (nequip_data%use_indices, nequip_data%force)
     275           4 :          CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
     276           4 :          CALL torch_model_freeze(nequip_data%model)
     277             :       END IF
     278           4 :       IF (ASSOCIATED(nequip_data%force)) THEN
     279           0 :          IF (SIZE(nequip_data%force, 2) /= n_atoms_use) THEN
     280           0 :             DEALLOCATE (nequip_data%force, nequip_data%use_indices)
     281             :          END IF
     282             :       END IF
     283           4 :       IF (.NOT. ASSOCIATED(nequip_data%force)) THEN
     284          12 :          ALLOCATE (nequip_data%force(3, n_atoms_use))
     285          12 :          ALLOCATE (nequip_data%use_indices(n_atoms_use))
     286             :       END IF
     287             : 
     288             :       iat_use = 0
     289         202 :       DO iat = 1, n_atoms_use
     290         202 :          IF (use_atom(iat)) THEN
     291         198 :             iat_use = iat_use + 1
     292         198 :             nequip_data%use_indices(iat_use) = iat
     293             :          END IF
     294             :       END DO
     295             : 
     296           4 :       nedges = 0
     297          12 :       ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
     298          12 :       ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
     299         112 :       DO ilist = 1, nonbonded%nlists
     300         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     301         108 :          npairs = neighbor_kind_pair%npairs
     302         108 :          IF (npairs == 0) CYCLE
     303         163 :          Kind_Group_Loop_Nequip: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     304         116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     305         116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     306         116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     307         116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     308         116 :             list => neighbor_kind_pair%list
     309         464 :             cvi = neighbor_kind_pair%cell_vector
     310         116 :             pot => potparm%pot(ikind, jkind)%pot
     311         340 :             DO i = 1, SIZE(pot%type)
     312         116 :                IF (pot%type(i) /= nequip_type) CYCLE
     313         116 :                rab2_max = pot%set(i)%nequip%rcutsq
     314        1508 :                cell_v = MATMUL(cell%hmat, cvi)
     315         116 :                pot => potparm%pot(ikind, jkind)%pot
     316         116 :                nequip => pot%set(i)%nequip
     317         116 :                npairs = iend - istart + 1
     318         232 :                IF (npairs /= 0) THEN
     319         580 :                   ALLOCATE (sort_list(2, npairs), work_list(npairs))
     320       29996 :                   sort_list = list(:, istart:iend)
     321             :                   ! Sort the list of neighbors, this increases the efficiency for single
     322             :                   ! potential contributions
     323         116 :                   CALL sort(sort_list(1, :), npairs, work_list)
     324        5096 :                   DO ipair = 1, npairs
     325        5096 :                      work_list(ipair) = sort_list(2, work_list(ipair))
     326             :                   END DO
     327        5096 :                   sort_list(2, :) = work_list
     328             :                   ! find number of unique elements of array index 1
     329             :                   nunique = 1
     330        4980 :                   DO ipair = 1, npairs - 1
     331        4980 :                      IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
     332             :                   END DO
     333         116 :                   ipair = 1
     334         116 :                   junique = sort_list(1, ipair)
     335         116 :                   ifirst = 1
     336         915 :                   DO iunique = 1, nunique
     337         799 :                      atom_a = junique
     338         799 :                      IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
     339      171934 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     340      171934 :                         IF (glob_loc_list_a(mpair) == atom_a) EXIT
     341             :                      END DO
     342       41700 :                      ifirst = mpair
     343       41700 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     344       41700 :                         IF (glob_loc_list_a(mpair) /= atom_a) EXIT
     345             :                      END DO
     346         799 :                      ilast = mpair - 1
     347         799 :                      nloc_size = 0
     348         799 :                      IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
     349        5779 :                      DO WHILE (ipair <= npairs)
     350        5663 :                         IF (sort_list(1, ipair) /= junique) EXIT
     351        4980 :                         atom_b = sort_list(2, ipair)
     352       19920 :                         rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
     353       19920 :                         drij = DOT_PRODUCT(rij, rij)
     354        4980 :                         ipair = ipair + 1
     355        5779 :                         IF (drij <= rab2_max) THEN
     356        2576 :                            nedges = nedges + 1
     357        7728 :                            edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
     358       10304 :                            edge_cell_shifts(:, nedges) = cvi
     359             :                         END IF
     360             :                      END DO
     361         799 :                      ifirst = ilast + 1
     362         915 :                      IF (ipair <= npairs) junique = sort_list(1, ipair)
     363             :                   END DO
     364         116 :                   DEALLOCATE (sort_list, work_list)
     365             :                END IF
     366             :             END DO
     367             :          END DO Kind_Group_Loop_Nequip
     368             :       END DO
     369             : 
     370           4 :       nequip => pot%set(1)%nequip
     371             : 
     372          12 :       ALLOCATE (edge_count(para_env%num_pe))
     373           8 :       ALLOCATE (edge_count_cell(para_env%num_pe))
     374           8 :       ALLOCATE (displ_cell(para_env%num_pe))
     375           8 :       ALLOCATE (displ(para_env%num_pe))
     376             : 
     377           4 :       CALL para_env%allgather(nedges, edge_count)
     378          12 :       nedges_tot = SUM(edge_count)
     379             : 
     380          12 :       ALLOCATE (temp_edge_index(2, nedges))
     381        7732 :       temp_edge_index(:, :) = edge_index(:, :nedges)
     382           4 :       DEALLOCATE (edge_index)
     383          12 :       ALLOCATE (temp_edge_cell_shifts(3, nedges))
     384       10308 :       temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
     385           4 :       DEALLOCATE (edge_cell_shifts)
     386             : 
     387          12 :       ALLOCATE (edge_index(2, nedges_tot))
     388          12 :       ALLOCATE (edge_cell_shifts(3, nedges_tot))
     389           8 :       ALLOCATE (t_edge_index(nedges_tot, 2))
     390             : 
     391          12 :       edge_count_cell(:) = edge_count*3
     392          12 :       edge_count = edge_count*2
     393           4 :       displ(1) = 0
     394           4 :       displ_cell(1) = 0
     395           8 :       DO ipair = 2, para_env%num_pe
     396           4 :          displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
     397           8 :          displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
     398             :       END DO
     399             : 
     400           4 :       CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
     401           4 :       CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
     402             : 
     403       10316 :       t_edge_index(:, :) = TRANSPOSE(edge_index)
     404           4 :       DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
     405             : 
     406          52 :       lattice = cell%hmat/nequip%unit_cell_val
     407          52 :       lattice_sp = REAL(lattice, kind=sp)
     408             : 
     409           4 :       iat_use = 0
     410          20 :       ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
     411             : 
     412         202 :       DO iat = 1, n_atoms_use
     413         198 :          IF (.NOT. use_atom(iat)) CYCLE
     414         198 :          iat_use = iat_use + 1
     415             :          ! Find index of the element based on its position in config.yaml file to have correct mapping
     416         594 :          DO i = 1, SIZE(nequip%type_names_torch)
     417         594 :             IF (particle_set(iat)%atomic_kind%element_symbol == nequip%type_names_torch(i)) THEN
     418         198 :                atom_idx = i - 1
     419             :             END IF
     420             :          END DO
     421         198 :          atom_types(iat_use) = atom_idx
     422         796 :          pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
     423             :       END DO
     424             : 
     425           4 :       CALL torch_dict_create(inputs)
     426           4 :       IF (nequip%do_nequip_sp) THEN
     427          10 :          ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
     428          26 :          pos_sp(:, :) = REAL(pos(:, :), kind=sp)
     429          50 :          edge_cell_shifts_sp(:, :) = REAL(edge_cell_shifts(:, :), kind=sp)
     430           2 :          CALL torch_dict_insert(inputs, "pos", pos_sp)
     431           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts_sp)
     432           2 :          CALL torch_dict_insert(inputs, "cell", lattice_sp)
     433             :       ELSE
     434           2 :          CALL torch_dict_insert(inputs, "pos", pos)
     435           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts)
     436           2 :          CALL torch_dict_insert(inputs, "cell", lattice)
     437             :       END IF
     438             : 
     439           4 :       CALL torch_dict_insert(inputs, "edge_index", t_edge_index)
     440           4 :       CALL torch_dict_insert(inputs, "atom_types", atom_types)
     441             : 
     442           4 :       CALL torch_dict_create(outputs)
     443             : 
     444           4 :       CALL torch_model_eval(nequip_data%model, inputs, outputs)
     445             : 
     446           4 :       IF (nequip%do_nequip_sp) THEN
     447           2 :          CALL torch_dict_get(outputs, "total_energy", total_energy_sp)
     448           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_sp)
     449           2 :          CALL torch_dict_get(outputs, "forces", forces_sp)
     450           2 :          IF (use_virial) THEN
     451           0 :             ALLOCATE (virial(3, 3))
     452           0 :             CALL torch_dict_get(outputs, "virial", virial3d)
     453           0 :             virial = RESHAPE(virial3d, (/3, 3/))
     454           0 :             nequip_data%virial(:, :) = virial(:, :)*nequip%unit_energy_val
     455           0 :             DEALLOCATE (virial, virial3d)
     456             :          END IF
     457           2 :          pot_nequip = REAL(total_energy_sp(1, 1), kind=dp)*nequip%unit_energy_val
     458          26 :          nequip_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*nequip%unit_forces_val
     459           2 :          DEALLOCATE (pos_sp, edge_cell_shifts_sp, total_energy_sp, atomic_energy_sp, forces_sp)
     460             :       ELSE
     461           2 :          CALL torch_dict_get(outputs, "total_energy", total_energy)
     462           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
     463           2 :          CALL torch_dict_get(outputs, "forces", forces)
     464           2 :          IF (use_virial) THEN
     465           2 :             ALLOCATE (virial(3, 3))
     466           2 :             CALL torch_dict_get(outputs, "virial", virial3d)
     467          26 :             virial = RESHAPE(virial3d, (/3, 3/))
     468          50 :             nequip_data%virial(:, :) = virial(:, :)*nequip%unit_energy_val
     469           2 :             DEALLOCATE (virial, virial3d)
     470             :          END IF
     471           2 :          pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
     472        1538 :          nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
     473           2 :          DEALLOCATE (pos, edge_cell_shifts, total_energy, atomic_energy, forces)
     474             :       END IF
     475             : 
     476           4 :       CALL torch_dict_release(inputs)
     477           4 :       CALL torch_dict_release(outputs)
     478             : 
     479           4 :       DEALLOCATE (t_edge_index, atom_types)
     480             : 
     481             :       ! account for double counting from multiple MPI processes
     482           4 :       pot_nequip = pot_nequip/REAL(para_env%num_pe, dp)
     483         796 :       nequip_data%force = nequip_data%force/REAL(para_env%num_pe, dp)
     484          28 :       IF (use_virial) nequip_data%virial(:, :) = nequip_data%virial/REAL(para_env%num_pe, dp)
     485             : 
     486           4 :       CALL timestop(handle)
     487           8 :    END SUBROUTINE nequip_energy_store_force_virial
     488             : 
     489             : ! **************************************************************************************************
     490             : !> \brief ...
     491             : !> \param fist_nonbond_env ...
     492             : !> \param f_nonbond ...
     493             : !> \param pv_nonbond ...
     494             : !> \param use_virial ...
     495             : ! **************************************************************************************************
     496           4 :    SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
     497             : 
     498             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     499             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: f_nonbond, pv_nonbond
     500             :       LOGICAL, INTENT(IN)                                :: use_virial
     501             : 
     502             :       INTEGER                                            :: iat, iat_use
     503             :       TYPE(nequip_data_type), POINTER                    :: nequip_data
     504             : 
     505           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
     506             : 
     507           4 :       IF (use_virial) THEN
     508          26 :          pv_nonbond = pv_nonbond + nequip_data%virial
     509             :       END IF
     510             : 
     511         202 :       DO iat_use = 1, SIZE(nequip_data%use_indices)
     512         198 :          iat = nequip_data%use_indices(iat_use)
     513         198 :          CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
     514         796 :          f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + nequip_data%force(1:3, iat_use)
     515             :       END DO
     516             : 
     517           4 :    END SUBROUTINE nequip_add_force_virial
     518             : END MODULE manybody_nequip
     519             : 

Generated by: LCOV version 1.15