LCOV - code coverage report
Current view: top level - src - nnp_model.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 49 69 71.0 %
Date: 2024-11-21 06:45:46 Functions: 3 3 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief  Methods dealing with core routines for artificial neural networks
      10             : !> \author Christoph Schran (christoph.schran@rub.de)
      11             : !> \date   2020-10-10
      12             : ! **************************************************************************************************
      13             : MODULE nnp_model
      14             : 
      15             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      16             :                                               cp_logger_get_default_unit_nr,&
      17             :                                               cp_logger_type
      18             :    USE kinds,                           ONLY: default_string_length,&
      19             :                                               dp
      20             :    USE message_passing,                 ONLY: mp_para_env_type
      21             :    USE nnp_environment_types,           ONLY: &
      22             :         nnp_actfnct_cos, nnp_actfnct_exp, nnp_actfnct_gaus, nnp_actfnct_invsig, nnp_actfnct_lin, &
      23             :         nnp_actfnct_quad, nnp_actfnct_sig, nnp_actfnct_softplus, nnp_actfnct_tanh, nnp_arc_type, &
      24             :         nnp_type
      25             : #include "./base/base_uses.f90"
      26             : 
      27             :    IMPLICIT NONE
      28             : 
      29             :    PRIVATE
      30             : 
      31             :    LOGICAL, PARAMETER, PRIVATE :: debug_this_module = .TRUE.
      32             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'nnp_model'
      33             : 
      34             :    PUBLIC :: nnp_write_arc, &
      35             :              nnp_predict, &
      36             :              nnp_gradients
      37             : 
      38             : CONTAINS
      39             : 
      40             : ! **************************************************************************************************
      41             : !> \brief Write neural network architecture information
      42             : !> \param nnp ...
      43             : !> \param para_env ...
      44             : !> \param printtag ...
      45             : ! **************************************************************************************************
      46          15 :    SUBROUTINE nnp_write_arc(nnp, para_env, printtag)
      47             :       TYPE(nnp_type), INTENT(IN)                         :: nnp
      48             :       TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
      49             :       CHARACTER(LEN=*), INTENT(IN)                       :: printtag
      50             : 
      51             :       CHARACTER(len=default_string_length)               :: my_label
      52             :       INTEGER                                            :: i, j, unit_nr
      53             :       TYPE(cp_logger_type), POINTER                      :: logger
      54             : 
      55          15 :       NULLIFY (logger)
      56          15 :       logger => cp_get_default_logger()
      57             : 
      58          15 :       my_label = TRIM(printtag)//"| "
      59          15 :       IF (para_env%is_source()) THEN
      60           8 :          unit_nr = cp_logger_get_default_unit_nr(logger)
      61          25 :          DO i = 1, nnp%n_ele
      62             :             WRITE (unit_nr, *) TRIM(my_label)//" Neural network specification for element "// &
      63          17 :                nnp%ele(i)//":"
      64          93 :             DO j = 1, nnp%n_layer
      65          68 :                WRITE (unit_nr, '(1X,A,1X,I3,1X,A,1X,I2)') TRIM(my_label), &
      66         153 :                   nnp%arc(i)%n_nodes(j), "nodes in layer", j
      67             :             END DO
      68             :          END DO
      69             :       END IF
      70             : 
      71          15 :       RETURN
      72             : 
      73             :    END SUBROUTINE nnp_write_arc
      74             : 
      75             : ! **************************************************************************************************
      76             : !> \brief Predict energy by evaluating neural network
      77             : !> \param arc ...
      78             : !> \param nnp ...
      79             : !> \param i_com ...
      80             : ! **************************************************************************************************
      81      425816 :    SUBROUTINE nnp_predict(arc, nnp, i_com)
      82             :       TYPE(nnp_arc_type), INTENT(INOUT)                  :: arc
      83             :       TYPE(nnp_type), INTENT(INOUT)                      :: nnp
      84             :       INTEGER, INTENT(IN)                                :: i_com
      85             : 
      86             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'nnp_predict'
      87             : 
      88             :       INTEGER                                            :: handle, i, j
      89             :       REAL(KIND=dp)                                      :: norm
      90             : 
      91      425816 :       CALL timeset(routineN, handle)
      92             : 
      93     1703264 :       DO i = 2, nnp%n_layer
      94             :          ! Calculate node(i)
      95    20942784 :          arc%layer(i)%node(:) = 0.0_dp
      96             :          !Perform matrix-vector product
      97             :          !y := alpha*A*x + beta*y
      98             :          !with A = layer(i)*weights
      99             :          !and  x = layer(i-1)%node
     100             :          !and  y = layer(i)%node
     101             :          CALL DGEMV('T', & !transpose matrix A
     102             :                     arc%n_nodes(i - 1), & !number of rows of A
     103             :                     arc%n_nodes(i), & !number of columns of A
     104             :                     1.0_dp, & ! alpha
     105             :                     arc%layer(i)%weights(:, :, i_com), & !matrix A
     106             :                     arc%n_nodes(i - 1), & !leading dimension of A
     107             :                     arc%layer(i - 1)%node, & !vector x
     108             :                     1, & !increment for the elements of x
     109             :                     1.0_dp, & !beta
     110             :                     arc%layer(i)%node, & !vector y
     111     1277448 :                     1) !increment for the elements of y
     112             : 
     113             :          ! Add bias weight
     114    20942784 :          DO j = 1, arc%n_nodes(i)
     115    20942784 :             arc%layer(i)%node(j) = arc%layer(i)%node(j) + arc%layer(i)%bweights(j, i_com)
     116             :          END DO
     117             : 
     118             :          ! Normalize by number of nodes in previous layer if requested
     119     1277448 :          IF (nnp%normnodes) THEN
     120           0 :             norm = 1.0_dp/REAL(arc%n_nodes(i - 1), dp)
     121           0 :             DO j = 1, arc%n_nodes(i)
     122           0 :                arc%layer(i)%node(j) = arc%layer(i)%node(j)*norm
     123             :             END DO
     124             :          END IF
     125             : 
     126             :          ! Store node values before application of activation function
     127             :          ! (needed for derivatives)
     128    20942784 :          DO j = 1, arc%n_nodes(i)
     129    20942784 :             arc%layer(i)%node_grad(j) = arc%layer(i)%node(j)
     130             :          END DO
     131             : 
     132             :          ! Apply activation function:
     133      425816 :          SELECT CASE (nnp%actfnct(i - 1))
     134             :          CASE (nnp_actfnct_tanh)
     135    20091152 :             arc%layer(i)%node(:) = TANH(arc%layer(i)%node(:))
     136             :          CASE (nnp_actfnct_gaus)
     137           0 :             arc%layer(i)%node(:) = EXP(-0.5_dp*arc%layer(i)%node(:)**2)
     138             :          CASE (nnp_actfnct_lin)
     139           0 :             CONTINUE
     140             :          CASE (nnp_actfnct_cos)
     141           0 :             arc%layer(i)%node(:) = COS(arc%layer(i)%node(:))
     142             :          CASE (nnp_actfnct_sig)
     143           0 :             arc%layer(i)%node(:) = 1.0_dp/(1.0_dp + EXP(-1.0_dp*arc%layer(i)%node(:)))
     144             :          CASE (nnp_actfnct_invsig)
     145           0 :             arc%layer(i)%node(:) = 1.0_dp - 1.0_dp/(1.0_dp + EXP(-1.0_dp*arc%layer(i)%node(:)))
     146             :          CASE (nnp_actfnct_exp)
     147           0 :             arc%layer(i)%node(:) = EXP(-1.0_dp*arc%layer(i)%node(:))
     148             :          CASE (nnp_actfnct_softplus)
     149           0 :             arc%layer(i)%node(:) = LOG(EXP(arc%layer(i)%node(:)) + 1.0_dp)
     150             :          CASE (nnp_actfnct_quad)
     151           0 :             arc%layer(i)%node(:) = arc%layer(i)%node(:)**2
     152             :          CASE DEFAULT
     153     1277448 :             CPABORT("NNP| Error: Unknown activation function")
     154             :          END SELECT
     155             :       END DO
     156             : 
     157      425816 :       CALL timestop(handle)
     158             : 
     159      425816 :    END SUBROUTINE nnp_predict
     160             : 
     161             : ! **************************************************************************************************
     162             : !> \brief Calculate gradients of neural network
     163             : !> \param arc ...
     164             : !> \param nnp ...
     165             : !> \param i_com ...
     166             : !> \param denergydsym ...
     167             : ! **************************************************************************************************
     168       93256 :    SUBROUTINE nnp_gradients(arc, nnp, i_com, denergydsym)
     169             :       TYPE(nnp_arc_type), INTENT(INOUT)                  :: arc
     170             :       TYPE(nnp_type), INTENT(INOUT)                      :: nnp
     171             :       INTEGER, INTENT(IN)                                :: i_com
     172             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: denergydsym
     173             : 
     174             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'nnp_gradients'
     175             : 
     176             :       INTEGER                                            :: handle, i, j, k
     177             :       REAL(KIND=dp)                                      :: norm
     178             : 
     179       93256 :       CALL timeset(routineN, handle)
     180             : 
     181       93256 :       norm = 1.0_dp
     182             : 
     183      373024 :       DO i = 2, nnp%n_layer
     184             : 
     185             :          ! Apply activation function:
     186      466280 :          SELECT CASE (nnp%actfnct(i - 1))
     187             :          CASE (nnp_actfnct_tanh)
     188     3980752 :             arc%layer(i)%node_grad(:) = 1.0_dp - arc%layer(i)%node(:)**2 !tanh(x)'=1-tanh(x)**2
     189             :          CASE (nnp_actfnct_gaus)
     190           0 :             arc%layer(i)%node_grad(:) = -1.0_dp*arc%layer(i)%node(:)*arc%layer(i)%node_grad(:)
     191             :          CASE (nnp_actfnct_lin)
     192      186512 :             arc%layer(i)%node_grad(:) = 1.0_dp
     193             :          CASE (nnp_actfnct_cos)
     194           0 :             arc%layer(i)%node_grad(:) = -SIN(arc%layer(i)%node_grad(:))
     195             :          CASE (nnp_actfnct_sig)
     196             :             arc%layer(i)%node_grad(:) = EXP(-arc%layer(i)%node_grad(:))/ &
     197           0 :                                         (1.0_dp + EXP(-1.0_dp*arc%layer(i)%node_grad(:)))**2
     198             :          CASE (nnp_actfnct_invsig)
     199             :             arc%layer(i)%node_grad(:) = -1.0_dp*EXP(-1.0_dp*arc%layer(i)%node_grad(:))/ &
     200           0 :                                         (1.0_dp + EXP(-1.0_dp*arc%layer(i)%node_grad(:)))**2
     201             :          CASE (nnp_actfnct_exp)
     202           0 :             arc%layer(i)%node_grad(:) = -1.0_dp*arc%layer(i)%node(:)
     203             :          CASE (nnp_actfnct_softplus)
     204             :             arc%layer(i)%node_grad(:) = (EXP(arc%layer(i)%node(:)) + 1.0_dp)/ &
     205           0 :                                         EXP(arc%layer(i)%node(:))
     206             :          CASE (nnp_actfnct_quad)
     207           0 :             arc%layer(i)%node_grad(:) = 2.0_dp*arc%layer(i)%node_grad(:)
     208             :          CASE DEFAULT
     209      279768 :             CPABORT("NNP| Error: Unknown activation function")
     210             :          END SELECT
     211             :          ! Normalize by number of nodes in previous layer if requested
     212      373024 :          IF (nnp%normnodes) THEN
     213           0 :             norm = 1.0_dp/REAL(arc%n_nodes(i - 1), dp)
     214           0 :             arc%layer(i)%node_grad(:) = norm*arc%layer(i)%node_grad(:)
     215             :          END IF
     216             : 
     217             :       END DO
     218             : 
     219             :       ! calculate \frac{\partial f^1(x_j^1)}{\partial G_i}*a_{ij}^{01}
     220     1990376 :       DO j = 1, arc%n_nodes(2)
     221    53909736 :          DO i = 1, arc%n_nodes(1)
     222    53816480 :             arc%layer(2)%tmp_der(i, j) = arc%layer(2)%node_grad(j)*arc%layer(2)%weights(i, j, i_com)
     223             :          END DO
     224             :       END DO
     225             : 
     226      279768 :       DO k = 3, nnp%n_layer
     227             :          ! Reset tmp_der:
     228    56659416 :          arc%layer(k)%tmp_der(:, :) = 0.0_dp
     229             :          !Perform matrix-matrix product
     230             :          !C := alpha*A*B + beta*C
     231             :          !with A = layer(k-1)%tmp_der
     232             :          !and  B = layer(k)%weights
     233             :          !and  C = tmp
     234             :          CALL DGEMM('N', & !don't transpose matrix A
     235             :                     'N', & !don't transpose matrix B
     236             :                     arc%n_nodes(1), & !number of rows of A
     237             :                     arc%n_nodes(k), & !number of columns of B
     238             :                     arc%n_nodes(k - 1), & !number of col of A and nb of rows of B
     239             :                     1.0_dp, & !alpha
     240             :                     arc%layer(k - 1)%tmp_der, & !matrix A
     241             :                     arc%n_nodes(1), & !leading dimension of A
     242             :                     arc%layer(k)%weights(:, :, i_com), & !matrix B
     243             :                     arc%n_nodes(k - 1), & !leading dimension of B
     244             :                     1.0_dp, & !beta
     245             :                     arc%layer(k)%tmp_der, & !matrix C
     246      186512 :                     arc%n_nodes(1)) !leading dimension of C
     247             : 
     248             :          ! sum over all nodes in the target layer
     249     2270144 :          DO j = 1, arc%n_nodes(k)
     250             :             ! sum over input layer
     251    56659416 :             DO i = 1, arc%n_nodes(1)
     252             :                arc%layer(k)%tmp_der(i, j) = arc%layer(k)%node_grad(j)* &
     253    56472904 :                                             arc%layer(k)%tmp_der(i, j)
     254             :             END DO
     255             :          END DO
     256             :       END DO
     257             : 
     258     2656424 :       DO i = 1, arc%n_nodes(1)
     259     2656424 :          denergydsym(i) = arc%layer(nnp%n_layer)%tmp_der(i, 1)
     260             :       END DO
     261             : 
     262       93256 :       CALL timestop(handle)
     263             : 
     264       93256 :    END SUBROUTINE nnp_gradients
     265             : 
     266             : END MODULE nnp_model

Generated by: LCOV version 1.15