LCOV - code coverage report
Current view: top level - src - rpa_exchange.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b4bd748) Lines: 364 373 97.6 %
Date: 2025-03-09 07:56:22 Functions: 9 15 60.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Auxiliary routines needed for RPA-exchange
      10             : !>        given blacs_env to another
      11             : !> \par History
      12             : !>      09.2016 created [Vladimir Rybkin]
      13             : !>      03.2019 Renamed [Frederick Stein]
      14             : !>      03.2019 Moved Functions from rpa_ri_gpw.F [Frederick Stein]
      15             : !>      04.2024 Added open-shell calculations, SOSEX [Frederick Stein]
      16             : !> \author Vladimir Rybkin
      17             : ! **************************************************************************************************
      18             : MODULE rpa_exchange
      19             :    USE atomic_kind_types,               ONLY: atomic_kind_type
      20             :    USE cell_types,                      ONLY: cell_type
      21             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      22             :    USE cp_control_types,                ONLY: dft_control_type
      23             :    USE cp_dbcsr_api,                    ONLY: &
      24             :         dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, dbcsr_multiply, dbcsr_p_type, &
      25             :         dbcsr_release, dbcsr_set, dbcsr_type, dbcsr_type_no_symmetry
      26             :    USE cp_dbcsr_contrib,                ONLY: dbcsr_trace
      27             :    USE cp_dbcsr_operations,             ONLY: dbcsr_allocate_matrix_set
      28             :    USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale
      29             :    USE cp_fm_diag,                      ONLY: choose_eigv_solver
      30             :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      31             :                                               cp_fm_struct_p_type,&
      32             :                                               cp_fm_struct_release
      33             :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      34             :                                               cp_fm_get_info,&
      35             :                                               cp_fm_release,&
      36             :                                               cp_fm_set_all,&
      37             :                                               cp_fm_to_fm,&
      38             :                                               cp_fm_to_fm_submat_general,&
      39             :                                               cp_fm_type
      40             :    USE group_dist_types,                ONLY: create_group_dist,&
      41             :                                               get_group_dist,&
      42             :                                               group_dist_d1_type,&
      43             :                                               group_dist_proc,&
      44             :                                               maxsize,&
      45             :                                               release_group_dist
      46             :    USE hfx_admm_utils,                  ONLY: tddft_hfx_matrix
      47             :    USE hfx_types,                       ONLY: hfx_create,&
      48             :                                               hfx_release,&
      49             :                                               hfx_type
      50             :    USE input_constants,                 ONLY: rpa_exchange_axk,&
      51             :                                               rpa_exchange_none,&
      52             :                                               rpa_exchange_sosex
      53             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      54             :                                               section_vals_type
      55             :    USE kinds,                           ONLY: dp,&
      56             :                                               int_8
      57             :    USE local_gemm_api,                  ONLY: LOCAL_GEMM_PU_GPU
      58             :    USE mathconstants,                   ONLY: sqrthalf
      59             :    USE message_passing,                 ONLY: mp_para_env_type,&
      60             :                                               mp_proc_null
      61             :    USE mp2_types,                       ONLY: mp2_type
      62             :    USE parallel_gemm_api,               ONLY: parallel_gemm
      63             :    USE particle_types,                  ONLY: particle_type
      64             :    USE qs_environment_types,            ONLY: get_qs_env,&
      65             :                                               qs_environment_type
      66             :    USE qs_kind_types,                   ONLY: qs_kind_type
      67             :    USE qs_subsys_types,                 ONLY: qs_subsys_get,&
      68             :                                               qs_subsys_type
      69             :    USE rpa_communication,               ONLY: gamma_fm_to_dbcsr
      70             :    USE rpa_util,                        ONLY: calc_fm_mat_S_rpa,&
      71             :                                               remove_scaling_factor_rpa
      72             :    USE scf_control_types,               ONLY: scf_control_type
      73             : #include "./base/base_uses.f90"
      74             : 
      75             :    IMPLICIT NONE
      76             : 
      77             :    PRIVATE
      78             : 
      79             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_exchange'
      80             : 
      81             :    PUBLIC :: rpa_exchange_work_type, rpa_exchange_needed_mem
      82             : 
      83             :    TYPE rpa_exchange_env_type
      84             :       PRIVATE
      85             :       TYPE(qs_environment_type), POINTER             :: qs_env => NULL()
      86             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: mat_hfx => NULL()
      87             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: dbcsr_Gamma_munu_P => NULL()
      88             :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)    :: dbcsr_Gamma_inu_P
      89             :       ! Workaround GCC 8
      90             :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_o => NULL()
      91             :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_v => NULL()
      92             :       TYPE(dbcsr_type)                               :: work_ao
      93             :       TYPE(hfx_type), DIMENSION(:, :), POINTER       :: x_data => NULL()
      94             :       TYPE(mp_para_env_type), POINTER                :: para_env => NULL()
      95             :       TYPE(section_vals_type), POINTER               :: hfx_sections => NULL()
      96             :       LOGICAL :: my_recalc_hfx_integrals = .FALSE.
      97             :       REAL(KIND=dp) :: eps_filter = 0.0_dp
      98             :       TYPE(cp_fm_struct_p_type), DIMENSION(:), ALLOCATABLE :: struct_Gamma
      99             :    CONTAINS
     100             :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: create => hfx_create_subgroup
     101             :       !PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: integrate => integrate_exchange
     102             :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: release => hfx_release_subgroup
     103             :    END TYPE
     104             : 
     105             :    TYPE dbcsr_matrix_p_set
     106             :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: matrix_set
     107             :    END TYPE
     108             : 
     109             :    TYPE rpa_exchange_work_type
     110             :       PRIVATE
     111             :       INTEGER :: exchange_correction = rpa_exchange_none
     112             :       TYPE(rpa_exchange_env_type) :: exchange_env
     113             :       INTEGER, DIMENSION(:), ALLOCATABLE :: homo, virtual, dimen_ia
     114             :       TYPE(group_dist_d1_type) :: aux_func_dist = group_dist_d1_type()
     115             :       INTEGER, DIMENSION(:), ALLOCATABLE :: aux2send
     116             :       INTEGER :: dimen_RI = 0
     117             :       INTEGER :: block_size = 0
     118             :       INTEGER :: color_sub = 0
     119             :       INTEGER :: ngroup = 0
     120             :       TYPE(cp_fm_type) :: fm_mat_Q_tmp = cp_fm_type()
     121             :       TYPE(cp_fm_type) :: fm_mat_R_half_gemm = cp_fm_type()
     122             :       TYPE(cp_fm_type) :: fm_mat_U = cp_fm_type()
     123             :       TYPE(mp_para_env_type), POINTER :: para_env_sub => NULL()
     124             :    CONTAINS
     125             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: create => rpa_exchange_work_create
     126             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: compute => rpa_exchange_work_compute
     127             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: release => rpa_exchange_work_release
     128             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: redistribute_into_subgroups
     129             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_fm => rpa_exchange_work_compute_fm
     130             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_hfx => rpa_exchange_work_compute_hfx
     131             :    END TYPE
     132             : 
     133             : CONTAINS
     134             : 
     135             : ! **************************************************************************************************
     136             : !> \brief ...
     137             : !> \param mp2_env ...
     138             : !> \param homo ...
     139             : !> \param virtual ...
     140             : !> \param dimen_RI ...
     141             : !> \param para_env ...
     142             : !> \param mem_per_rank ...
     143             : !> \param mem_per_repl ...
     144             : ! **************************************************************************************************
     145         128 :    SUBROUTINE rpa_exchange_needed_mem(mp2_env, homo, virtual, dimen_RI, para_env, mem_per_rank, mem_per_repl)
     146             :       TYPE(mp2_type), INTENT(IN)                         :: mp2_env
     147             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: homo, virtual
     148             :       INTEGER, INTENT(IN)                                :: dimen_RI
     149             :       TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
     150             :       REAL(KIND=dp), INTENT(INOUT)                       :: mem_per_rank, mem_per_repl
     151             : 
     152             :       INTEGER                                            :: block_size
     153             : 
     154             :       ! We need the block size and if it is unknown, an upper bound
     155         128 :       block_size = mp2_env%ri_rpa%exchange_block_size
     156         128 :       IF (block_size <= 0) block_size = MAX(1, (dimen_RI + para_env%num_pe - 1)/para_env%num_pe)
     157             : 
     158             :       ! storage of product matrix (upper bound only as it depends on the square of the potential still unknown block size)
     159         284 :       mem_per_rank = mem_per_rank + REAL(MAXVAL(homo), KIND=dp)**2*block_size**2*8.0_dp/(1024_dp**2)
     160             : 
     161             :       ! work arrays R (2x) and U, copies of Gamma (2x), communication buffer (as expensive as Gamma)
     162             :       mem_per_repl = mem_per_repl + 3.0_dp*dimen_RI*dimen_RI*8.0_dp/(1024_dp**2) &
     163         284 :                      + 3.0_dp*MAXVAL(homo*virtual)*dimen_RI*8.0_dp/(1024_dp**2)
     164         128 :    END SUBROUTINE rpa_exchange_needed_mem
     165             : 
     166             : ! **************************************************************************************************
     167             : !> \brief ...
     168             : !> \param exchange_work ...
     169             : !> \param qs_env ...
     170             : !> \param para_env_sub ...
     171             : !> \param mat_munu ...
     172             : !> \param dimen_RI ...
     173             : !> \param fm_mat_S ...
     174             : !> \param fm_mat_Q ...
     175             : !> \param fm_mat_Q_gemm ...
     176             : !> \param homo ...
     177             : !> \param virtual ...
     178             : ! **************************************************************************************************
     179         132 :    SUBROUTINE rpa_exchange_work_create(exchange_work, qs_env, para_env_sub, mat_munu, dimen_RI, &
     180         132 :                                        fm_mat_S, fm_mat_Q, fm_mat_Q_gemm, homo, virtual)
     181             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     182             :       TYPE(qs_environment_type), POINTER :: qs_env
     183             :       TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
     184             :       TYPE(dbcsr_p_type), INTENT(IN) :: mat_munu
     185             :       INTEGER, INTENT(IN) :: dimen_RI
     186             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     187             :       TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q, fm_mat_Q_gemm
     188             :       INTEGER, DIMENSION(SIZE(fm_mat_S)), INTENT(IN) :: homo, virtual
     189             : 
     190             :       INTEGER :: nspins, aux_global, aux_local, my_process_row, proc, ispin
     191         132 :       INTEGER, DIMENSION(:), POINTER :: row_indices, aux_distribution_fm
     192             :       TYPE(cp_blacs_env_type), POINTER :: context
     193             : 
     194         132 :       exchange_work%exchange_correction = qs_env%mp2_env%ri_rpa%exchange_correction
     195             : 
     196         132 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     197             : 
     198             :       ASSOCIATE (para_env => fm_mat_S(1)%matrix_struct%para_env)
     199          12 :          exchange_work%para_env_sub => para_env_sub
     200          12 :          exchange_work%ngroup = para_env%num_pe/para_env_sub%num_pe
     201          12 :          exchange_work%color_sub = para_env%mepos/para_env_sub%num_pe
     202             :       END ASSOCIATE
     203             : 
     204          12 :       CALL cp_fm_get_info(fm_mat_S(1), row_indices=row_indices, nrow_locals=aux_distribution_fm, context=context)
     205          12 :       CALL context%get(my_process_row=my_process_row)
     206             : 
     207          12 :       CALL create_group_dist(exchange_work%aux_func_dist, exchange_work%ngroup, dimen_RI)
     208          36 :       ALLOCATE (exchange_work%aux2send(0:exchange_work%ngroup - 1))
     209          36 :       exchange_work%aux2send = 0
     210         499 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     211         487 :          aux_global = row_indices(aux_local)
     212         487 :          proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     213         499 :          exchange_work%aux2send(proc) = exchange_work%aux2send(proc) + 1
     214             :       END DO
     215             : 
     216          12 :       nspins = SIZE(fm_mat_S)
     217             : 
     218          60 :       ALLOCATE (exchange_work%homo(nspins), exchange_work%virtual(nspins), exchange_work%dimen_ia(nspins))
     219          26 :       exchange_work%homo(:) = homo
     220          26 :       exchange_work%virtual(:) = virtual
     221          26 :       exchange_work%dimen_ia(:) = homo*virtual
     222          12 :       exchange_work%dimen_RI = dimen_RI
     223             : 
     224          12 :       exchange_work%block_size = qs_env%mp2_env%ri_rpa%exchange_block_size
     225          12 :       IF (exchange_work%block_size <= 0) exchange_work%block_size = dimen_RI
     226             : 
     227          12 :       CALL cp_fm_create(exchange_work%fm_mat_U, fm_mat_Q%matrix_struct, name="fm_mat_U")
     228          12 :       CALL cp_fm_create(exchange_work%fm_mat_Q_tmp, fm_mat_Q%matrix_struct, name="fm_mat_Q_tmp")
     229          12 :       CALL cp_fm_create(exchange_work%fm_mat_R_half_gemm, fm_mat_Q_gemm%matrix_struct)
     230             : 
     231          12 :       IF (qs_env%mp2_env%ri_rpa%use_hfx_implementation) THEN
     232           2 :          CALL exchange_work%exchange_env%create(qs_env, mat_munu%matrix, para_env_sub, fm_mat_S)
     233             :       END IF
     234             : 
     235          12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_o)) THEN
     236          22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_o)
     237          22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_o(ispin))
     238             :          END DO
     239          10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_o)
     240             :       END IF
     241             : 
     242          12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_v)) THEN
     243          22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_v)
     244          22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_v(ispin))
     245             :          END DO
     246          10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_v)
     247             :       END IF
     248         132 :    END SUBROUTINE
     249             : 
     250             : ! **************************************************************************************************
     251             : !> \brief ... Initializes x_data on a subgroup
     252             : !> \param exchange_env ...
     253             : !> \param qs_env ...
     254             : !> \param mat_munu ...
     255             : !> \param para_env_sub ...
     256             : !> \param fm_mat_S ...
     257             : !> \author Vladimir Rybkin
     258             : ! **************************************************************************************************
     259           2 :    SUBROUTINE hfx_create_subgroup(exchange_env, qs_env, mat_munu, para_env_sub, fm_mat_S)
     260             :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     261             :       TYPE(dbcsr_type), INTENT(IN) :: mat_munu
     262             :       TYPE(qs_environment_type), POINTER   :: qs_env
     263             :       TYPE(mp_para_env_type), POINTER, INTENT(IN)            :: para_env_sub
     264             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     265             : 
     266             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_create_subgroup'
     267             : 
     268             :       INTEGER                                            :: handle, nelectron_total, ispin, &
     269             :                                                             number_of_aos, nspins, dimen_RI, dimen_ia
     270           2 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     271             :       TYPE(cell_type), POINTER                           :: my_cell
     272             :       TYPE(dft_control_type), POINTER                    :: dft_control
     273           2 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     274           2 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     275             :       TYPE(qs_subsys_type), POINTER                      :: subsys
     276             :       TYPE(scf_control_type), POINTER                    :: scf_control
     277             :       TYPE(section_vals_type), POINTER                   :: input
     278             : 
     279           2 :       CALL timeset(routineN, handle)
     280             : 
     281           2 :       exchange_env%mo_coeff_o => qs_env%mp2_env%ri_rpa%mo_coeff_o
     282           2 :       exchange_env%mo_coeff_v => qs_env%mp2_env%ri_rpa%mo_coeff_v
     283           2 :       NULLIFY (qs_env%mp2_env%ri_rpa%mo_coeff_o, qs_env%mp2_env%ri_rpa%mo_coeff_v)
     284             : 
     285           2 :       nspins = SIZE(exchange_env%mo_coeff_o)
     286             : 
     287           2 :       exchange_env%qs_env => qs_env
     288           2 :       exchange_env%para_env => para_env_sub
     289           2 :       exchange_env%eps_filter = qs_env%mp2_env%mp2_gpw%eps_filter
     290             : 
     291           2 :       NULLIFY (my_cell, atomic_kind_set, particle_set, dft_control, qs_kind_set, scf_control)
     292             : 
     293             :       CALL get_qs_env(qs_env, &
     294             :                       subsys=subsys, &
     295             :                       input=input, &
     296             :                       scf_control=scf_control, &
     297           2 :                       nelectron_total=nelectron_total)
     298             : 
     299             :       CALL qs_subsys_get(subsys, &
     300             :                          cell=my_cell, &
     301             :                          atomic_kind_set=atomic_kind_set, &
     302             :                          qs_kind_set=qs_kind_set, &
     303           2 :                          particle_set=particle_set)
     304             : 
     305           2 :       exchange_env%hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")
     306           2 :       CALL get_qs_env(qs_env, dft_control=dft_control)
     307             : 
     308             :       ! Retrieve particle_set and atomic_kind_set
     309             :       CALL hfx_create(exchange_env%x_data, para_env_sub, exchange_env%hfx_sections, atomic_kind_set, &
     310             :                       qs_kind_set, particle_set, dft_control, my_cell, orb_basis='ORB', &
     311           2 :                       nelectron_total=nelectron_total)
     312             : 
     313           2 :       exchange_env%my_recalc_hfx_integrals = .TRUE.
     314             : 
     315           2 :       CALL dbcsr_allocate_matrix_set(exchange_env%mat_hfx, nspins)
     316           4 :       DO ispin = 1, nspins
     317           2 :          ALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     318           2 :          CALL dbcsr_init_p(exchange_env%mat_hfx(ispin)%matrix)
     319             :          CALL dbcsr_create(exchange_env%mat_hfx(ispin)%matrix, template=mat_munu, &
     320           2 :                            matrix_type=dbcsr_type_no_symmetry)
     321           4 :          CALL dbcsr_copy(exchange_env%mat_hfx(ispin)%matrix, mat_munu)
     322             :       END DO
     323             : 
     324           2 :       CALL dbcsr_get_info(mat_munu, nfullcols_total=number_of_aos)
     325             : 
     326             :       CALL dbcsr_create(exchange_env%work_ao, template=mat_munu, &
     327           2 :                         matrix_type=dbcsr_type_no_symmetry)
     328             : 
     329           8 :       ALLOCATE (exchange_env%dbcsr_Gamma_inu_P(nspins))
     330           2 :       CALL dbcsr_allocate_matrix_set(exchange_env%dbcsr_Gamma_munu_P, nspins)
     331           4 :       DO ispin = 1, nspins
     332           2 :          ALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     333             :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, template=mat_munu, &
     334           2 :                            matrix_type=dbcsr_type_no_symmetry)
     335           2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, mat_munu)
     336           2 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, 0.0_dp)
     337             : 
     338           2 :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_inu_P(ispin), template=exchange_env%mo_coeff_o(ispin))
     339           2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_inu_P(ispin), exchange_env%mo_coeff_o(ispin))
     340           4 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_inu_P(ispin), 0.0_dp)
     341             :       END DO
     342             : 
     343           8 :       ALLOCATE (exchange_env%struct_Gamma(nspins))
     344           4 :       DO ispin = 1, nspins
     345           2 :          CALL cp_fm_get_info(fm_mat_S(ispin), nrow_global=dimen_RI, ncol_global=dimen_ia)
     346             :          CALL cp_fm_struct_create(exchange_env%struct_Gamma(ispin)%struct, template_fmstruct=fm_mat_S(ispin)%matrix_struct, &
     347           4 :                                   nrow_global=dimen_ia, ncol_global=dimen_RI)
     348             :       END DO
     349             : 
     350           2 :       CALL timestop(handle)
     351             : 
     352           4 :    END SUBROUTINE hfx_create_subgroup
     353             : 
     354             : ! **************************************************************************************************
     355             : !> \brief ...
     356             : !> \param exchange_work ...
     357             : ! **************************************************************************************************
     358         132 :    SUBROUTINE rpa_exchange_work_release(exchange_work)
     359             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     360             : 
     361         132 :       IF (ALLOCATED(exchange_work%homo)) DEALLOCATE (exchange_work%homo)
     362         132 :       IF (ALLOCATED(exchange_work%virtual)) DEALLOCATE (exchange_work%virtual)
     363         132 :       IF (ALLOCATED(exchange_work%dimen_ia)) DEALLOCATE (exchange_work%dimen_ia)
     364         132 :       NULLIFY (exchange_work%para_env_sub)
     365         132 :       CALL release_group_dist(exchange_work%aux_func_dist)
     366         132 :       IF (ALLOCATED(exchange_work%aux2send)) DEALLOCATE (exchange_work%aux2send)
     367         132 :       CALL cp_fm_release(exchange_work%fm_mat_Q_tmp)
     368         132 :       CALL cp_fm_release(exchange_work%fm_mat_U)
     369         132 :       CALL cp_fm_release(exchange_work%fm_mat_R_half_gemm)
     370             : 
     371         132 :       CALL exchange_work%exchange_env%release()
     372         132 :    END SUBROUTINE
     373             : 
     374             : ! **************************************************************************************************
     375             : !> \brief ...
     376             : !> \param exchange_env ...
     377             : ! **************************************************************************************************
     378         132 :    SUBROUTINE hfx_release_subgroup(exchange_env)
     379             :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     380             : 
     381             :       INTEGER :: ispin
     382             : 
     383         132 :       NULLIFY (exchange_env%para_env, exchange_env%hfx_sections)
     384             : 
     385         132 :       IF (ASSOCIATED(exchange_env%x_data)) THEN
     386           2 :          CALL hfx_release(exchange_env%x_data)
     387           2 :          NULLIFY (exchange_env%x_data)
     388             :       END IF
     389             : 
     390         132 :       CALL dbcsr_release(exchange_env%work_ao)
     391             : 
     392         132 :       IF (ASSOCIATED(exchange_env%dbcsr_Gamma_munu_P)) THEN
     393           4 :          DO ispin = 1, SIZE(exchange_env%mat_hfx, 1)
     394           2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     395           2 :             CALL dbcsr_release(exchange_env%mat_hfx(ispin)%matrix)
     396           2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_inu_P(ispin))
     397           2 :             CALL dbcsr_release(exchange_env%mo_coeff_o(ispin))
     398           2 :             CALL dbcsr_release(exchange_env%mo_coeff_v(ispin))
     399           2 :             DEALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     400           4 :             DEALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     401             :          END DO
     402           2 :          DEALLOCATE (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     403           2 :          DEALLOCATE (exchange_env%dbcsr_Gamma_inu_P, exchange_env%mo_coeff_o, exchange_env%mo_coeff_v)
     404           2 :          NULLIFY (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     405             :       END IF
     406         132 :       IF (ALLOCATED(exchange_env%struct_Gamma)) THEN
     407           4 :       DO ispin = 1, SIZE(exchange_env%struct_Gamma)
     408           4 :          CALL cp_fm_struct_release(exchange_env%struct_Gamma(ispin)%struct)
     409             :       END DO
     410           2 :       DEALLOCATE (exchange_env%struct_Gamma)
     411             :       END IF
     412         132 :    END SUBROUTINE hfx_release_subgroup
     413             : 
     414             : ! **************************************************************************************************
     415             : !> \brief Main driver for RPA-exchange energies
     416             : !> \param exchange_work ...
     417             : !> \param fm_mat_Q ...
     418             : !> \param eig ...
     419             : !> \param fm_mat_S ...
     420             : !> \param omega ...
     421             : !> \param e_exchange_corr exchange energy correction for a quadrature point
     422             : !> \param mp2_env ...
     423             : !> \author Vladimir Rybkin, 07/2016
     424             : ! **************************************************************************************************
     425          12 :    SUBROUTINE rpa_exchange_work_compute(exchange_work, fm_mat_Q, eig, fm_mat_S, omega, &
     426             :                                         e_exchange_corr, mp2_env)
     427             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     428             :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat_Q
     429             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     430             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT)         :: fm_mat_S
     431             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     432             :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     433             :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     434             : 
     435             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute'
     436             :       REAL(KIND=dp), PARAMETER                           :: thresh = 0.0000001_dp
     437             : 
     438             :       INTEGER :: handle, nspins, dimen_RI, iiB
     439          12 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigenval
     440             : 
     441          12 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     442             : 
     443          12 :       CALL timeset(routineN, handle)
     444             : 
     445          12 :       CALL cp_fm_get_info(fm_mat_Q, ncol_global=dimen_RI)
     446             : 
     447          12 :       nspins = SIZE(fm_mat_S)
     448             : 
     449             :       ! Eigenvalues
     450          36 :       ALLOCATE (eigenval(dimen_RI))
     451         986 :       eigenval = 0.0_dp
     452             : 
     453          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_Q_tmp, alpha=0.0_dp)
     454          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_U, alpha=0.0_dp)
     455             : 
     456             :       ! Copy Q to Q_tmp
     457          12 :       CALL cp_fm_to_fm(fm_mat_Q, exchange_work%fm_mat_Q_tmp)
     458             :       ! Diagonalize Q
     459          12 :       CALL choose_eigv_solver(exchange_work%fm_mat_Q_tmp, exchange_work%fm_mat_U, eigenval)
     460             : 
     461             :       ! Calculate diagonal matrix for R_half
     462             : 
     463             :       ! Manipulate eigenvalues to get diagonal matrix
     464          12 :       IF (exchange_work%exchange_correction == rpa_exchange_axk) THEN
     465         818 :          DO iib = 1, dimen_RI
     466         818 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     467             :                eigenval(iib) = &
     468             :                   SQRT((1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     469         718 :                        - 1.0_dp/(eigenval(iib)*(eigenval(iib) + 1.0_dp)))
     470             :             ELSE
     471          90 :                eigenval(iib) = sqrthalf
     472             :             END IF
     473             :          END DO
     474           2 :       ELSE IF (exchange_work%exchange_correction == rpa_exchange_sosex) THEN
     475         168 :          DO iib = 1, dimen_RI
     476         168 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     477             :                eigenval(iib) = &
     478             :                   SQRT(-(1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     479         144 :                        + 1.0_dp/eigenval(iib))
     480             :             ELSE
     481          22 :                eigenval(iib) = sqrthalf
     482             :             END IF
     483             :          END DO
     484             :       ELSE
     485           0 :          CPABORT("Unknown RPA exchange correction")
     486             :       END IF
     487             : 
     488             :       ! fm_mat_U now contains some sqrt of the required matrix-valued function
     489          12 :       CALL cp_fm_column_scale(exchange_work%fm_mat_U, eigenval)
     490             : 
     491             :       ! Release memory
     492          12 :       DEALLOCATE (eigenval)
     493             : 
     494             :       ! Redistribute fm_mat_U for "rectangular" multiplication: ia*P P*P
     495          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_R_half_gemm, alpha=0.0_dp)
     496             : 
     497             :       CALL cp_fm_to_fm_submat_general(exchange_work%fm_mat_U, exchange_work%fm_mat_R_half_gemm, dimen_RI, &
     498          12 :                                       dimen_RI, 1, 1, 1, 1, exchange_work%fm_mat_U%matrix_struct%context)
     499             : 
     500          12 :       IF (mp2_env%ri_rpa%use_hfx_implementation) THEN
     501           2 :          CALL exchange_work%compute_hfx(fm_mat_S, eig, omega, e_exchange_corr)
     502             :       ELSE
     503          10 :          CALL exchange_work%compute_fm(fm_mat_S, eig, omega, e_exchange_corr, mp2_env)
     504             :       END IF
     505             : 
     506          12 :       CALL timestop(handle)
     507             : 
     508          12 :    END SUBROUTINE rpa_exchange_work_compute
     509             : 
     510             : ! **************************************************************************************************
     511             : !> \brief Main driver for RPA-exchange energies
     512             : !> \param exchange_work ...
     513             : !> \param fm_mat_S ...
     514             : !> \param eig ...
     515             : !> \param omega ...
     516             : !> \param e_exchange_corr exchange energy correction for a quadrature point
     517             : !> \param mp2_env ...
     518             : !> \author Frederick Stein, May-June 2024
     519             : ! **************************************************************************************************
     520          10 :    SUBROUTINE rpa_exchange_work_compute_fm(exchange_work, fm_mat_S, eig, omega, &
     521             :                                            e_exchange_corr, mp2_env)
     522             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     523             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN)         :: fm_mat_S
     524             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     525             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     526             :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     527             :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     528             : 
     529             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute_fm'
     530             : 
     531             :       INTEGER :: handle, ispin, nspins, P, Q, L_size_Gamma, hom, virt, i, &
     532             :                  send_proc, recv_proc, recv_size, max_aux_size, proc_shift, dimen_ia, &
     533             :                  block_size, P_start, P_end, P_size, Q_start, Q_size, Q_end, handle2, my_aux_size, my_virt
     534          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET :: mat_Gamma_3_3D
     535          10 :       REAL(KIND=dp), POINTER, DIMENSION(:), CONTIGUOUS :: mat_Gamma_3_1D
     536          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: mat_Gamma_3_2D
     537          10 :       REAL(KIND=dp), ALLOCATABLE, TARGET, DIMENSION(:) :: recv_buffer_1D
     538          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: recv_buffer_2D
     539          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :), CONTIGUOUS :: recv_buffer_3D
     540          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: mat_B_iaP
     541          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: product_matrix_1D
     542          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: product_matrix_2D
     543          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :, :), CONTIGUOUS :: product_matrix_4D
     544             :       TYPE(cp_fm_type)        :: fm_mat_Gamma_3
     545             :       TYPE(mp_para_env_type), POINTER :: para_env
     546          10 :       TYPE(group_dist_d1_type)                           :: virt_dist
     547             : 
     548          10 :       CALL timeset(routineN, handle)
     549             : 
     550          10 :       nspins = SIZE(fm_mat_S)
     551             : 
     552          10 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, sizes=my_aux_size)
     553             : 
     554          10 :       e_exchange_corr = 0.0_dp
     555          10 :       max_aux_size = maxsize(exchange_work%aux_func_dist)
     556             : 
     557             :       ! local_gemm_ctx has a very large footprint the first time this routine is
     558             :       ! called.
     559          10 :       CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
     560          10 :       CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)
     561             : 
     562          22 :       DO ispin = 1, nspins
     563          12 :          hom = exchange_work%homo(ispin)
     564          12 :          virt = exchange_work%virtual(ispin)
     565          12 :          dimen_ia = hom*virt
     566          12 :          IF (hom < 1 .OR. virt < 1) CYCLE
     567             : 
     568          12 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     569             : 
     570          12 :          CALL cp_fm_create(fm_mat_Gamma_3, fm_mat_S(ispin)%matrix_struct)
     571          12 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     572             : 
     573             :          ! Update G with a new value of Omega: in practice, it is G*S
     574             : 
     575             :          ! Scale fm_work_iaP
     576             :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     577          12 :                                 hom, omega, 0.0_dp)
     578             : 
     579             :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     580             :          CALL parallel_gemm(transa="T", transb="N", m=exchange_work%dimen_RI, n=dimen_ia, k=exchange_work%dimen_RI, alpha=1.0_dp, &
     581             :                             matrix_a=exchange_work%fm_mat_R_half_gemm, matrix_b=fm_mat_S(ispin), beta=0.0_dp, &
     582          12 :                             matrix_c=fm_mat_Gamma_3)
     583             : 
     584          12 :          CALL create_group_dist(virt_dist, exchange_work%para_env_sub%num_pe, virt)
     585             : 
     586             :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     587          12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     588             : 
     589          12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_Gamma_3, mat_Gamma_3_3D, ispin, virt_dist)
     590          12 :          CALL cp_fm_release(fm_mat_Gamma_3)
     591             : 
     592             :          ! We need only the pure matrix
     593          12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     594             : 
     595             :          ! Reorder matrix from (P, i*a) -> (a, i, P) with P being distributed within subgroups
     596          12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_S(ispin), mat_B_iaP, ispin, virt_dist)
     597             : 
     598             :          ! Return to the original tensor
     599          12 :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), hom, omega, 0.0_dp)
     600             : 
     601          12 :          L_size_Gamma = SIZE(mat_Gamma_3_3D, 3)
     602          12 :          my_virt = SIZE(mat_Gamma_3_3D, 1)
     603          12 :          block_size = exchange_work%block_size
     604             : 
     605          12 :          mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size) => mat_Gamma_3_3D(:, :, 1:my_aux_size)
     606          12 :          mat_Gamma_3_2D(1:my_virt, 1:hom*my_aux_size) => mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size)
     607             : 
     608           0 :          ALLOCATE (product_matrix_1D(INT(hom*MIN(block_size, L_size_gamma), KIND=int_8)* &
     609          36 :                                      INT(hom*MIN(block_size, max_aux_size), KIND=int_8)))
     610          36 :          ALLOCATE (recv_buffer_1D(INT(virt, KIND=int_8)*hom*max_aux_size))
     611          12 :          recv_buffer_2D(1:my_virt, 1:hom*max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     612          12 :          recv_buffer_3D(1:my_virt, 1:hom, 1:max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     613          36 :          DO proc_shift = 0, para_env%num_pe - 1, exchange_work%para_env_sub%num_pe
     614          24 :             send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     615          24 :             recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     616             : 
     617          24 :             CALL get_group_dist(exchange_work%aux_func_dist, recv_proc/exchange_work%para_env_sub%num_pe, sizes=recv_size)
     618             : 
     619          24 :             IF (recv_size == 0) recv_proc = mp_proc_null
     620             : 
     621          24 :             CALL para_env%sendrecv(mat_B_iaP, send_proc, recv_buffer_3D(:, :, 1:recv_size), recv_proc)
     622             : 
     623          24 :             IF (recv_size == 0) CYCLE
     624             : 
     625        1038 :             DO P_start = 1, L_size_Gamma, block_size
     626        1002 :                P_end = MIN(L_size_Gamma, P_start + block_size - 1)
     627        1002 :                P_size = P_end - P_start + 1
     628       43875 :                DO Q_start = 1, recv_size, block_size
     629       42849 :                   Q_end = MIN(recv_size, Q_start + block_size - 1)
     630       42849 :                   Q_size = Q_end - Q_start + 1
     631             : 
     632             :                   ! Reassign product_matrix pointers to enforce contiguity of target array
     633             :                   product_matrix_2D(1:hom*P_size, 1:hom*Q_size) => &
     634       42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     635             :                   product_matrix_4D(1:hom, 1:P_size, 1:hom, 1:Q_size) => &
     636       42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     637             : 
     638       42849 :                   CALL timeset(routineN//"_gemm", handle2)
     639             :                   CALL mp2_env%local_gemm_ctx%gemm("T", "N", hom*P_size, hom*Q_size, my_virt, 1.0_dp, &
     640             :                                                    mat_Gamma_3_2D(:, hom*(P_start - 1) + 1:hom*P_end), my_virt, &
     641             :                                                    recv_buffer_2D(:, hom*(Q_start - 1) + 1:hom*Q_end), my_virt, &
     642       42849 :                                                    0.0_dp, product_matrix_2D, hom*P_size)
     643       42849 :                   CALL timestop(handle2)
     644             : 
     645       42849 :                   CALL timeset(routineN//"_energy", handle2)
     646             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(P_size, Q_size, hom, product_matrix_4D) &
     647       42849 : !$OMP             COLLAPSE(3) REDUCTION(+: e_exchange_corr) PRIVATE(P, Q, i)
     648             :                   DO P = 1, P_size
     649             :                   DO Q = 1, Q_size
     650             :                   DO i = 1, hom
     651             :                      e_exchange_corr = e_exchange_corr + DOT_PRODUCT(product_matrix_4D(i, P, :, Q), product_matrix_4D(:, P, i, Q))
     652             :                   END DO
     653             :                   END DO
     654             :                   END DO
     655       86700 :                   CALL timestop(handle2)
     656             :                END DO
     657             :             END DO
     658             :          END DO
     659             : 
     660          12 :          CALL release_group_dist(virt_dist)
     661          12 :          IF (ALLOCATED(mat_B_iaP)) DEALLOCATE (mat_B_iaP)
     662          12 :          IF (ALLOCATED(mat_Gamma_3_3D)) DEALLOCATE (mat_Gamma_3_3D)
     663          12 :          IF (ALLOCATED(product_matrix_1D)) DEALLOCATE (product_matrix_1D)
     664          58 :          IF (ALLOCATED(recv_buffer_1D)) DEALLOCATE (recv_buffer_1D)
     665             :       END DO
     666             : 
     667          10 :       CALL mp2_env%local_gemm_ctx%destroy()
     668             : 
     669          10 :       IF (nspins == 2) e_exchange_corr = e_exchange_corr*2.0_dp
     670          10 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     671             : 
     672          10 :       CALL timestop(handle)
     673             : 
     674          20 :    END SUBROUTINE rpa_exchange_work_compute_fm
     675             : 
     676             : ! **************************************************************************************************
     677             : !> \brief Contract RPA-exchange density matrix with HF exchange integrals and evaluate the correction
     678             : !> \param exchange_work ...
     679             : !> \param fm_mat_S ...
     680             : !> \param eig ...
     681             : !> \param omega ...
     682             : !> \param e_exchange_corr ...
     683             : !> \author Vladimir Rybkin, 08/2016
     684             : ! **************************************************************************************************
     685           2 :    SUBROUTINE rpa_exchange_work_compute_hfx(exchange_work, fm_mat_S, eig, omega, e_exchange_corr)
     686             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     687             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
     688             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     689             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     690             :       REAL(KIND=dp), INTENT(OUT) :: e_exchange_corr
     691             : 
     692             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_hfx'
     693             : 
     694             :       INTEGER                                            :: handle, ispin, my_aux_start, my_aux_end, &
     695             :                                                             my_aux_size, nspins, L_counter, dimen_ia, hom, virt
     696             :       REAL(KIND=dp)                                      :: e_exchange_P
     697           2 :       TYPE(dbcsr_matrix_p_set), DIMENSION(:), ALLOCATABLE          :: dbcsr_Gamma_3
     698             :       TYPE(cp_fm_type) :: fm_mat_Gamma_3
     699             :       TYPE(mp_para_env_type), POINTER :: para_env
     700             : 
     701           2 :       CALL timeset(routineN, handle)
     702             : 
     703           2 :       e_exchange_corr = 0.0_dp
     704             : 
     705           2 :       nspins = SIZE(fm_mat_S)
     706             : 
     707           2 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     708             : 
     709           8 :       ALLOCATE (dbcsr_Gamma_3(nspins))
     710           4 :       DO ispin = 1, nspins
     711           2 :          hom = exchange_work%homo(ispin)
     712           2 :          virt = exchange_work%virtual(ispin)
     713           2 :          dimen_ia = hom*virt
     714           2 :          IF (hom < 1 .OR. virt < 1) CYCLE
     715             : 
     716           2 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     717             : 
     718           2 :          CALL cp_fm_create(fm_mat_Gamma_3, exchange_work%exchange_env%struct_Gamma(ispin)%struct)
     719           2 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     720             : 
     721             :          ! Update G with a new value of Omega: in practice, it is G*S
     722             : 
     723             :          ! Scale fm_work_iaP
     724             :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     725           2 :                                 hom, omega, 0.0_dp)
     726             : 
     727             :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     728             :          CALL parallel_gemm(transa="T", transb="N", m=dimen_ia, n=exchange_work%dimen_RI, &
     729             :                             k=exchange_work%dimen_RI, alpha=1.0_dp, &
     730             :                             matrix_a=fm_mat_S(ispin), matrix_b=exchange_work%fm_mat_R_half_gemm, beta=0.0_dp, &
     731           2 :                             matrix_c=fm_mat_Gamma_3)
     732             : 
     733             :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     734           2 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     735             : 
     736             :          ! Copy Gamma_ia_P^3 to dbcsr matrix set
     737             :          CALL gamma_fm_to_dbcsr(fm_mat_Gamma_3, dbcsr_Gamma_3(ispin)%matrix_set, &
     738             :                                 para_env, exchange_work%para_env_sub, hom, virt, &
     739             :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     740           6 :                                 exchange_work%ngroup, my_aux_start, my_aux_end, my_aux_size)
     741             :       END DO
     742             : 
     743          85 :       DO L_counter = 1, my_aux_size
     744         166 :          DO ispin = 1, nspins
     745             :             ! Do dbcsr multiplication: transform the virtual index
     746             :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mo_coeff_v(ispin), &
     747             :                                 dbcsr_Gamma_3(ispin)%matrix_set(L_counter), &
     748             :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     749          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     750             : 
     751          83 :             CALL dbcsr_release(dbcsr_Gamma_3(ispin)%matrix_set(L_counter))
     752             : 
     753             :             ! Do dbcsr multiplication: transform the occupied index
     754             :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     755             :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     756             :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     757          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     758             :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%mo_coeff_o(ispin), &
     759             :                                 exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     760             :                                 1.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     761          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     762             : 
     763         166 :             CALL dbcsr_set(exchange_work%exchange_env%mat_hfx(ispin)%matrix, 0.0_dp)
     764             :          END DO
     765             : 
     766             :          CALL tddft_hfx_matrix(exchange_work%exchange_env%mat_hfx, exchange_work%exchange_env%dbcsr_Gamma_munu_P, &
     767             :                                exchange_work%exchange_env%qs_env, .FALSE., &
     768             :                                exchange_work%exchange_env%my_recalc_hfx_integrals, &
     769             :                                exchange_work%exchange_env%hfx_sections, exchange_work%exchange_env%x_data, &
     770          83 :                                exchange_work%exchange_env%para_env)
     771             : 
     772          83 :          exchange_work%exchange_env%my_recalc_hfx_integrals = .FALSE.
     773         168 :          DO ispin = 1, nspins
     774             :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mat_hfx(ispin)%matrix, &
     775             :                                 exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     776          83 :                                 0.0_dp, exchange_work%exchange_env%work_ao, filter_eps=exchange_work%exchange_env%eps_filter)
     777          83 :             CALL dbcsr_trace(exchange_work%exchange_env%work_ao, e_exchange_P)
     778         166 :             e_exchange_corr = e_exchange_corr - e_exchange_P
     779             :          END DO
     780             :       END DO
     781             : 
     782             :       IF (nspins == 2) e_exchange_corr = e_exchange_corr
     783           2 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     784             : 
     785           2 :       CALL timestop(handle)
     786             : 
     787           6 :    END SUBROUTINE rpa_exchange_work_compute_hfx
     788             : 
     789             : ! **************************************************************************************************
     790             : !> \brief ...
     791             : !> \param exchange_work ...
     792             : !> \param fm_mat ...
     793             : !> \param mat ...
     794             : !> \param ispin ...
     795             : !> \param virt_dist ...
     796             : ! **************************************************************************************************
     797          24 :    SUBROUTINE redistribute_into_subgroups(exchange_work, fm_mat, mat, ispin, virt_dist)
     798             :       CLASS(rpa_exchange_work_type), INTENT(IN) :: exchange_work
     799             :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat
     800             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     801             :          INTENT(OUT)                                     :: mat
     802             :       INTEGER, INTENT(IN)                                :: ispin
     803             :       TYPE(group_dist_d1_type), INTENT(IN)               :: virt_dist
     804             : 
     805             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'redistribute_into_subgroups'
     806             : 
     807             :       INTEGER :: aux_counter, aux_global, aux_local, aux_proc, avirt, dimen_RI, handle, handle2, &
     808             :                  ia_global, ia_local, iocc, max_number_recv, max_number_send, my_aux_end, my_aux_size, &
     809             :                  my_aux_start, my_process_column, my_process_row, my_virt_end, my_virt_size, &
     810             :                  my_virt_start, proc, proc_shift, recv_proc, send_proc, virt_counter, virt_proc, group_size
     811          24 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: data2send, recv_col_indices, &
     812          24 :                                             recv_row_indices, send_aux_indices, send_virt_indices, virt2send
     813             :       INTEGER, DIMENSION(2)                              :: recv_shape
     814          24 :       INTEGER, DIMENSION(:), POINTER                     :: aux_distribution_fm, col_indices, &
     815          24 :                                                             ia_distribution_fm, row_indices
     816          24 :       INTEGER, DIMENSION(:, :), POINTER                  :: mpi2blacs
     817          24 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET   :: recv_buffer, send_buffer
     818             :       REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
     819          24 :          POINTER                                         :: recv_ptr, send_ptr
     820             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     821             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     822             : 
     823          24 :       CALL timeset(routineN, handle)
     824             : 
     825             :       CALL cp_fm_get_info(matrix=fm_mat, &
     826             :                           nrow_locals=aux_distribution_fm, &
     827             :                           col_indices=col_indices, &
     828             :                           row_indices=row_indices, &
     829             :                           ncol_locals=ia_distribution_fm, &
     830             :                           context=context, &
     831             :                           nrow_global=dimen_RI, &
     832          24 :                           para_env=para_env)
     833             : 
     834          24 :       IF (exchange_work%homo(ispin) <= 0 .OR. exchange_work%virtual(ispin) <= 0) THEN
     835           0 :          CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     836           0 :          ALLOCATE (mat(exchange_work%homo(ispin), my_virt_size, dimen_RI))
     837           0 :          CALL timestop(handle)
     838           0 :          RETURN
     839             :       END IF
     840             : 
     841          24 :       group_size = exchange_work%para_env_sub%num_pe
     842             : 
     843          24 :       CALL timeset(routineN//"_prep", handle2)
     844          24 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     845          24 :       CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     846          24 :       CALL context%get(my_process_column=my_process_column, my_process_row=my_process_row, mpi2blacs=mpi2blacs)
     847             : 
     848             :       ! Determine the number of columns to send
     849         120 :       ALLOCATE (send_aux_indices(MAXVAL(exchange_work%aux2send)))
     850          72 :       ALLOCATE (virt2send(0:group_size - 1))
     851          48 :       virt2send = 0
     852        1924 :       DO ia_local = 1, ia_distribution_fm(my_process_column)
     853        1900 :          ia_global = col_indices(ia_local)
     854        1900 :          avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     855        1900 :          proc = group_dist_proc(virt_dist, avirt)
     856        1924 :          virt2send(proc) = virt2send(proc) + 1
     857             :       END DO
     858             : 
     859          72 :       ALLOCATE (data2send(0:para_env%num_pe - 1))
     860          72 :       DO aux_proc = 0, exchange_work%ngroup - 1
     861         120 :       DO virt_proc = 0, group_size - 1
     862          96 :          data2send(aux_proc*group_size + virt_proc) = exchange_work%aux2send(aux_proc)*virt2send(virt_proc)
     863             :       END DO
     864             :       END DO
     865             : 
     866          96 :       ALLOCATE (send_virt_indices(MAXVAL(virt2send)))
     867          72 :       max_number_send = MAXVAL(data2send)
     868             : 
     869          72 :       ALLOCATE (send_buffer(INT(max_number_send, KIND=int_8)*exchange_work%homo(ispin)))
     870          24 :       max_number_recv = max_number_send
     871          24 :       CALL para_env%max(max_number_recv)
     872          72 :       ALLOCATE (recv_buffer(max_number_recv))
     873             : 
     874         120 :       ALLOCATE (mat(my_virt_size, exchange_work%homo(ispin), my_aux_size))
     875             : 
     876          24 :       CALL timestop(handle2)
     877             : 
     878          24 :       CALL timeset(routineN//"_own", handle2)
     879             :       ! Start with own data
     880        1026 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     881        1002 :          aux_global = row_indices(aux_local)
     882        1002 :          IF (aux_global < my_aux_start .OR. aux_global > my_aux_end) CYCLE
     883       40848 :          DO ia_local = 1, ia_distribution_fm(my_process_column)
     884       40318 :             ia_global = fm_mat%matrix_struct%col_indices(ia_local)
     885             : 
     886       40318 :             iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     887       40318 :             avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     888             : 
     889       40318 :             IF (my_virt_start > avirt .OR. my_virt_end < avirt) CYCLE
     890             : 
     891       41320 :             mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = fm_mat%local_data(aux_local, ia_local)
     892             :          END DO
     893             :       END DO
     894          24 :       CALL timestop(handle2)
     895             : 
     896          48 :       DO proc_shift = 1, para_env%num_pe - 1
     897          24 :          send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     898          24 :          recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     899             : 
     900          24 :          CALL timeset(routineN//"_pack_buffer", handle2)
     901             :          send_ptr(1:virt2send(MOD(send_proc, group_size)), &
     902             :                   1:exchange_work%aux2send(send_proc/group_size)) => &
     903             :             send_buffer(1:INT(virt2send(MOD(send_proc, group_size)), KIND=int_8)* &
     904          24 :                         exchange_work%aux2send(send_proc/group_size))
     905             : ! Pack send buffer
     906          24 :          aux_counter = 0
     907        1026 :          DO aux_local = 1, aux_distribution_fm(my_process_row)
     908        1002 :             aux_global = row_indices(aux_local)
     909        1002 :             proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     910        1002 :             IF (proc /= send_proc/group_size) CYCLE
     911         496 :             aux_counter = aux_counter + 1
     912         496 :             virt_counter = 0
     913       40016 :             DO ia_local = 1, ia_distribution_fm(my_process_column)
     914       39520 :                ia_global = col_indices(ia_local)
     915       39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     916             : 
     917       39520 :                proc = group_dist_proc(virt_dist, avirt)
     918       39520 :                IF (proc /= MOD(send_proc, group_size)) CYCLE
     919       39520 :                virt_counter = virt_counter + 1
     920       39520 :                send_ptr(virt_counter, aux_counter) = fm_mat%local_data(aux_local, ia_local)
     921       40016 :                send_virt_indices(virt_counter) = ia_global
     922             :             END DO
     923        1026 :             send_aux_indices(aux_counter) = aux_global
     924             :          END DO
     925          24 :          CALL timestop(handle2)
     926             : 
     927          24 :          CALL timeset(routineN//"_ex_size", handle2)
     928          24 :          recv_shape = [1, 1]
     929          72 :          CALL para_env%sendrecv(SHAPE(send_ptr), send_proc, recv_shape, recv_proc)
     930          24 :          CALL timestop(handle2)
     931             : 
     932          72 :          IF (SIZE(send_ptr) == 0) send_proc = mp_proc_null
     933          72 :          IF (PRODUCT(recv_shape) == 0) recv_proc = mp_proc_null
     934             : 
     935          24 :          CALL timeset(routineN//"_ex_idx", handle2)
     936         120 :          ALLOCATE (recv_row_indices(recv_shape(1)), recv_col_indices(recv_shape(2)))
     937          24 :          CALL para_env%sendrecv(send_virt_indices(1:virt_counter), send_proc, recv_row_indices, recv_proc)
     938          24 :          CALL para_env%sendrecv(send_aux_indices(1:aux_counter), send_proc, recv_col_indices, recv_proc)
     939          24 :          CALL timestop(handle2)
     940             : 
     941             :          ! Prepare pointer to recv buffer (consider transposition while packing the send buffer)
     942          24 :          recv_ptr(1:recv_shape(1), 1:MAX(1, recv_shape(2))) => recv_buffer(1:recv_shape(1)*MAX(1, recv_shape(2)))
     943             : 
     944          24 :          CALL timeset(routineN//"_sendrecv", handle2)
     945             : ! Perform communication
     946          24 :          CALL para_env%sendrecv(send_ptr, send_proc, recv_ptr, recv_proc)
     947          24 :          CALL timestop(handle2)
     948             : 
     949          24 :          IF (recv_proc == mp_proc_null) THEN
     950           0 :             DEALLOCATE (recv_row_indices, recv_col_indices)
     951           0 :             CYCLE
     952             :          END IF
     953             : 
     954          24 :          CALL timeset(routineN//"_unpack", handle2)
     955             : ! Unpack receive buffer
     956         520 :          DO aux_local = 1, SIZE(recv_col_indices)
     957         496 :             aux_global = recv_col_indices(aux_local)
     958             : 
     959       40040 :             DO ia_local = 1, SIZE(recv_row_indices)
     960       39520 :                ia_global = recv_row_indices(ia_local)
     961             : 
     962       39520 :                iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     963       39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     964             : 
     965       40016 :                mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = recv_ptr(ia_local, aux_local)
     966             :             END DO
     967             :          END DO
     968          24 :          CALL timestop(handle2)
     969             : 
     970          24 :          IF (ALLOCATED(recv_row_indices)) DEALLOCATE (recv_row_indices)
     971         168 :          IF (ALLOCATED(recv_col_indices)) DEALLOCATE (recv_col_indices)
     972             :       END DO
     973             : 
     974          24 :       DEALLOCATE (send_aux_indices, send_virt_indices)
     975             : 
     976          24 :       CALL timestop(handle)
     977             : 
     978          96 :    END SUBROUTINE redistribute_into_subgroups
     979             : 
     980           0 : END MODULE rpa_exchange

Generated by: LCOV version 1.15