LCOV - code coverage report
Current view: top level - src - pao_param_equi.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 54 55 98.2 %
Date: 2024-12-21 06:28:57 Functions: 5 5 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Equivariant parametrization
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_param_equi
      13             :    USE basis_set_types,                 ONLY: gto_basis_set_type
      14             :    USE cp_dbcsr_api,                    ONLY: &
      15             :         dbcsr_complete_redistribute, dbcsr_create, dbcsr_distribution_type, dbcsr_get_block_p, &
      16             :         dbcsr_get_info, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
      17             :         dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, &
      18             :         dbcsr_release, dbcsr_reserve_diag_blocks, dbcsr_type
      19             :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      20             :                                               ls_scf_env_type
      21             :    USE kinds,                           ONLY: dp
      22             :    USE mathlib,                         ONLY: diamat_all
      23             :    USE message_passing,                 ONLY: mp_comm_type
      24             :    USE pao_param_methods,               ONLY: pao_calc_grad_lnv_wrt_AB
      25             :    USE pao_potentials,                  ONLY: pao_guess_initial_potential
      26             :    USE pao_types,                       ONLY: pao_env_type
      27             :    USE qs_environment_types,            ONLY: get_qs_env,&
      28             :                                               qs_environment_type
      29             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      30             :                                               qs_kind_type
      31             : #include "./base/base_uses.f90"
      32             : 
      33             :    IMPLICIT NONE
      34             : 
      35             :    PRIVATE
      36             : 
      37             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_param_equi'
      38             : 
      39             :    PUBLIC :: pao_param_init_equi, pao_param_finalize_equi, pao_calc_AB_equi
      40             :    PUBLIC :: pao_param_count_equi, pao_param_initguess_equi
      41             : 
      42             : CONTAINS
      43             : 
      44             : ! **************************************************************************************************
      45             : !> \brief Initialize equivariant parametrization
      46             : !> \param pao ...
      47             : ! **************************************************************************************************
      48          12 :    SUBROUTINE pao_param_init_equi(pao)
      49             :       TYPE(pao_env_type), POINTER                        :: pao
      50             : 
      51          12 :       IF (pao%precondition) &
      52           0 :          CPABORT("PAO preconditioning not supported for selected parametrization.")
      53             : 
      54          12 :    END SUBROUTINE pao_param_init_equi
      55             : 
      56             : ! **************************************************************************************************
      57             : !> \brief Finalize equivariant parametrization
      58             : ! **************************************************************************************************
      59          12 :    SUBROUTINE pao_param_finalize_equi()
      60             : 
      61             :       ! Nothing to do.
      62             : 
      63          12 :    END SUBROUTINE pao_param_finalize_equi
      64             : 
      65             : ! **************************************************************************************************
      66             : !> \brief Returns the number of parameters for given atomic kind
      67             : !> \param qs_env ...
      68             : !> \param ikind ...
      69             : !> \param nparams ...
      70             : ! **************************************************************************************************
      71          88 :    SUBROUTINE pao_param_count_equi(qs_env, ikind, nparams)
      72             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      73             :       INTEGER, INTENT(IN)                                :: ikind
      74             :       INTEGER, INTENT(OUT)                               :: nparams
      75             : 
      76             :       INTEGER                                            :: pao_basis_size, pri_basis_size
      77             :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      78          44 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      79             : 
      80          44 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set)
      81             :       CALL get_qs_kind(qs_kind_set(ikind), &
      82             :                        basis_set=basis_set, &
      83          44 :                        pao_basis_size=pao_basis_size)
      84          44 :       pri_basis_size = basis_set%nsgf
      85             : 
      86          44 :       nparams = pao_basis_size*pri_basis_size
      87             : 
      88          44 :    END SUBROUTINE pao_param_count_equi
      89             : 
      90             : ! **************************************************************************************************
      91             : !> \brief Fills matrix_X with an initial guess
      92             : !> \param pao ...
      93             : !> \param qs_env ...
      94             : ! **************************************************************************************************
      95          10 :    SUBROUTINE pao_param_initguess_equi(pao, qs_env)
      96             :       TYPE(pao_env_type), POINTER                        :: pao
      97             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      98             : 
      99             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_param_initguess_equi'
     100             : 
     101             :       INTEGER                                            :: acol, arow, handle, i, iatom, m, n
     102          10 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_pao, blk_sizes_pri
     103             :       LOGICAL                                            :: found
     104          10 :       REAL(dp), DIMENSION(:), POINTER                    :: H_evals
     105          10 :       REAL(dp), DIMENSION(:, :), POINTER                 :: A, block_H0, block_N, block_N_inv, &
     106          10 :                                                             block_X, H, H_evecs, V0
     107             :       TYPE(dbcsr_iterator_type)                          :: iter
     108             : 
     109          10 :       CALL timeset(routineN, handle)
     110             : 
     111          10 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
     112             : 
     113             : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,blk_sizes_pri,blk_sizes_pao) &
     114             : !$OMP PRIVATE(iter,arow,acol,iatom,n,m,i,found) &
     115          10 : !$OMP PRIVATE(block_X,block_H0,block_N,block_N_inv,A,H,H_evecs,H_evals,V0)
     116             :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     117             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     118             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     119             :          iatom = arow; CPASSERT(arow == acol)
     120             : 
     121             :          CALL dbcsr_get_block_p(matrix=pao%matrix_H0, row=iatom, col=iatom, block=block_H0, found=found)
     122             :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_diag, row=iatom, col=iatom, block=block_N, found=found)
     123             :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_inv_diag, row=iatom, col=iatom, block=block_N_inv, found=found)
     124             :          CPASSERT(ASSOCIATED(block_H0) .AND. ASSOCIATED(block_N) .AND. ASSOCIATED(block_N_inv))
     125             : 
     126             :          n = blk_sizes_pri(iatom) ! size of primary basis
     127             :          m = blk_sizes_pao(iatom) ! size of pao basis
     128             : 
     129             :          ALLOCATE (V0(n, n))
     130             :          CALL pao_guess_initial_potential(qs_env, iatom, V0)
     131             : 
     132             :          ! construct H
     133             :          ALLOCATE (H(n, n))
     134             :          H = MATMUL(MATMUL(block_N, block_H0 + V0), block_N) ! transform into orthonormal basis
     135             : 
     136             :          ! diagonalize H
     137             :          ALLOCATE (H_evecs(n, n), H_evals(n))
     138             :          H_evecs = H
     139             :          CALL diamat_all(H_evecs, H_evals)
     140             : 
     141             :          ! use first m eigenvectors as initial guess
     142             :          ALLOCATE (A(n, m))
     143             :          A = MATMUL(block_N_inv, H_evecs(:, 1:m))
     144             : 
     145             :          ! normalize vectors
     146             :          DO i = 1, m
     147             :             A(:, i) = A(:, i)/NORM2(A(:, i))
     148             :          END DO
     149             : 
     150             :          block_X = RESHAPE(A, (/n*m, 1/))
     151             :          DEALLOCATE (H, V0, A, H_evecs, H_evals)
     152             : 
     153             :       END DO
     154             :       CALL dbcsr_iterator_stop(iter)
     155             : !$OMP END PARALLEL
     156             : 
     157          10 :       CALL timestop(handle)
     158             : 
     159          10 :    END SUBROUTINE pao_param_initguess_equi
     160             : 
     161             : ! **************************************************************************************************
     162             : !> \brief Takes current matrix_X and calculates the matrices A and B.
     163             : !> \param pao ...
     164             : !> \param qs_env ...
     165             : !> \param ls_scf_env ...
     166             : !> \param gradient ...
     167             : !> \param penalty ...
     168             : ! **************************************************************************************************
     169        3380 :    SUBROUTINE pao_calc_AB_equi(pao, qs_env, ls_scf_env, gradient, penalty)
     170             :       TYPE(pao_env_type), POINTER                        :: pao
     171             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     172             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     173             :       LOGICAL, INTENT(IN)                                :: gradient
     174             :       REAL(dp), INTENT(INOUT), OPTIONAL                  :: penalty
     175             : 
     176             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_calc_AB_equi'
     177             : 
     178             :       INTEGER                                            :: acol, arow, group_handle, handle, i, &
     179             :                                                             iatom, j, k, m, n
     180             :       LOGICAL                                            :: found
     181             :       REAL(dp)                                           :: denom, w
     182        1690 :       REAL(dp), DIMENSION(:), POINTER                    :: ANNA_evals
     183        1690 :       REAL(dp), DIMENSION(:, :), POINTER                 :: ANNA, ANNA_evecs, ANNA_inv, block_A, &
     184        1690 :                                                             block_B, block_G, block_Ma, block_Mb, &
     185        1690 :                                                             block_N, block_X, D, G, M1, M2, M3, &
     186        1690 :                                                             M4, M5, NN
     187             :       TYPE(dbcsr_distribution_type)                      :: main_dist
     188             :       TYPE(dbcsr_iterator_type)                          :: iter
     189        1690 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     190             :       TYPE(dbcsr_type)                                   :: matrix_G_nondiag, matrix_Ma, matrix_Mb, &
     191             :                                                             matrix_X_nondiag
     192             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     193             :       TYPE(mp_comm_type)                                 :: group
     194             : 
     195        1690 :       CALL timeset(routineN, handle)
     196        1690 :       ls_mstruct => ls_scf_env%ls_mstruct
     197             : 
     198        1690 :       IF (gradient) THEN
     199         232 :          CALL pao_calc_grad_lnv_wrt_AB(qs_env, ls_scf_env, matrix_Ma, matrix_Mb)
     200             :       END IF
     201             : 
     202             :       ! Redistribute matrix_X from diag_distribution to distribution of matrix_s.
     203        1690 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     204        1690 :       CALL dbcsr_get_info(matrix=matrix_s(1)%matrix, distribution=main_dist)
     205             :       CALL dbcsr_create(matrix_X_nondiag, &
     206             :                         name="PAO matrix_X_nondiag", &
     207             :                         dist=main_dist, &
     208        1690 :                         template=pao%matrix_X)
     209        1690 :       CALL dbcsr_reserve_diag_blocks(matrix_X_nondiag)
     210        1690 :       CALL dbcsr_complete_redistribute(pao%matrix_X, matrix_X_nondiag)
     211             : 
     212             :       ! Compuation of matrix_G uses distr. of matrix_s, afterwards we redistribute to diag_distribution.
     213        1690 :       IF (gradient) THEN
     214             :          CALL dbcsr_create(matrix_G_nondiag, &
     215             :                            name="PAO matrix_G_nondiag", &
     216             :                            dist=main_dist, &
     217         232 :                            template=pao%matrix_G)
     218         232 :          CALL dbcsr_reserve_diag_blocks(matrix_G_nondiag)
     219             :       END IF
     220             : 
     221             : !$OMP PARALLEL DEFAULT(NONE) &
     222             : !$OMP SHARED(pao,ls_mstruct,matrix_X_nondiag,matrix_G_nondiag,matrix_Ma,matrix_Mb,gradient,penalty) &
     223             : !$OMP PRIVATE(iter,arow,acol,iatom,found,n,m,w,i,j,k,denom) &
     224             : !$OMP PRIVATE(NN,ANNA,ANNA_evals,ANNA_evecs,ANNA_inv,D,G,M1,M2,M3,M4,M5) &
     225        1690 : !$OMP PRIVATE(block_X,block_A,block_B,block_N,block_Ma, block_Mb, block_G)
     226             :       CALL dbcsr_iterator_start(iter, matrix_X_nondiag)
     227             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     228             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     229             :          iatom = arow; CPASSERT(arow == acol)
     230             :          CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_A, row=iatom, col=iatom, block=block_A, found=found)
     231             :          CPASSERT(ASSOCIATED(block_A))
     232             :          CALL dbcsr_get_block_p(matrix=ls_mstruct%matrix_B, row=iatom, col=iatom, block=block_B, found=found)
     233             :          CPASSERT(ASSOCIATED(block_B))
     234             :          CALL dbcsr_get_block_p(matrix=pao%matrix_N, row=iatom, col=iatom, block=block_N, found=found)
     235             :          CPASSERT(ASSOCIATED(block_N))
     236             : 
     237             :          n = SIZE(block_A, 1) ! size of primary basis
     238             :          m = SIZE(block_A, 2) ! size of pao basis
     239             :          block_A = RESHAPE(block_X, (/n, m/))
     240             : 
     241             :          ! restrain pao basis vectors to unit norm
     242             :          IF (PRESENT(penalty)) THEN
     243             :             DO i = 1, m
     244             :                w = 1.0_dp - SUM(block_A(:, i)**2)
     245             :                penalty = penalty + pao%penalty_strength*w**2
     246             :             END DO
     247             :          END IF
     248             : 
     249             :          ALLOCATE (NN(n, n), ANNA(m, m))
     250             :          NN = MATMUL(block_N, block_N) ! it's actually S^{-1}
     251             :          ANNA = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_A)
     252             : 
     253             :          ! diagonalize ANNA
     254             :          ALLOCATE (ANNA_evecs(m, m), ANNA_evals(m))
     255             :          ANNA_evecs(:, :) = ANNA
     256             :          CALL diamat_all(ANNA_evecs, ANNA_evals)
     257             :          IF (MINVAL(ABS(ANNA_evals)) < 1e-10_dp) CPABORT("PAO basis singualar.")
     258             : 
     259             :          ! build ANNA_inv
     260             :          ALLOCATE (ANNA_inv(m, m))
     261             :          ANNA_inv(:, :) = 0.0_dp
     262             :          DO k = 1, m
     263             :             w = 1.0_dp/ANNA_evals(k)
     264             :             DO i = 1, m
     265             :             DO j = 1, m
     266             :                ANNA_inv(i, j) = ANNA_inv(i, j) + w*ANNA_evecs(i, k)*ANNA_evecs(j, k)
     267             :             END DO
     268             :             END DO
     269             :          END DO
     270             : 
     271             :          !B = 1/S * A * 1/(A^T 1/S A)
     272             :          block_B = MATMUL(MATMUL(NN, block_A), ANNA_inv)
     273             : 
     274             :          ! TURNING POINT (if calc grad) ------------------------------------------
     275             :          IF (gradient) THEN
     276             :             CALL dbcsr_get_block_p(matrix=matrix_G_nondiag, row=iatom, col=iatom, block=block_G, found=found)
     277             :             CPASSERT(ASSOCIATED(block_G))
     278             :             CALL dbcsr_get_block_p(matrix=matrix_Ma, row=iatom, col=iatom, block=block_Ma, found=found)
     279             :             CALL dbcsr_get_block_p(matrix=matrix_Mb, row=iatom, col=iatom, block=block_Mb, found=found)
     280             :             ! don't check ASSOCIATED(block_M), it might have been filtered out.
     281             : 
     282             :             ALLOCATE (G(n, m))
     283             :             G(:, :) = 0.0_dp
     284             : 
     285             :             IF (PRESENT(penalty)) THEN
     286             :                DO i = 1, m
     287             :                   w = 1.0_dp - SUM(block_A(:, i)**2)
     288             :                   G(:, i) = -4.0_dp*pao%penalty_strength*w*block_A(:, i)
     289             :                END DO
     290             :             END IF
     291             : 
     292             :             IF (ASSOCIATED(block_Ma)) THEN
     293             :                G = G + block_Ma
     294             :             END IF
     295             : 
     296             :             IF (ASSOCIATED(block_Mb)) THEN
     297             :                G = G + MATMUL(MATMUL(NN, block_Mb), ANNA_inv)
     298             : 
     299             :                ! calculate derivatives dAA_inv/ dAA
     300             :                ALLOCATE (D(m, m), M1(m, m), M2(m, m), M3(m, m), M4(m, m), M5(m, m))
     301             : 
     302             :                DO i = 1, m
     303             :                DO j = 1, m
     304             :                   denom = ANNA_evals(i) - ANNA_evals(j)
     305             :                   IF (i == j) THEN
     306             :                      D(i, i) = -1.0_dp/ANNA_evals(i)**2 ! diagonal elements
     307             :                   ELSE IF (ABS(denom) > 1e-10_dp) THEN
     308             :                      D(i, j) = (1.0_dp/ANNA_evals(i) - 1.0_dp/ANNA_evals(j))/denom
     309             :                   ELSE
     310             :                      D(i, j) = -1.0_dp ! limit according to L'Hospital's rule
     311             :                   END IF
     312             :                END DO
     313             :                END DO
     314             : 
     315             :                M1 = MATMUL(MATMUL(TRANSPOSE(block_A), NN), block_Mb)
     316             :                M2 = MATMUL(MATMUL(TRANSPOSE(ANNA_evecs), M1), ANNA_evecs)
     317             :                M3 = M2*D ! Hadamard product
     318             :                M4 = MATMUL(MATMUL(ANNA_evecs, M3), TRANSPOSE(ANNA_evecs))
     319             :                M5 = 0.5_dp*(M4 + TRANSPOSE(M4))
     320             :                G = G + 2.0_dp*MATMUL(MATMUL(NN, block_A), M5)
     321             : 
     322             :                DEALLOCATE (D, M1, M2, M3, M4, M5)
     323             :             END IF
     324             : 
     325             :             block_G = RESHAPE(G, (/n*m, 1/))
     326             :             DEALLOCATE (G)
     327             :          END IF
     328             : 
     329             :          DEALLOCATE (NN, ANNA, ANNA_evecs, ANNA_evals, ANNA_inv)
     330             :       END DO
     331             :       CALL dbcsr_iterator_stop(iter)
     332             : !$OMP END PARALLEL
     333             : 
     334             :       ! sum penalty energies across ranks
     335        1690 :       IF (PRESENT(penalty)) THEN
     336        1678 :          CALL dbcsr_get_info(pao%matrix_X, group=group_handle)
     337        1678 :          CALL group%set_handle(group_handle)
     338        1678 :          CALL group%sum(penalty)
     339             :       END IF
     340             : 
     341        1690 :       CALL dbcsr_release(matrix_X_nondiag)
     342             : 
     343        1690 :       IF (gradient) THEN
     344         232 :          CALL dbcsr_complete_redistribute(matrix_G_nondiag, pao%matrix_G)
     345         232 :          CALL dbcsr_release(matrix_G_nondiag)
     346         232 :          CALL dbcsr_release(matrix_Ma)
     347         232 :          CALL dbcsr_release(matrix_Mb)
     348             :       END IF
     349             : 
     350        1690 :       CALL timestop(handle)
     351             : 
     352        1690 :    END SUBROUTINE pao_calc_AB_equi
     353             : 
     354             : END MODULE pao_param_equi

Generated by: LCOV version 1.15