LCOV - code coverage report
Current view: top level - src - pao_model.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 87 98 88.8 %
Date: 2024-12-21 06:28:57 Functions: 3 3 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Module for equivariant PAO-ML based on PyTorch.
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_model
      13             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      14             :                                               get_atomic_kind
      15             :    USE basis_set_types,                 ONLY: gto_basis_set_type
      16             :    USE cell_types,                      ONLY: cell_type,&
      17             :                                               pbc
      18             :    USE cp_dbcsr_api,                    ONLY: dbcsr_iterator_blocks_left,&
      19             :                                               dbcsr_iterator_next_block,&
      20             :                                               dbcsr_iterator_start,&
      21             :                                               dbcsr_iterator_stop,&
      22             :                                               dbcsr_iterator_type
      23             :    USE kinds,                           ONLY: default_path_length,&
      24             :                                               default_string_length,&
      25             :                                               dp,&
      26             :                                               sp
      27             :    USE message_passing,                 ONLY: mp_para_env_type
      28             :    USE pao_types,                       ONLY: pao_env_type,&
      29             :                                               pao_model_type
      30             :    USE particle_types,                  ONLY: particle_type
      31             :    USE physcon,                         ONLY: angstrom
      32             :    USE qs_environment_types,            ONLY: get_qs_env,&
      33             :                                               qs_environment_type
      34             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      35             :                                               qs_kind_type
      36             :    USE torch_api,                       ONLY: torch_dict_create,&
      37             :                                               torch_dict_get,&
      38             :                                               torch_dict_insert,&
      39             :                                               torch_dict_release,&
      40             :                                               torch_dict_type,&
      41             :                                               torch_model_eval,&
      42             :                                               torch_model_get_attr,&
      43             :                                               torch_model_load
      44             :    USE util,                            ONLY: sort
      45             : #include "./base/base_uses.f90"
      46             : 
      47             :    IMPLICIT NONE
      48             : 
      49             :    PRIVATE
      50             : 
      51             :    PUBLIC :: pao_model_load, pao_model_predict, pao_model_type
      52             : 
      53             : CONTAINS
      54             : 
      55             : ! **************************************************************************************************
      56             : !> \brief Loads a PAO-ML model.
      57             : !> \param pao ...
      58             : !> \param qs_env ...
      59             : !> \param ikind ...
      60             : !> \param pao_model_file ...
      61             : !> \param model ...
      62             : ! **************************************************************************************************
      63           0 :    SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
      64             :       TYPE(pao_env_type), INTENT(IN)                     :: pao
      65             :       TYPE(qs_environment_type), INTENT(IN)              :: qs_env
      66             :       INTEGER, INTENT(IN)                                :: ikind
      67             :       CHARACTER(LEN=default_path_length), INTENT(IN)     :: pao_model_file
      68             :       TYPE(pao_model_type), INTENT(OUT)                  :: model
      69             : 
      70             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_load'
      71             : 
      72             :       CHARACTER(LEN=default_string_length)               :: kind_name
      73             :       CHARACTER(LEN=default_string_length), &
      74           4 :          ALLOCATABLE, DIMENSION(:)                       :: feature_kind_names
      75             :       INTEGER                                            :: handle, jkind, kkind, pao_basis_size, z
      76           4 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      77             :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      78           4 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      79             : 
      80           4 :       CALL timeset(routineN, handle)
      81           4 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
      82             : 
      83           4 :       IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
      84           4 :       CALL torch_model_load(model%torch_model, pao_model_file)
      85             : 
      86             :       ! Read model attributes.
      87           4 :       CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
      88           4 :       CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
      89           4 :       CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
      90           4 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
      91           4 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
      92           4 :       CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
      93           4 :       CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
      94           4 :       CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
      95           4 :       CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
      96             : 
      97             :       ! Freeze model after all attributes have been read.
      98             :       ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
      99             :       ! https://github.com/pytorch/pytorch/issues/96726
     100             :       ! CALL torch_model_freeze(model%torch_model)
     101             : 
     102             :       ! For each feature kind name lookup its corresponding atomic kind number.
     103          12 :       ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
     104          12 :       model%feature_kinds(:) = -1
     105          12 :       DO jkind = 1, SIZE(feature_kind_names)
     106          24 :          DO kkind = 1, SIZE(atomic_kind_set)
     107          24 :             IF (TRIM(atomic_kind_set(kkind)%name) == TRIM(feature_kind_names(jkind))) THEN
     108           8 :                model%feature_kinds(jkind) = kkind
     109             :             END IF
     110             :          END DO
     111          12 :          IF (model%feature_kinds(jkind) < 0) THEN
     112           0 :             IF (pao%iw > 0) &
     113             :                WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
     114           0 :                TRIM(feature_kind_names(jkind))//"' that is not present in subsys."
     115             :          END IF
     116             :       END DO
     117             : 
     118             :       ! Check for missing kinds.
     119          12 :       DO jkind = 1, SIZE(atomic_kind_set)
     120          16 :          IF (ALL(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
     121           0 :             IF (pao%iw > 0) &
     122             :                WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
     123           0 :                TRIM(atomic_kind_set(jkind)%name)//"' that is present in subsys."
     124             :          END IF
     125             :       END DO
     126             : 
     127             :       ! Check compatibility
     128           4 :       CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
     129           4 :       CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
     130           4 :       IF (model%version /= 1) &
     131           0 :          CPABORT("Model version not supported.")
     132           4 :       IF (TRIM(model%kind_name) .NE. TRIM(kind_name)) &
     133           0 :          CPABORT("Kind name does not match.")
     134           4 :       IF (model%atomic_number /= z) &
     135           0 :          CPABORT("Atomic number does not match.")
     136           4 :       IF (TRIM(model%prim_basis_name) .NE. TRIM(basis_set%name)) &
     137           0 :          CPABORT("Primary basis set name does not match.")
     138           4 :       IF (model%prim_basis_size /= basis_set%nsgf) &
     139           0 :          CPABORT("Primary basis set size does not match.")
     140           4 :       IF (model%pao_basis_size /= pao_basis_size) &
     141           0 :          CPABORT("PAO basis size does not match.")
     142             : 
     143           4 :       CALL timestop(handle)
     144             : 
     145          12 :    END SUBROUTINE pao_model_load
     146             : 
     147             : ! **************************************************************************************************
     148             : !> \brief Fills pao%matrix_X based on machine learning predictions
     149             : !> \param pao ...
     150             : !> \param qs_env ...
     151             : ! **************************************************************************************************
     152           2 :    SUBROUTINE pao_model_predict(pao, qs_env)
     153             :       TYPE(pao_env_type), POINTER                        :: pao
     154             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     155             : 
     156             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_predict'
     157             : 
     158             :       INTEGER                                            :: acol, arow, handle, iatom
     159           2 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     160             :       TYPE(dbcsr_iterator_type)                          :: iter
     161             : 
     162           2 :       CALL timeset(routineN, handle)
     163             : 
     164           2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
     165             :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     166             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     167             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     168             :          IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
     169             :          iatom = arow; CPASSERT(arow == acol)
     170             :          CALL predict_single_atom(pao, qs_env, iatom, block_X)
     171             :       END DO
     172             :       CALL dbcsr_iterator_stop(iter)
     173             : !$OMP END PARALLEL
     174             : 
     175           2 :       CALL timestop(handle)
     176             : 
     177           2 :    END SUBROUTINE pao_model_predict
     178             : 
     179             : ! **************************************************************************************************
     180             : !> \brief Predicts a single block_X.
     181             : !> \param pao ...
     182             : !> \param qs_env ...
     183             : !> \param iatom ...
     184             : !> \param block_X ...
     185             : ! **************************************************************************************************
     186           6 :    SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X)
     187             :       TYPE(pao_env_type), INTENT(IN), POINTER            :: pao
     188             :       TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env
     189             :       INTEGER, INTENT(IN)                                :: iatom
     190             :       REAL(dp), DIMENSION(:, :), INTENT(OUT)             :: block_X
     191             : 
     192             :       INTEGER                                            :: ikind, jatom, jkind, jneighbor, natoms
     193           6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: neighbors_index
     194             :       REAL(dp), DIMENSION(3)                             :: Ri, Rij, Rj
     195           6 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: neighbors_distance
     196           6 :       REAL(sp), ALLOCATABLE, DIMENSION(:, :)             :: neighbors_features, neighbors_relpos
     197           6 :       REAL(sp), DIMENSION(:, :), POINTER                 :: predicted_xblock
     198           6 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     199             :       TYPE(cell_type), POINTER                           :: cell
     200             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     201             :       TYPE(pao_model_type), POINTER                      :: model
     202           6 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     203           6 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     204             :       TYPE(torch_dict_type)                              :: model_inputs, model_outputs
     205             : 
     206             :       CALL get_qs_env(qs_env, &
     207             :                       para_env=para_env, &
     208             :                       cell=cell, &
     209             :                       particle_set=particle_set, &
     210             :                       atomic_kind_set=atomic_kind_set, &
     211             :                       qs_kind_set=qs_kind_set, &
     212           6 :                       natom=natoms)
     213             : 
     214           6 :       CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     215           6 :       model => pao%models(ikind)
     216           6 :       CPASSERT(model%version > 0)
     217             : 
     218             :       ! Find neighbors.
     219             :       ! TODO: this is a quadratic algorithm, use a neighbor-list instead
     220          30 :       ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
     221          24 :       Ri = particle_set(iatom)%r
     222          42 :       DO jatom = 1, natoms
     223         144 :          Rj = particle_set(jatom)%r
     224          36 :          Rij = pbc(Ri, Rj, cell)
     225         150 :          neighbors_distance(jatom) = DOT_PRODUCT(Rij, Rij) ! using squared distances for performance
     226             :       END DO
     227           6 :       CALL sort(neighbors_distance, natoms, neighbors_index)
     228           6 :       CPASSERT(neighbors_index(1) == iatom) ! central atom should be closesd to itself
     229             : 
     230             :       ! Compute neighbors relative positions.
     231          18 :       ALLOCATE (neighbors_relpos(3, model%num_neighbors))
     232         126 :       neighbors_relpos(:, :) = 0.0_sp
     233          36 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     234          30 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     235         120 :          Rj = particle_set(jatom)%r
     236          30 :          Rij = pbc(Ri, Rj, cell)
     237         126 :          neighbors_relpos(:, jneighbor) = REAL(angstrom*Rij, kind=sp)
     238             :       END DO
     239             : 
     240             :       ! Compute neighbors features.
     241          24 :       ALLOCATE (neighbors_features(SIZE(model%feature_kinds), model%num_neighbors))
     242          96 :       neighbors_features(:, :) = 0.0_sp
     243          36 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     244          30 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     245          30 :          jkind = particle_set(jatom)%atomic_kind%kind_number
     246          96 :          WHERE (model%feature_kinds == jkind) neighbors_features(:, jneighbor) = 1.0_sp
     247             :       END DO
     248             : 
     249             :       ! Inference.
     250           6 :       CALL torch_dict_create(model_inputs)
     251           6 :       CALL torch_dict_insert(model_inputs, "neighbors_relpos", neighbors_relpos)
     252           6 :       CALL torch_dict_insert(model_inputs, "neighbors_features", neighbors_features)
     253           6 :       CALL torch_dict_create(model_outputs)
     254           6 :       CALL torch_model_eval(model%torch_model, model_inputs, model_outputs)
     255             : 
     256             :       ! Copy predicted XBlock.
     257           6 :       NULLIFY (predicted_xblock)
     258           6 :       CALL torch_dict_get(model_outputs, "xblock", predicted_xblock)
     259         220 :       block_X = RESHAPE(predicted_xblock, (/SIZE(block_X), 1/))
     260             : 
     261             :       ! Clean up.
     262           6 :       CALL torch_dict_release(model_inputs)
     263           6 :       CALL torch_dict_release(model_outputs)
     264           6 :       DEALLOCATE (neighbors_distance, neighbors_index)
     265           6 :       DEALLOCATE (predicted_xblock, neighbors_relpos, neighbors_features)
     266             : 
     267           6 :    END SUBROUTINE predict_single_atom
     268             : 
     269             : END MODULE pao_model

Generated by: LCOV version 1.15