LCOV - code coverage report
Current view: top level - src - optbas_opt_utils.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 112 116 96.6 %
Date: 2024-11-21 06:45:46 Functions: 3 3 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : MODULE optbas_opt_utils
       8             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
       9             :                                               get_atomic_kind
      10             :    USE basis_set_types,                 ONLY: get_gto_basis_set,&
      11             :                                               gto_basis_set_type
      12             :    USE cell_types,                      ONLY: cell_type
      13             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      14             :    USE cp_dbcsr_api,                    ONLY: dbcsr_create,&
      15             :                                               dbcsr_distribution_type,&
      16             :                                               dbcsr_get_info,&
      17             :                                               dbcsr_p_type,&
      18             :                                               dbcsr_release,&
      19             :                                               dbcsr_transposed,&
      20             :                                               dbcsr_type,&
      21             :                                               dbcsr_type_no_symmetry
      22             :    USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
      23             :                                               cp_dbcsr_sm_fm_multiply
      24             :    USE cp_fm_basic_linalg,              ONLY: cp_fm_invert,&
      25             :                                               cp_fm_trace
      26             :    USE cp_fm_diag,                      ONLY: cp_fm_power
      27             :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      28             :                                               cp_fm_struct_release,&
      29             :                                               cp_fm_struct_type
      30             :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      31             :                                               cp_fm_get_info,&
      32             :                                               cp_fm_release,&
      33             :                                               cp_fm_type
      34             :    USE distribution_1d_types,           ONLY: distribution_1d_type
      35             :    USE distribution_2d_types,           ONLY: distribution_2d_type
      36             :    USE input_section_types,             ONLY: section_vals_val_get
      37             :    USE kinds,                           ONLY: dp
      38             :    USE message_passing,                 ONLY: mp_para_env_type
      39             :    USE molecule_types,                  ONLY: molecule_type
      40             :    USE parallel_gemm_api,               ONLY: parallel_gemm
      41             :    USE particle_types,                  ONLY: particle_type
      42             :    USE qs_condnum,                      ONLY: overlap_condnum
      43             :    USE qs_environment_types,            ONLY: get_qs_env,&
      44             :                                               qs_environment_type
      45             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      46             :                                               qs_kind_type
      47             :    USE qs_ks_types,                     ONLY: qs_ks_env_type
      48             :    USE qs_mo_types,                     ONLY: get_mo_set,&
      49             :                                               mo_set_type
      50             :    USE qs_neighbor_list_types,          ONLY: neighbor_list_set_p_type
      51             :    USE qs_neighbor_lists,               ONLY: atom2d_build,&
      52             :                                               atom2d_cleanup,&
      53             :                                               build_neighbor_lists,&
      54             :                                               local_atoms_type,&
      55             :                                               pair_radius_setup
      56             : #include "./base/base_uses.f90"
      57             : 
      58             :    IMPLICIT NONE
      59             :    PRIVATE
      60             : 
      61             :    PUBLIC :: evaluate_optvals, fit_mo_coeffs, optbas_build_neighborlist
      62             : 
      63             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'optbas_opt_utils'
      64             : 
      65             : CONTAINS
      66             : 
      67             : ! **************************************************************************************************
      68             : !> \brief ...
      69             : !> \param mos ...
      70             : !> \param mos_aux_fit ...
      71             : !> \param matrix_ks ...
      72             : !> \param Q ...
      73             : !> \param Snew ...
      74             : !> \param S_inv_orb ...
      75             : !> \param fval ...
      76             : !> \param energy ...
      77             : !> \param S_cond_number ...
      78             : ! **************************************************************************************************
      79         234 :    SUBROUTINE evaluate_optvals(mos, mos_aux_fit, matrix_ks, Q, Snew, S_inv_orb, &
      80             :                                fval, energy, S_cond_number)
      81             :       TYPE(mo_set_type), DIMENSION(:), INTENT(IN)        :: mos, mos_aux_fit
      82             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks
      83             :       TYPE(dbcsr_type), POINTER                          :: Q, Snew
      84             :       TYPE(cp_fm_type), INTENT(IN)                       :: S_inv_orb
      85             :       REAL(KIND=dp)                                      :: fval, energy, S_cond_number
      86             : 
      87             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'evaluate_optvals'
      88             : 
      89             :       INTEGER                                            :: handle, ispin, iunit, naux, nmo, norb, &
      90             :                                                             nspins
      91         234 :       INTEGER, DIMENSION(:), POINTER                     :: col_blk_sizes, row_blk_sizes
      92             :       REAL(KIND=dp)                                      :: tmp_energy, trace
      93             :       REAL(KIND=dp), DIMENSION(2)                        :: condnum
      94             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
      95             :       TYPE(cp_fm_type)                                   :: tmp1, tmp2
      96             :       TYPE(cp_fm_type), POINTER                          :: mo_coeff, mo_coeff_aux_fit
      97             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist
      98         234 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: smat
      99             :       TYPE(dbcsr_type)                                   :: Qt
     100             : 
     101         234 :       CALL timeset(routineN, handle)
     102             : 
     103         234 :       nspins = SIZE(mos)
     104             : 
     105         234 :       NULLIFY (col_blk_sizes, row_blk_sizes)
     106             :       CALL dbcsr_get_info(Q, distribution=dbcsr_dist, &
     107             :                           nfullrows_total=naux, nfullcols_total=norb, &
     108         234 :                           row_blk_size=row_blk_sizes, col_blk_size=col_blk_sizes)
     109             :       CALL dbcsr_create(matrix=Qt, name="Qt", &
     110             :                         dist=dbcsr_dist, matrix_type=dbcsr_type_no_symmetry, &
     111             :                         row_blk_size=col_blk_sizes, col_blk_size=row_blk_sizes, &
     112         234 :                         nze=0)
     113         234 :       CALL dbcsr_transposed(Qt, Q)
     114             :       !
     115         234 :       fval = 0.0_dp
     116         234 :       energy = 0.0_dp
     117         468 :       DO ispin = 1, nspins
     118         234 :          CALL get_mo_set(mos(ispin), mo_coeff=mo_coeff)
     119         234 :          CALL get_mo_set(mos_aux_fit(ispin), mo_coeff=mo_coeff_aux_fit)
     120         234 :          CALL cp_fm_get_info(mo_coeff, ncol_global=nmo)
     121         234 :          CALL cp_fm_create(tmp1, matrix_struct=mo_coeff%matrix_struct)
     122         234 :          CALL cp_dbcsr_sm_fm_multiply(Qt, mo_coeff_aux_fit, tmp1, nmo)
     123         234 :          CALL cp_fm_trace(tmp1, mo_coeff, trace)
     124         234 :          fval = fval - 2.0_dp*trace + 2.0_dp*nmo
     125             :          !
     126         234 :          CALL cp_fm_create(tmp2, matrix_struct=mo_coeff%matrix_struct)
     127         234 :          CALL parallel_gemm('N', 'N', norb, nmo, norb, 1.0_dp, S_inv_orb, tmp1, 0.0_dp, tmp2)
     128         234 :          CALL cp_dbcsr_sm_fm_multiply(matrix_ks(ispin)%matrix, tmp2, tmp1, nmo)
     129         234 :          CALL cp_fm_trace(tmp2, tmp1, tmp_energy)
     130         234 :          energy = energy + tmp_energy*(3.0_dp - REAL(nspins, KIND=dp))
     131         234 :          CALL cp_fm_release(tmp1)
     132        1170 :          CALL cp_fm_release(tmp2)
     133             :       END DO
     134         234 :       CALL dbcsr_release(Qt)
     135             : 
     136         702 :       ALLOCATE (smat(1, 1))
     137         234 :       smat(1, 1)%matrix => Snew
     138         234 :       iunit = -1
     139         234 :       CALL cp_fm_get_info(S_inv_orb, context=blacs_env)
     140         234 :       CALL overlap_condnum(smat, condnum, iunit, .FALSE., .TRUE., .FALSE., blacs_env)
     141         234 :       S_cond_number = condnum(2)
     142         234 :       DEALLOCATE (smat)
     143             : 
     144         234 :       CALL timestop(handle)
     145             : 
     146         234 :    END SUBROUTINE evaluate_optvals
     147             : 
     148             : ! **************************************************************************************************
     149             : !> \brief ...
     150             : !> \param saux ...
     151             : !> \param sauxorb ...
     152             : !> \param mos ...
     153             : !> \param mosaux ...
     154             : ! **************************************************************************************************
     155         234 :    SUBROUTINE fit_mo_coeffs(saux, sauxorb, mos, mosaux)
     156             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: saux, sauxorb
     157             :       TYPE(mo_set_type), DIMENSION(:), INTENT(IN)        :: mos, mosaux
     158             : 
     159             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'fit_mo_coeffs'
     160             :       REAL(KIND=dp), PARAMETER                           :: threshold = 1.E-12_dp
     161             : 
     162             :       INTEGER                                            :: handle, ispin, naux, ndep, nmo, norb, &
     163             :                                                             nspins
     164             :       TYPE(cp_fm_struct_type), POINTER                   :: fm_struct
     165             :       TYPE(cp_fm_type)                                   :: fm_s, fm_sinv, tmat, tmp1, tmp2, work
     166             :       TYPE(cp_fm_type), POINTER                          :: mo_coeff, mo_coeff_aux
     167             : 
     168         234 :       CALL timeset(routineN, handle)
     169             : 
     170         234 :       CALL dbcsr_get_info(saux(1)%matrix, nfullrows_total=naux)
     171         234 :       CALL dbcsr_get_info(sauxorb(1)%matrix, nfullcols_total=norb)
     172         234 :       CALL get_mo_set(mos(1), mo_coeff=mo_coeff)
     173             : 
     174             :       CALL cp_fm_struct_create(fm_struct, nrow_global=naux, ncol_global=naux, &
     175             :                                context=mo_coeff%matrix_struct%context, &
     176         234 :                                para_env=mo_coeff%matrix_struct%para_env)
     177         234 :       CALL cp_fm_create(fm_s, fm_struct, name="s_aux")
     178         234 :       CALL cp_fm_create(fm_sinv, fm_struct, name="s_aux_inv")
     179         234 :       CALL copy_dbcsr_to_fm(saux(1)%matrix, fm_s)
     180         234 :       CALL cp_fm_invert(fm_s, fm_sinv)
     181         234 :       CALL cp_fm_release(fm_s)
     182         234 :       CALL cp_fm_struct_release(fm_struct)
     183         234 :       nspins = SIZE(mos)
     184         468 :       DO ispin = 1, nspins
     185         234 :          CALL get_mo_set(mos(ispin), mo_coeff=mo_coeff)
     186         234 :          CALL get_mo_set(mosaux(ispin), mo_coeff=mo_coeff_aux)
     187         234 :          CALL cp_fm_get_info(mo_coeff, ncol_global=nmo)
     188         234 :          CALL cp_fm_create(tmp1, matrix_struct=mo_coeff_aux%matrix_struct)
     189         234 :          CALL cp_fm_create(tmp2, matrix_struct=mo_coeff_aux%matrix_struct)
     190             :          CALL cp_fm_struct_create(fm_struct, nrow_global=nmo, ncol_global=nmo, &
     191             :                                   context=mo_coeff%matrix_struct%context, &
     192         234 :                                   para_env=mo_coeff%matrix_struct%para_env)
     193         234 :          CALL cp_fm_create(tmat, fm_struct, name="tmat")
     194         234 :          CALL cp_fm_create(work, fm_struct, name="work")
     195         234 :          CALL cp_fm_struct_release(fm_struct)
     196             :          !
     197         234 :          CALL cp_dbcsr_sm_fm_multiply(sauxorb(1)%matrix, mo_coeff, tmp1, nmo)
     198         234 :          CALL parallel_gemm('N', 'N', naux, nmo, naux, 1.0_dp, fm_sinv, tmp1, 0.0_dp, tmp2)
     199         234 :          CALL parallel_gemm('T', 'N', nmo, nmo, naux, 1.0_dp, tmp1, tmp2, 0.0_dp, tmat)
     200         234 :          CALL cp_fm_power(tmat, work, -0.5_dp, threshold, ndep)
     201         234 :          CALL parallel_gemm('N', 'N', naux, nmo, nmo, 1.0_dp, tmp2, tmat, 0.0_dp, mo_coeff_aux)
     202             :          !
     203         234 :          CALL cp_fm_release(work)
     204         234 :          CALL cp_fm_release(tmat)
     205         234 :          CALL cp_fm_release(tmp1)
     206         936 :          CALL cp_fm_release(tmp2)
     207             :       END DO
     208         234 :       CALL cp_fm_release(fm_sinv)
     209             : 
     210         234 :       CALL timestop(handle)
     211             : 
     212         234 :    END SUBROUTINE fit_mo_coeffs
     213             : 
     214             : ! **************************************************************************************************
     215             : !> \brief rebuilds neighborlist for absis sets
     216             : !> \param qs_env ...
     217             : !> \param sab_aux ...
     218             : !> \param sab_aux_orb ...
     219             : !> \param basis_type ...
     220             : !> \par History
     221             : !>       adapted from kg_build_neighborlist
     222             : ! **************************************************************************************************
     223         234 :    SUBROUTINE optbas_build_neighborlist(qs_env, sab_aux, sab_aux_orb, basis_type)
     224             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     225             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
     226             :          POINTER                                         :: sab_aux, sab_aux_orb
     227             :       CHARACTER(*)                                       :: basis_type
     228             : 
     229             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'optbas_build_neighborlist'
     230             : 
     231             :       INTEGER                                            :: handle, ikind, nkind
     232             :       LOGICAL                                            :: mic, molecule_only
     233             :       LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: aux_fit_present, orb_present
     234             :       REAL(dp)                                           :: subcells
     235             :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: aux_fit_radius, orb_radius
     236             :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: pair_radius
     237         234 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     238             :       TYPE(cell_type), POINTER                           :: cell
     239             :       TYPE(distribution_1d_type), POINTER                :: distribution_1d
     240             :       TYPE(distribution_2d_type), POINTER                :: distribution_2d
     241             :       TYPE(gto_basis_set_type), POINTER                  :: aux_fit_basis_set, orb_basis_set
     242         234 :       TYPE(local_atoms_type), ALLOCATABLE, DIMENSION(:)  :: atom2d
     243         234 :       TYPE(molecule_type), DIMENSION(:), POINTER         :: molecule_set
     244             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     245         234 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     246         234 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     247             :       TYPE(qs_ks_env_type), POINTER                      :: ks_env
     248             : 
     249         234 :       CALL timeset(routineN, handle)
     250         234 :       NULLIFY (para_env)
     251             : 
     252             :       ! restrict lists to molecular subgroups
     253         234 :       molecule_only = .FALSE.
     254         234 :       mic = molecule_only
     255             : 
     256             :       CALL get_qs_env(qs_env=qs_env, &
     257             :                       ks_env=ks_env, &
     258             :                       atomic_kind_set=atomic_kind_set, &
     259             :                       qs_kind_set=qs_kind_set, &
     260             :                       cell=cell, &
     261             :                       distribution_2d=distribution_2d, &
     262             :                       molecule_set=molecule_set, &
     263             :                       local_particles=distribution_1d, &
     264             :                       particle_set=particle_set, &
     265         234 :                       para_env=para_env)
     266             : 
     267         234 :       CALL section_vals_val_get(qs_env%input, "DFT%SUBCELLS", r_val=subcells)
     268             : 
     269             :       ! Allocate work storage
     270         234 :       nkind = SIZE(atomic_kind_set)
     271         936 :       ALLOCATE (orb_radius(nkind), aux_fit_radius(nkind))
     272         639 :       orb_radius(:) = 0.0_dp
     273         639 :       aux_fit_radius(:) = 0.0_dp
     274         936 :       ALLOCATE (orb_present(nkind), aux_fit_present(nkind))
     275         936 :       ALLOCATE (pair_radius(nkind, nkind))
     276        1107 :       ALLOCATE (atom2d(nkind))
     277             : 
     278             :       CALL atom2d_build(atom2d, distribution_1d, distribution_2d, atomic_kind_set, &
     279         234 :                         molecule_set, molecule_only, particle_set=particle_set)
     280             : 
     281         639 :       DO ikind = 1, nkind
     282         405 :          CALL get_atomic_kind(atomic_kind_set(ikind), atom_list=atom2d(ikind)%list)
     283         405 :          CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis_set, basis_type="ORB")
     284         405 :          IF (ASSOCIATED(orb_basis_set)) THEN
     285         405 :             orb_present(ikind) = .TRUE.
     286         405 :             CALL get_gto_basis_set(gto_basis_set=orb_basis_set, kind_radius=orb_radius(ikind))
     287             :          ELSE
     288           0 :             orb_present(ikind) = .FALSE.
     289           0 :             orb_radius(ikind) = 0.0_dp
     290             :          END IF
     291         405 :          CALL get_qs_kind(qs_kind_set(ikind), basis_set=aux_fit_basis_set, basis_type=basis_type)
     292         639 :          IF (ASSOCIATED(aux_fit_basis_set)) THEN
     293         405 :             aux_fit_present(ikind) = .TRUE.
     294         405 :             CALL get_gto_basis_set(gto_basis_set=aux_fit_basis_set, kind_radius=aux_fit_radius(ikind))
     295             :          ELSE
     296           0 :             aux_fit_present(ikind) = .FALSE.
     297           0 :             aux_fit_radius(ikind) = 0.0_dp
     298             :          END IF
     299             :       END DO
     300             :       !
     301         234 :       CALL pair_radius_setup(aux_fit_present, aux_fit_present, aux_fit_radius, aux_fit_radius, pair_radius)
     302             :       CALL build_neighbor_lists(sab_aux, particle_set, atom2d, cell, pair_radius, &
     303         234 :                                 mic=mic, molecular=molecule_only, subcells=subcells, nlname="sab_aux")
     304         234 :       CALL pair_radius_setup(aux_fit_present, orb_present, aux_fit_radius, orb_radius, pair_radius)
     305             :       CALL build_neighbor_lists(sab_aux_orb, particle_set, atom2d, cell, pair_radius, &
     306             :                                 mic=mic, symmetric=.FALSE., molecular=molecule_only, subcells=subcells, &
     307         234 :                                 nlname="sab_aux_orb")
     308             : 
     309             :       ! Release work storage
     310         234 :       CALL atom2d_cleanup(atom2d)
     311         234 :       DEALLOCATE (atom2d)
     312         234 :       DEALLOCATE (orb_present, aux_fit_present)
     313         234 :       DEALLOCATE (orb_radius, aux_fit_radius)
     314         234 :       DEALLOCATE (pair_radius)
     315             : 
     316         234 :       CALL timestop(handle)
     317             : 
     318         702 :    END SUBROUTINE optbas_build_neighborlist
     319             : 
     320             : END MODULE optbas_opt_utils

Generated by: LCOV version 1.15