LCOV - code coverage report
Current view: top level - src - pao_model.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4c33f95) Lines: 122 133 91.7 %
Date: 2025-01-30 06:53:08 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Module for equivariant PAO-ML based on PyTorch.
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_model
      13             :    USE OMP_LIB,                         ONLY: omp_init_lock,&
      14             :                                               omp_set_lock,&
      15             :                                               omp_unset_lock
      16             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      17             :                                               get_atomic_kind
      18             :    USE basis_set_types,                 ONLY: gto_basis_set_type
      19             :    USE cell_types,                      ONLY: cell_type,&
      20             :                                               pbc
      21             :    USE cp_dbcsr_api,                    ONLY: dbcsr_get_info,&
      22             :                                               dbcsr_iterator_blocks_left,&
      23             :                                               dbcsr_iterator_next_block,&
      24             :                                               dbcsr_iterator_start,&
      25             :                                               dbcsr_iterator_stop,&
      26             :                                               dbcsr_iterator_type,&
      27             :                                               dbcsr_type
      28             :    USE kinds,                           ONLY: default_path_length,&
      29             :                                               default_string_length,&
      30             :                                               dp,&
      31             :                                               sp
      32             :    USE message_passing,                 ONLY: mp_para_env_type
      33             :    USE pao_types,                       ONLY: pao_env_type,&
      34             :                                               pao_model_type
      35             :    USE particle_types,                  ONLY: particle_type
      36             :    USE physcon,                         ONLY: angstrom
      37             :    USE qs_environment_types,            ONLY: get_qs_env,&
      38             :                                               qs_environment_type
      39             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      40             :                                               qs_kind_type
      41             :    USE torch_api,                       ONLY: &
      42             :         torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
      43             :         torch_model_forward, torch_model_get_attr, torch_model_load, torch_tensor_backward, &
      44             :         torch_tensor_data_ptr, torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
      45             :         torch_tensor_type
      46             :    USE util,                            ONLY: sort
      47             : #include "./base/base_uses.f90"
      48             : 
      49             :    IMPLICIT NONE
      50             : 
      51             :    PRIVATE
      52             : 
      53             :    PUBLIC :: pao_model_load, pao_model_predict, pao_model_forces, pao_model_type
      54             : 
      55             : CONTAINS
      56             : 
      57             : ! **************************************************************************************************
      58             : !> \brief Loads a PAO-ML model.
      59             : !> \param pao ...
      60             : !> \param qs_env ...
      61             : !> \param ikind ...
      62             : !> \param pao_model_file ...
      63             : !> \param model ...
      64             : ! **************************************************************************************************
      65           0 :    SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
      66             :       TYPE(pao_env_type), INTENT(IN)                     :: pao
      67             :       TYPE(qs_environment_type), INTENT(IN)              :: qs_env
      68             :       INTEGER, INTENT(IN)                                :: ikind
      69             :       CHARACTER(LEN=default_path_length), INTENT(IN)     :: pao_model_file
      70             :       TYPE(pao_model_type), INTENT(OUT)                  :: model
      71             : 
      72             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_load'
      73             : 
      74             :       CHARACTER(LEN=default_string_length)               :: kind_name
      75             :       CHARACTER(LEN=default_string_length), &
      76           8 :          ALLOCATABLE, DIMENSION(:)                       :: feature_kind_names
      77             :       INTEGER                                            :: handle, jkind, kkind, pao_basis_size, z
      78           8 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      79             :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      80           8 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      81             : 
      82           8 :       CALL timeset(routineN, handle)
      83           8 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
      84             : 
      85           8 :       IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
      86           8 :       CALL torch_model_load(model%torch_model, pao_model_file)
      87             : 
      88             :       ! Read model attributes.
      89           8 :       CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
      90           8 :       CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
      91           8 :       CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
      92           8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
      93           8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
      94           8 :       CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
      95           8 :       CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
      96           8 :       CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
      97           8 :       CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
      98             : 
      99             :       ! Freeze model after all attributes have been read.
     100             :       ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
     101             :       ! https://github.com/pytorch/pytorch/issues/96726
     102             :       ! CALL torch_model_freeze(model%torch_model)
     103             : 
     104             :       ! For each feature kind name lookup its corresponding atomic kind number.
     105          24 :       ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
     106          24 :       model%feature_kinds(:) = -1
     107          24 :       DO jkind = 1, SIZE(feature_kind_names)
     108          48 :          DO kkind = 1, SIZE(atomic_kind_set)
     109          48 :             IF (TRIM(atomic_kind_set(kkind)%name) == TRIM(feature_kind_names(jkind))) THEN
     110          16 :                model%feature_kinds(jkind) = kkind
     111             :             END IF
     112             :          END DO
     113          24 :          IF (model%feature_kinds(jkind) < 0) THEN
     114           0 :             IF (pao%iw > 0) &
     115             :                WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
     116           0 :                TRIM(feature_kind_names(jkind))//"' that is not present in subsys."
     117             :          END IF
     118             :       END DO
     119             : 
     120             :       ! Check for missing kinds.
     121          24 :       DO jkind = 1, SIZE(atomic_kind_set)
     122          32 :          IF (ALL(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
     123           0 :             IF (pao%iw > 0) &
     124             :                WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
     125           0 :                TRIM(atomic_kind_set(jkind)%name)//"' that is present in subsys."
     126             :          END IF
     127             :       END DO
     128             : 
     129             :       ! Check compatibility
     130           8 :       CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
     131           8 :       CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
     132           8 :       IF (model%version /= 1) &
     133           0 :          CPABORT("Model version not supported.")
     134           8 :       IF (TRIM(model%kind_name) .NE. TRIM(kind_name)) &
     135           0 :          CPABORT("Kind name does not match.")
     136           8 :       IF (model%atomic_number /= z) &
     137           0 :          CPABORT("Atomic number does not match.")
     138           8 :       IF (TRIM(model%prim_basis_name) .NE. TRIM(basis_set%name)) &
     139           0 :          CPABORT("Primary basis set name does not match.")
     140           8 :       IF (model%prim_basis_size /= basis_set%nsgf) &
     141           0 :          CPABORT("Primary basis set size does not match.")
     142           8 :       IF (model%pao_basis_size /= pao_basis_size) &
     143           0 :          CPABORT("PAO basis size does not match.")
     144             : 
     145           8 :       CALL omp_init_lock(model%lock)
     146           8 :       CALL timestop(handle)
     147             : 
     148          24 :    END SUBROUTINE pao_model_load
     149             : 
     150             : ! **************************************************************************************************
     151             : !> \brief Fills pao%matrix_X based on machine learning predictions
     152             : !> \param pao ...
     153             : !> \param qs_env ...
     154             : ! **************************************************************************************************
     155          16 :    SUBROUTINE pao_model_predict(pao, qs_env)
     156             :       TYPE(pao_env_type), POINTER                        :: pao
     157             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     158             : 
     159             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_predict'
     160             : 
     161             :       INTEGER                                            :: acol, arow, handle, iatom
     162          16 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     163             :       TYPE(dbcsr_iterator_type)                          :: iter
     164             : 
     165          16 :       CALL timeset(routineN, handle)
     166             : 
     167          16 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
     168             :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     169             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     170             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     171             :          IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
     172             :          iatom = arow; CPASSERT(arow == acol)
     173             :          CALL predict_single_atom(pao, qs_env, iatom, block_X=block_X)
     174             :       END DO
     175             :       CALL dbcsr_iterator_stop(iter)
     176             : !$OMP END PARALLEL
     177             : 
     178          16 :       CALL timestop(handle)
     179             : 
     180          16 :    END SUBROUTINE pao_model_predict
     181             : 
     182             : ! **************************************************************************************************
     183             : !> \brief Calculate forces contributed by machine learning
     184             : !> \param pao ...
     185             : !> \param qs_env ...
     186             : !> \param matrix_G ...
     187             : !> \param forces ...
     188             : ! **************************************************************************************************
     189           2 :    SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
     190             :       TYPE(pao_env_type), POINTER                        :: pao
     191             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     192             :       TYPE(dbcsr_type)                                   :: matrix_G
     193             :       REAL(dp), DIMENSION(:, :), INTENT(INOUT)           :: forces
     194             : 
     195             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_forces'
     196             : 
     197             :       INTEGER                                            :: acol, arow, handle, iatom
     198           2 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_G
     199             :       TYPE(dbcsr_iterator_type)                          :: iter
     200             : 
     201           2 :       CALL timeset(routineN, handle)
     202             : 
     203           2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
     204             :       CALL dbcsr_iterator_start(iter, matrix_G)
     205             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     206             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_G)
     207             :          iatom = arow; CPASSERT(arow == acol)
     208             :          IF (SIZE(block_G) == 0) CYCLE ! pao disabled for iatom
     209             :          CALL predict_single_atom(pao, qs_env, iatom, block_G=block_G, forces=forces)
     210             :       END DO
     211             :       CALL dbcsr_iterator_stop(iter)
     212             : !$OMP END PARALLEL
     213             : 
     214           2 :       CALL timestop(handle)
     215             : 
     216           2 :    END SUBROUTINE pao_model_forces
     217             : 
     218             : ! **************************************************************************************************
     219             : !> \brief Predicts a single block_X.
     220             : !> \param pao ...
     221             : !> \param qs_env ...
     222             : !> \param iatom ...
     223             : !> \param block_X ...
     224             : !> \param block_G ...
     225             : !> \param forces ...
     226             : ! **************************************************************************************************
     227          54 :    SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
     228             :       TYPE(pao_env_type), INTENT(IN), POINTER            :: pao
     229             :       TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env
     230             :       INTEGER, INTENT(IN)                                :: iatom
     231             :       REAL(dp), DIMENSION(:, :), OPTIONAL                :: block_X, block_G, forces
     232             : 
     233             :       INTEGER                                            :: ikind, jatom, jkind, jneighbor, m, n, &
     234             :                                                             natoms
     235          54 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: neighbors_index
     236          54 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_pao, blk_sizes_pri
     237             :       REAL(dp), DIMENSION(3)                             :: Ri, Rij, Rj
     238          54 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: neighbors_distance
     239          54 :       REAL(sp), ALLOCATABLE, DIMENSION(:, :)             :: features, outer_grad, relpos
     240          54 :       REAL(sp), DIMENSION(:, :), POINTER                 :: predicted_xblock, relpos_grad
     241          54 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     242             :       TYPE(cell_type), POINTER                           :: cell
     243             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     244             :       TYPE(pao_model_type), POINTER                      :: model
     245          54 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     246          54 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     247             :       TYPE(torch_dict_type)                              :: model_inputs, model_outputs
     248             :       TYPE(torch_tensor_type)                            :: features_tensor, outer_grad_tensor, &
     249             :                                                             predicted_xblock_tensor, &
     250             :                                                             relpos_grad_tensor, relpos_tensor
     251             : 
     252          54 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
     253          54 :       n = blk_sizes_pri(iatom) ! size of primary basis
     254          54 :       m = blk_sizes_pao(iatom) ! size of pao basis
     255             : 
     256             :       CALL get_qs_env(qs_env, &
     257             :                       para_env=para_env, &
     258             :                       cell=cell, &
     259             :                       particle_set=particle_set, &
     260             :                       atomic_kind_set=atomic_kind_set, &
     261             :                       qs_kind_set=qs_kind_set, &
     262          54 :                       natom=natoms)
     263             : 
     264          54 :       CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     265          54 :       model => pao%models(ikind)
     266          54 :       CPASSERT(model%version > 0)
     267          54 :       CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
     268             : 
     269             :       ! Find neighbors.
     270             :       ! TODO: this is a quadratic algorithm, use a neighbor-list instead
     271         270 :       ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
     272         216 :       Ri = particle_set(iatom)%r
     273         378 :       DO jatom = 1, natoms
     274        1296 :          Rj = particle_set(jatom)%r
     275         324 :          Rij = pbc(Ri, Rj, cell)
     276        1350 :          neighbors_distance(jatom) = DOT_PRODUCT(Rij, Rij) ! using squared distances for performance
     277             :       END DO
     278          54 :       CALL sort(neighbors_distance, natoms, neighbors_index)
     279          54 :       CPASSERT(neighbors_index(1) == iatom) ! central atom should be closesd to itself
     280             : 
     281             :       ! Compute neighbors relative positions.
     282         162 :       ALLOCATE (relpos(3, model%num_neighbors))
     283        1134 :       relpos(:, :) = 0.0_sp
     284         324 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     285         270 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     286        1080 :          Rj = particle_set(jatom)%r
     287         270 :          Rij = pbc(Ri, Rj, cell)
     288        1134 :          relpos(:, jneighbor) = REAL(angstrom*Rij, kind=sp)
     289             :       END DO
     290             : 
     291             :       ! Compute neighbors features.
     292         216 :       ALLOCATE (features(SIZE(model%feature_kinds), model%num_neighbors))
     293         864 :       features(:, :) = 0.0_sp
     294         324 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     295         270 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     296         270 :          jkind = particle_set(jatom)%atomic_kind%kind_number
     297         864 :          WHERE (model%feature_kinds == jkind) features(:, jneighbor) = 1.0_sp
     298             :       END DO
     299             : 
     300             :       ! Inference.
     301          54 :       CALL torch_dict_create(model_inputs)
     302             : 
     303          54 :       CALL torch_tensor_from_array(relpos_tensor, relpos, requires_grad=PRESENT(block_G))
     304          54 :       CALL torch_dict_insert(model_inputs, "neighbors_relpos", relpos_tensor)
     305          54 :       CALL torch_tensor_from_array(features_tensor, features)
     306          54 :       CALL torch_dict_insert(model_inputs, "neighbors_features", features_tensor)
     307             : 
     308          54 :       CALL torch_dict_create(model_outputs)
     309          54 :       CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
     310             : 
     311             :       ! Copy predicted XBlock.
     312          54 :       NULLIFY (predicted_xblock)
     313          54 :       CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
     314          54 :       CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
     315          54 :       CPASSERT(SIZE(predicted_xblock, 1) == n .AND. SIZE(predicted_xblock, 2) == m)
     316          54 :       IF (PRESENT(block_X)) THEN
     317        1664 :          block_X = RESHAPE(predicted_xblock, [n*m, 1])
     318             :       END IF
     319             : 
     320             :       ! TURNING POINT (if calc forces) ------------------------------------------
     321          54 :       IF (PRESENT(block_G)) THEN
     322          24 :          ALLOCATE (outer_grad(n, m))
     323         226 :          outer_grad(:, :) = REAL(RESHAPE(block_G, [n, m]), kind=sp)
     324           6 :          CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
     325           6 :          CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
     326           6 :          CALL torch_tensor_grad(relpos_tensor, relpos_grad_tensor)
     327           6 :          NULLIFY (relpos_grad)
     328           6 :          CALL torch_tensor_data_ptr(relpos_grad_tensor, relpos_grad)
     329           6 :          CPASSERT(SIZE(relpos_grad, 1) == 3 .AND. SIZE(relpos_grad, 2) == model%num_neighbors)
     330          36 :          DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     331          30 :             jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     332         120 :             forces(iatom, :) = forces(iatom, :) + relpos_grad(:, jneighbor)*angstrom
     333         126 :             forces(jatom, :) = forces(jatom, :) - relpos_grad(:, jneighbor)*angstrom
     334             :          END DO
     335           6 :          CALL torch_tensor_release(outer_grad_tensor)
     336           6 :          CALL torch_tensor_release(relpos_grad_tensor)
     337             :       END IF
     338             : 
     339             :       ! Clean up.
     340          54 :       CALL torch_tensor_release(relpos_tensor)
     341          54 :       CALL torch_tensor_release(features_tensor)
     342          54 :       CALL torch_tensor_release(predicted_xblock_tensor)
     343          54 :       CALL torch_dict_release(model_inputs)
     344          54 :       CALL torch_dict_release(model_outputs)
     345          54 :       DEALLOCATE (neighbors_distance, neighbors_index, relpos, features)
     346          54 :       CALL omp_unset_lock(model%lock)
     347             : 
     348         162 :    END SUBROUTINE predict_single_atom
     349             : 
     350             : END MODULE pao_model

Generated by: LCOV version 1.15