LCOV - code coverage report
Current view: top level - src - rpa_exchange.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b8e0b09) Lines: 364 373 97.6 %
Date: 2024-08-31 06:31:37 Functions: 9 15 60.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Auxiliary routines needed for RPA-exchange
      10             : !>        given blacs_env to another
      11             : !> \par History
      12             : !>      09.2016 created [Vladimir Rybkin]
      13             : !>      03.2019 Renamed [Frederick Stein]
      14             : !>      03.2019 Moved Functions from rpa_ri_gpw.F [Frederick Stein]
      15             : !>      04.2024 Added open-shell calculations, SOSEX [Frederick Stein]
      16             : !> \author Vladimir Rybkin
      17             : ! **************************************************************************************************
      18             : MODULE rpa_exchange
      19             :    USE atomic_kind_types,               ONLY: atomic_kind_type
      20             :    USE cell_types,                      ONLY: cell_type
      21             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      22             :    USE cp_control_types,                ONLY: dft_control_type
      23             :    USE cp_dbcsr_api,                    ONLY: &
      24             :         dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, dbcsr_multiply, dbcsr_p_type, &
      25             :         dbcsr_release, dbcsr_set, dbcsr_trace, dbcsr_type, dbcsr_type_no_symmetry
      26             :    USE cp_dbcsr_operations,             ONLY: dbcsr_allocate_matrix_set
      27             :    USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale
      28             :    USE cp_fm_diag,                      ONLY: choose_eigv_solver
      29             :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      30             :                                               cp_fm_struct_p_type,&
      31             :                                               cp_fm_struct_release
      32             :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      33             :                                               cp_fm_get_info,&
      34             :                                               cp_fm_release,&
      35             :                                               cp_fm_set_all,&
      36             :                                               cp_fm_to_fm,&
      37             :                                               cp_fm_to_fm_submat_general,&
      38             :                                               cp_fm_type
      39             :    USE group_dist_types,                ONLY: create_group_dist,&
      40             :                                               get_group_dist,&
      41             :                                               group_dist_d1_type,&
      42             :                                               group_dist_proc,&
      43             :                                               maxsize,&
      44             :                                               release_group_dist
      45             :    USE hfx_admm_utils,                  ONLY: tddft_hfx_matrix
      46             :    USE hfx_types,                       ONLY: hfx_create,&
      47             :                                               hfx_release,&
      48             :                                               hfx_type
      49             :    USE input_constants,                 ONLY: rpa_exchange_axk,&
      50             :                                               rpa_exchange_none,&
      51             :                                               rpa_exchange_sosex
      52             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      53             :                                               section_vals_type
      54             :    USE kinds,                           ONLY: dp,&
      55             :                                               int_8
      56             :    USE local_gemm_api,                  ONLY: LOCAL_GEMM_PU_GPU
      57             :    USE mathconstants,                   ONLY: sqrthalf
      58             :    USE message_passing,                 ONLY: mp_para_env_type,&
      59             :                                               mp_proc_null
      60             :    USE mp2_types,                       ONLY: mp2_type
      61             :    USE parallel_gemm_api,               ONLY: parallel_gemm
      62             :    USE particle_types,                  ONLY: particle_type
      63             :    USE qs_environment_types,            ONLY: get_qs_env,&
      64             :                                               qs_environment_type
      65             :    USE qs_kind_types,                   ONLY: qs_kind_type
      66             :    USE qs_subsys_types,                 ONLY: qs_subsys_get,&
      67             :                                               qs_subsys_type
      68             :    USE rpa_communication,               ONLY: gamma_fm_to_dbcsr
      69             :    USE rpa_util,                        ONLY: calc_fm_mat_S_rpa,&
      70             :                                               remove_scaling_factor_rpa
      71             :    USE scf_control_types,               ONLY: scf_control_type
      72             : #include "./base/base_uses.f90"
      73             : 
      74             :    IMPLICIT NONE
      75             : 
      76             :    PRIVATE
      77             : 
      78             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_exchange'
      79             : 
      80             :    PUBLIC :: rpa_exchange_work_type, rpa_exchange_needed_mem
      81             : 
      82             :    TYPE rpa_exchange_env_type
      83             :       PRIVATE
      84             :       TYPE(qs_environment_type), POINTER             :: qs_env => NULL()
      85             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: mat_hfx => NULL()
      86             :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER      :: dbcsr_Gamma_munu_P => NULL()
      87             :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)    :: dbcsr_Gamma_inu_P
      88             :       ! Workaround GCC 8
      89             :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_o => NULL()
      90             :       TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_v => NULL()
      91             :       TYPE(dbcsr_type)                               :: work_ao
      92             :       TYPE(hfx_type), DIMENSION(:, :), POINTER       :: x_data => NULL()
      93             :       TYPE(mp_para_env_type), POINTER                :: para_env => NULL()
      94             :       TYPE(section_vals_type), POINTER               :: hfx_sections => NULL()
      95             :       LOGICAL :: my_recalc_hfx_integrals = .FALSE.
      96             :       REAL(KIND=dp) :: eps_filter = 0.0_dp
      97             :       TYPE(cp_fm_struct_p_type), DIMENSION(:), ALLOCATABLE :: struct_Gamma
      98             :    CONTAINS
      99             :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: create => hfx_create_subgroup
     100             :       !PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: integrate => integrate_exchange
     101             :       PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: release => hfx_release_subgroup
     102             :    END TYPE
     103             : 
     104             :    TYPE dbcsr_matrix_p_set
     105             :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: matrix_set
     106             :    END TYPE
     107             : 
     108             :    TYPE rpa_exchange_work_type
     109             :       PRIVATE
     110             :       INTEGER :: exchange_correction = rpa_exchange_none
     111             :       TYPE(rpa_exchange_env_type) :: exchange_env
     112             :       INTEGER, DIMENSION(:), ALLOCATABLE :: homo, virtual, dimen_ia
     113             :       TYPE(group_dist_d1_type) :: aux_func_dist = group_dist_d1_type()
     114             :       INTEGER, DIMENSION(:), ALLOCATABLE :: aux2send
     115             :       INTEGER :: dimen_RI = 0
     116             :       INTEGER :: block_size = 0
     117             :       INTEGER :: color_sub = 0
     118             :       INTEGER :: ngroup = 0
     119             :       TYPE(cp_fm_type) :: fm_mat_Q_tmp = cp_fm_type()
     120             :       TYPE(cp_fm_type) :: fm_mat_R_half_gemm = cp_fm_type()
     121             :       TYPE(cp_fm_type) :: fm_mat_U = cp_fm_type()
     122             :       TYPE(mp_para_env_type), POINTER :: para_env_sub => NULL()
     123             :    CONTAINS
     124             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: create => rpa_exchange_work_create
     125             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: compute => rpa_exchange_work_compute
     126             :       PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: release => rpa_exchange_work_release
     127             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: redistribute_into_subgroups
     128             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_fm => rpa_exchange_work_compute_fm
     129             :       PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_hfx => rpa_exchange_work_compute_hfx
     130             :    END TYPE
     131             : 
     132             : CONTAINS
     133             : 
     134             : ! **************************************************************************************************
     135             : !> \brief ...
     136             : !> \param mp2_env ...
     137             : !> \param homo ...
     138             : !> \param virtual ...
     139             : !> \param dimen_RI ...
     140             : !> \param para_env ...
     141             : !> \param mem_per_rank ...
     142             : !> \param mem_per_repl ...
     143             : ! **************************************************************************************************
     144         118 :    SUBROUTINE rpa_exchange_needed_mem(mp2_env, homo, virtual, dimen_RI, para_env, mem_per_rank, mem_per_repl)
     145             :       TYPE(mp2_type), INTENT(IN)                         :: mp2_env
     146             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: homo, virtual
     147             :       INTEGER, INTENT(IN)                                :: dimen_RI
     148             :       TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
     149             :       REAL(KIND=dp), INTENT(INOUT)                       :: mem_per_rank, mem_per_repl
     150             : 
     151             :       INTEGER                                            :: block_size
     152             : 
     153             :       ! We need the block size and if it is unknown, an upper bound
     154         118 :       block_size = mp2_env%ri_rpa%exchange_block_size
     155         118 :       IF (block_size <= 0) block_size = MAX(1, (dimen_RI + para_env%num_pe - 1)/para_env%num_pe)
     156             : 
     157             :       ! storage of product matrix (upper bound only as it depends on the square of the potential still unknown block size)
     158         260 :       mem_per_rank = mem_per_rank + REAL(MAXVAL(homo), KIND=dp)**2*block_size**2*8.0_dp/(1024_dp**2)
     159             : 
     160             :       ! work arrays R (2x) and U, copies of Gamma (2x), communication buffer (as expensive as Gamma)
     161             :       mem_per_repl = mem_per_repl + 3.0_dp*dimen_RI*dimen_RI*8.0_dp/(1024_dp**2) &
     162         260 :                      + 3.0_dp*MAXVAL(homo*virtual)*dimen_RI*8.0_dp/(1024_dp**2)
     163         118 :    END SUBROUTINE rpa_exchange_needed_mem
     164             : 
     165             : ! **************************************************************************************************
     166             : !> \brief ...
     167             : !> \param exchange_work ...
     168             : !> \param qs_env ...
     169             : !> \param para_env_sub ...
     170             : !> \param mat_munu ...
     171             : !> \param dimen_RI ...
     172             : !> \param fm_mat_S ...
     173             : !> \param fm_mat_Q ...
     174             : !> \param fm_mat_Q_gemm ...
     175             : !> \param homo ...
     176             : !> \param virtual ...
     177             : ! **************************************************************************************************
     178         110 :    SUBROUTINE rpa_exchange_work_create(exchange_work, qs_env, para_env_sub, mat_munu, dimen_RI, &
     179         110 :                                        fm_mat_S, fm_mat_Q, fm_mat_Q_gemm, homo, virtual)
     180             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     181             :       TYPE(qs_environment_type), POINTER :: qs_env
     182             :       TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
     183             :       TYPE(dbcsr_p_type), INTENT(IN) :: mat_munu
     184             :       INTEGER, INTENT(IN) :: dimen_RI
     185             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     186             :       TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q, fm_mat_Q_gemm
     187             :       INTEGER, DIMENSION(SIZE(fm_mat_S)), INTENT(IN) :: homo, virtual
     188             : 
     189             :       INTEGER :: nspins, aux_global, aux_local, my_process_row, proc, ispin
     190         110 :       INTEGER, DIMENSION(:), POINTER :: row_indices, aux_distribution_fm
     191             :       TYPE(cp_blacs_env_type), POINTER :: context
     192             : 
     193         110 :       exchange_work%exchange_correction = qs_env%mp2_env%ri_rpa%exchange_correction
     194             : 
     195         110 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     196             : 
     197             :       ASSOCIATE (para_env => fm_mat_S(1)%matrix_struct%para_env)
     198          12 :          exchange_work%para_env_sub => para_env_sub
     199          12 :          exchange_work%ngroup = para_env%num_pe/para_env_sub%num_pe
     200          12 :          exchange_work%color_sub = para_env%mepos/para_env_sub%num_pe
     201             :       END ASSOCIATE
     202             : 
     203          12 :       CALL cp_fm_get_info(fm_mat_S(1), row_indices=row_indices, nrow_locals=aux_distribution_fm, context=context)
     204          12 :       CALL context%get(my_process_row=my_process_row)
     205             : 
     206          12 :       CALL create_group_dist(exchange_work%aux_func_dist, exchange_work%ngroup, dimen_RI)
     207          36 :       ALLOCATE (exchange_work%aux2send(0:exchange_work%ngroup - 1))
     208          36 :       exchange_work%aux2send = 0
     209         499 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     210         487 :          aux_global = row_indices(aux_local)
     211         487 :          proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     212         499 :          exchange_work%aux2send(proc) = exchange_work%aux2send(proc) + 1
     213             :       END DO
     214             : 
     215          12 :       nspins = SIZE(fm_mat_S)
     216             : 
     217          60 :       ALLOCATE (exchange_work%homo(nspins), exchange_work%virtual(nspins), exchange_work%dimen_ia(nspins))
     218          26 :       exchange_work%homo(:) = homo
     219          26 :       exchange_work%virtual(:) = virtual
     220          26 :       exchange_work%dimen_ia(:) = homo*virtual
     221          12 :       exchange_work%dimen_RI = dimen_RI
     222             : 
     223          12 :       exchange_work%block_size = qs_env%mp2_env%ri_rpa%exchange_block_size
     224          12 :       IF (exchange_work%block_size <= 0) exchange_work%block_size = dimen_RI
     225             : 
     226          12 :       CALL cp_fm_create(exchange_work%fm_mat_U, fm_mat_Q%matrix_struct, name="fm_mat_U")
     227          12 :       CALL cp_fm_create(exchange_work%fm_mat_Q_tmp, fm_mat_Q%matrix_struct, name="fm_mat_Q_tmp")
     228          12 :       CALL cp_fm_create(exchange_work%fm_mat_R_half_gemm, fm_mat_Q_gemm%matrix_struct)
     229             : 
     230          12 :       IF (qs_env%mp2_env%ri_rpa%use_hfx_implementation) THEN
     231           2 :          CALL exchange_work%exchange_env%create(qs_env, mat_munu%matrix, para_env_sub, fm_mat_S)
     232             :       END IF
     233             : 
     234          12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_o)) THEN
     235          22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_o)
     236          22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_o(ispin))
     237             :          END DO
     238          10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_o)
     239             :       END IF
     240             : 
     241          12 :       IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_v)) THEN
     242          22 :          DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_v)
     243          22 :             CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_v(ispin))
     244             :          END DO
     245          10 :          DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_v)
     246             :       END IF
     247         110 :    END SUBROUTINE
     248             : 
     249             : ! **************************************************************************************************
     250             : !> \brief ... Initializes x_data on a subgroup
     251             : !> \param exchange_env ...
     252             : !> \param qs_env ...
     253             : !> \param mat_munu ...
     254             : !> \param para_env_sub ...
     255             : !> \param fm_mat_S ...
     256             : !> \author Vladimir Rybkin
     257             : ! **************************************************************************************************
     258           2 :    SUBROUTINE hfx_create_subgroup(exchange_env, qs_env, mat_munu, para_env_sub, fm_mat_S)
     259             :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     260             :       TYPE(dbcsr_type), INTENT(IN) :: mat_munu
     261             :       TYPE(qs_environment_type), POINTER   :: qs_env
     262             :       TYPE(mp_para_env_type), POINTER, INTENT(IN)            :: para_env_sub
     263             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
     264             : 
     265             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_create_subgroup'
     266             : 
     267             :       INTEGER                                            :: handle, nelectron_total, ispin, &
     268             :                                                             number_of_aos, nspins, dimen_RI, dimen_ia
     269           2 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     270             :       TYPE(cell_type), POINTER                           :: my_cell
     271             :       TYPE(dft_control_type), POINTER                    :: dft_control
     272           2 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     273           2 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     274             :       TYPE(qs_subsys_type), POINTER                      :: subsys
     275             :       TYPE(scf_control_type), POINTER                    :: scf_control
     276             :       TYPE(section_vals_type), POINTER                   :: input
     277             : 
     278           2 :       CALL timeset(routineN, handle)
     279             : 
     280           2 :       exchange_env%mo_coeff_o => qs_env%mp2_env%ri_rpa%mo_coeff_o
     281           2 :       exchange_env%mo_coeff_v => qs_env%mp2_env%ri_rpa%mo_coeff_v
     282           2 :       NULLIFY (qs_env%mp2_env%ri_rpa%mo_coeff_o, qs_env%mp2_env%ri_rpa%mo_coeff_v)
     283             : 
     284           2 :       nspins = SIZE(exchange_env%mo_coeff_o)
     285             : 
     286           2 :       exchange_env%qs_env => qs_env
     287           2 :       exchange_env%para_env => para_env_sub
     288           2 :       exchange_env%eps_filter = qs_env%mp2_env%mp2_gpw%eps_filter
     289             : 
     290           2 :       NULLIFY (my_cell, atomic_kind_set, particle_set, dft_control, qs_kind_set, scf_control)
     291             : 
     292             :       CALL get_qs_env(qs_env, &
     293             :                       subsys=subsys, &
     294             :                       input=input, &
     295             :                       scf_control=scf_control, &
     296           2 :                       nelectron_total=nelectron_total)
     297             : 
     298             :       CALL qs_subsys_get(subsys, &
     299             :                          cell=my_cell, &
     300             :                          atomic_kind_set=atomic_kind_set, &
     301             :                          qs_kind_set=qs_kind_set, &
     302           2 :                          particle_set=particle_set)
     303             : 
     304           2 :       exchange_env%hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")
     305           2 :       CALL get_qs_env(qs_env, dft_control=dft_control)
     306             : 
     307             :       ! Retrieve particle_set and atomic_kind_set
     308             :       CALL hfx_create(exchange_env%x_data, para_env_sub, exchange_env%hfx_sections, atomic_kind_set, &
     309             :                       qs_kind_set, particle_set, dft_control, my_cell, orb_basis='ORB', &
     310           2 :                       nelectron_total=nelectron_total)
     311             : 
     312           2 :       exchange_env%my_recalc_hfx_integrals = .TRUE.
     313             : 
     314           2 :       CALL dbcsr_allocate_matrix_set(exchange_env%mat_hfx, nspins)
     315           4 :       DO ispin = 1, nspins
     316           2 :          ALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     317           2 :          CALL dbcsr_init_p(exchange_env%mat_hfx(ispin)%matrix)
     318             :          CALL dbcsr_create(exchange_env%mat_hfx(ispin)%matrix, template=mat_munu, &
     319           2 :                            matrix_type=dbcsr_type_no_symmetry)
     320           4 :          CALL dbcsr_copy(exchange_env%mat_hfx(ispin)%matrix, mat_munu)
     321             :       END DO
     322             : 
     323           2 :       CALL dbcsr_get_info(mat_munu, nfullcols_total=number_of_aos)
     324             : 
     325             :       CALL dbcsr_create(exchange_env%work_ao, template=mat_munu, &
     326           2 :                         matrix_type=dbcsr_type_no_symmetry)
     327             : 
     328           8 :       ALLOCATE (exchange_env%dbcsr_Gamma_inu_P(nspins))
     329           2 :       CALL dbcsr_allocate_matrix_set(exchange_env%dbcsr_Gamma_munu_P, nspins)
     330           4 :       DO ispin = 1, nspins
     331           2 :          ALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     332             :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, template=mat_munu, &
     333           2 :                            matrix_type=dbcsr_type_no_symmetry)
     334           2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, mat_munu)
     335           2 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, 0.0_dp)
     336             : 
     337           2 :          CALL dbcsr_create(exchange_env%dbcsr_Gamma_inu_P(ispin), template=exchange_env%mo_coeff_o(ispin))
     338           2 :          CALL dbcsr_copy(exchange_env%dbcsr_Gamma_inu_P(ispin), exchange_env%mo_coeff_o(ispin))
     339           4 :          CALL dbcsr_set(exchange_env%dbcsr_Gamma_inu_P(ispin), 0.0_dp)
     340             :       END DO
     341             : 
     342           8 :       ALLOCATE (exchange_env%struct_Gamma(nspins))
     343           4 :       DO ispin = 1, nspins
     344           2 :          CALL cp_fm_get_info(fm_mat_S(ispin), nrow_global=dimen_RI, ncol_global=dimen_ia)
     345             :          CALL cp_fm_struct_create(exchange_env%struct_Gamma(ispin)%struct, template_fmstruct=fm_mat_S(ispin)%matrix_struct, &
     346           4 :                                   nrow_global=dimen_ia, ncol_global=dimen_RI)
     347             :       END DO
     348             : 
     349           2 :       CALL timestop(handle)
     350             : 
     351           4 :    END SUBROUTINE hfx_create_subgroup
     352             : 
     353             : ! **************************************************************************************************
     354             : !> \brief ...
     355             : !> \param exchange_work ...
     356             : ! **************************************************************************************************
     357         110 :    SUBROUTINE rpa_exchange_work_release(exchange_work)
     358             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     359             : 
     360         110 :       IF (ALLOCATED(exchange_work%homo)) DEALLOCATE (exchange_work%homo)
     361         110 :       IF (ALLOCATED(exchange_work%virtual)) DEALLOCATE (exchange_work%virtual)
     362         110 :       IF (ALLOCATED(exchange_work%dimen_ia)) DEALLOCATE (exchange_work%dimen_ia)
     363         110 :       NULLIFY (exchange_work%para_env_sub)
     364         110 :       CALL release_group_dist(exchange_work%aux_func_dist)
     365         110 :       IF (ALLOCATED(exchange_work%aux2send)) DEALLOCATE (exchange_work%aux2send)
     366         110 :       CALL cp_fm_release(exchange_work%fm_mat_Q_tmp)
     367         110 :       CALL cp_fm_release(exchange_work%fm_mat_U)
     368         110 :       CALL cp_fm_release(exchange_work%fm_mat_R_half_gemm)
     369             : 
     370         110 :       CALL exchange_work%exchange_env%release()
     371         110 :    END SUBROUTINE
     372             : 
     373             : ! **************************************************************************************************
     374             : !> \brief ...
     375             : !> \param exchange_env ...
     376             : ! **************************************************************************************************
     377         110 :    SUBROUTINE hfx_release_subgroup(exchange_env)
     378             :       CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
     379             : 
     380             :       INTEGER :: ispin
     381             : 
     382         110 :       NULLIFY (exchange_env%para_env, exchange_env%hfx_sections)
     383             : 
     384         110 :       IF (ASSOCIATED(exchange_env%x_data)) THEN
     385           2 :          CALL hfx_release(exchange_env%x_data)
     386           2 :          NULLIFY (exchange_env%x_data)
     387             :       END IF
     388             : 
     389         110 :       CALL dbcsr_release(exchange_env%work_ao)
     390             : 
     391         110 :       IF (ASSOCIATED(exchange_env%dbcsr_Gamma_munu_P)) THEN
     392           4 :          DO ispin = 1, SIZE(exchange_env%mat_hfx, 1)
     393           2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     394           2 :             CALL dbcsr_release(exchange_env%mat_hfx(ispin)%matrix)
     395           2 :             CALL dbcsr_release(exchange_env%dbcsr_Gamma_inu_P(ispin))
     396           2 :             CALL dbcsr_release(exchange_env%mo_coeff_o(ispin))
     397           2 :             CALL dbcsr_release(exchange_env%mo_coeff_v(ispin))
     398           2 :             DEALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
     399           4 :             DEALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
     400             :          END DO
     401           2 :          DEALLOCATE (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     402           2 :          DEALLOCATE (exchange_env%dbcsr_Gamma_inu_P, exchange_env%mo_coeff_o, exchange_env%mo_coeff_v)
     403           2 :          NULLIFY (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
     404             :       END IF
     405         110 :       IF (ALLOCATED(exchange_env%struct_Gamma)) THEN
     406           4 :       DO ispin = 1, SIZE(exchange_env%struct_Gamma)
     407           4 :          CALL cp_fm_struct_release(exchange_env%struct_Gamma(ispin)%struct)
     408             :       END DO
     409           2 :       DEALLOCATE (exchange_env%struct_Gamma)
     410             :       END IF
     411         110 :    END SUBROUTINE hfx_release_subgroup
     412             : 
     413             : ! **************************************************************************************************
     414             : !> \brief Main driver for RPA-exchange energies
     415             : !> \param exchange_work ...
     416             : !> \param fm_mat_Q ...
     417             : !> \param eig ...
     418             : !> \param fm_mat_S ...
     419             : !> \param omega ...
     420             : !> \param e_exchange_corr exchange energy correction for a quadrature point
     421             : !> \param mp2_env ...
     422             : !> \author Vladimir Rybkin, 07/2016
     423             : ! **************************************************************************************************
     424          12 :    SUBROUTINE rpa_exchange_work_compute(exchange_work, fm_mat_Q, eig, fm_mat_S, omega, &
     425             :                                         e_exchange_corr, mp2_env)
     426             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     427             :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat_Q
     428             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     429             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT)         :: fm_mat_S
     430             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     431             :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     432             :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     433             : 
     434             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute'
     435             :       REAL(KIND=dp), PARAMETER                           :: thresh = 0.0000001_dp
     436             : 
     437             :       INTEGER :: handle, nspins, dimen_RI, iiB
     438          12 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigenval
     439             : 
     440          12 :       IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
     441             : 
     442          12 :       CALL timeset(routineN, handle)
     443             : 
     444          12 :       CALL cp_fm_get_info(fm_mat_Q, ncol_global=dimen_RI)
     445             : 
     446          12 :       nspins = SIZE(fm_mat_S)
     447             : 
     448             :       ! Eigenvalues
     449          36 :       ALLOCATE (eigenval(dimen_RI))
     450         986 :       eigenval = 0.0_dp
     451             : 
     452          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_Q_tmp, alpha=0.0_dp)
     453          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_U, alpha=0.0_dp)
     454             : 
     455             :       ! Copy Q to Q_tmp
     456          12 :       CALL cp_fm_to_fm(fm_mat_Q, exchange_work%fm_mat_Q_tmp)
     457             :       ! Diagonalize Q
     458          12 :       CALL choose_eigv_solver(exchange_work%fm_mat_Q_tmp, exchange_work%fm_mat_U, eigenval)
     459             : 
     460             :       ! Calculate diagonal matrix for R_half
     461             : 
     462             :       ! Manipulate eigenvalues to get diagonal matrix
     463          12 :       IF (exchange_work%exchange_correction == rpa_exchange_axk) THEN
     464         818 :          DO iib = 1, dimen_RI
     465         818 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     466             :                eigenval(iib) = &
     467             :                   SQRT((1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     468         718 :                        - 1.0_dp/(eigenval(iib)*(eigenval(iib) + 1.0_dp)))
     469             :             ELSE
     470          90 :                eigenval(iib) = sqrthalf
     471             :             END IF
     472             :          END DO
     473           2 :       ELSE IF (exchange_work%exchange_correction == rpa_exchange_sosex) THEN
     474         168 :          DO iib = 1, dimen_RI
     475         168 :             IF (ABS(eigenval(iib)) .GE. thresh) THEN
     476             :                eigenval(iib) = &
     477             :                   SQRT(-(1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
     478         144 :                        + 1.0_dp/eigenval(iib))
     479             :             ELSE
     480          22 :                eigenval(iib) = sqrthalf
     481             :             END IF
     482             :          END DO
     483             :       ELSE
     484           0 :          CPABORT("Unknown RPA exchange correction")
     485             :       END IF
     486             : 
     487             :       ! fm_mat_U now contains some sqrt of the required matrix-valued function
     488          12 :       CALL cp_fm_column_scale(exchange_work%fm_mat_U, eigenval)
     489             : 
     490             :       ! Release memory
     491          12 :       DEALLOCATE (eigenval)
     492             : 
     493             :       ! Redistribute fm_mat_U for "rectangular" multiplication: ia*P P*P
     494          12 :       CALL cp_fm_set_all(matrix=exchange_work%fm_mat_R_half_gemm, alpha=0.0_dp)
     495             : 
     496             :       CALL cp_fm_to_fm_submat_general(exchange_work%fm_mat_U, exchange_work%fm_mat_R_half_gemm, dimen_RI, &
     497          12 :                                       dimen_RI, 1, 1, 1, 1, exchange_work%fm_mat_U%matrix_struct%context)
     498             : 
     499          12 :       IF (mp2_env%ri_rpa%use_hfx_implementation) THEN
     500           2 :          CALL exchange_work%compute_hfx(fm_mat_S, eig, omega, e_exchange_corr)
     501             :       ELSE
     502          10 :          CALL exchange_work%compute_fm(fm_mat_S, eig, omega, e_exchange_corr, mp2_env)
     503             :       END IF
     504             : 
     505          12 :       CALL timestop(handle)
     506             : 
     507          12 :    END SUBROUTINE rpa_exchange_work_compute
     508             : 
     509             : ! **************************************************************************************************
     510             : !> \brief Main driver for RPA-exchange energies
     511             : !> \param exchange_work ...
     512             : !> \param fm_mat_S ...
     513             : !> \param eig ...
     514             : !> \param omega ...
     515             : !> \param e_exchange_corr exchange energy correction for a quadrature point
     516             : !> \param mp2_env ...
     517             : !> \author Frederick Stein, May-June 2024
     518             : ! **************************************************************************************************
     519          10 :    SUBROUTINE rpa_exchange_work_compute_fm(exchange_work, fm_mat_S, eig, omega, &
     520             :                                            e_exchange_corr, mp2_env)
     521             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     522             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(IN)         :: fm_mat_S
     523             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     524             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     525             :       REAL(KIND=dp), INTENT(INOUT)                       :: e_exchange_corr
     526             :       TYPE(mp2_type), INTENT(INOUT) :: mp2_env
     527             : 
     528             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'rpa_exchange_work_compute_fm'
     529             : 
     530             :       INTEGER :: handle, ispin, nspins, P, Q, L_size_Gamma, hom, virt, i, &
     531             :                  send_proc, recv_proc, recv_size, max_aux_size, proc_shift, dimen_ia, &
     532             :                  block_size, P_start, P_end, P_size, Q_start, Q_size, Q_end, handle2, my_aux_size, my_virt
     533          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET :: mat_Gamma_3_3D
     534          10 :       REAL(KIND=dp), POINTER, DIMENSION(:), CONTIGUOUS :: mat_Gamma_3_1D
     535          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: mat_Gamma_3_2D
     536          10 :       REAL(KIND=dp), ALLOCATABLE, TARGET, DIMENSION(:) :: recv_buffer_1D
     537          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: recv_buffer_2D
     538          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :), CONTIGUOUS :: recv_buffer_3D
     539          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: mat_B_iaP
     540          10 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: product_matrix_1D
     541          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: product_matrix_2D
     542          10 :       REAL(KIND=dp), POINTER, DIMENSION(:, :, :, :), CONTIGUOUS :: product_matrix_4D
     543             :       TYPE(cp_fm_type)        :: fm_mat_Gamma_3
     544             :       TYPE(mp_para_env_type), POINTER :: para_env
     545          10 :       TYPE(group_dist_d1_type)                           :: virt_dist
     546             : 
     547          10 :       CALL timeset(routineN, handle)
     548             : 
     549          10 :       nspins = SIZE(fm_mat_S)
     550             : 
     551          10 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, sizes=my_aux_size)
     552             : 
     553          10 :       e_exchange_corr = 0.0_dp
     554          10 :       max_aux_size = maxsize(exchange_work%aux_func_dist)
     555             : 
     556             :       ! local_gemm_ctx has a very large footprint the first time this routine is
     557             :       ! called.
     558          10 :       CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
     559          10 :       CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)
     560             : 
     561          22 :       DO ispin = 1, nspins
     562          12 :          hom = exchange_work%homo(ispin)
     563          12 :          virt = exchange_work%virtual(ispin)
     564          12 :          dimen_ia = hom*virt
     565          12 :          IF (hom < 1 .OR. virt < 1) CYCLE
     566             : 
     567          12 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     568             : 
     569          12 :          CALL cp_fm_create(fm_mat_Gamma_3, fm_mat_S(ispin)%matrix_struct)
     570          12 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     571             : 
     572             :          ! Update G with a new value of Omega: in practice, it is G*S
     573             : 
     574             :          ! Scale fm_work_iaP
     575             :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     576          12 :                                 hom, omega, 0.0_dp)
     577             : 
     578             :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     579             :          CALL parallel_gemm(transa="T", transb="N", m=exchange_work%dimen_RI, n=dimen_ia, k=exchange_work%dimen_RI, alpha=1.0_dp, &
     580             :                             matrix_a=exchange_work%fm_mat_R_half_gemm, matrix_b=fm_mat_S(ispin), beta=0.0_dp, &
     581          12 :                             matrix_c=fm_mat_Gamma_3)
     582             : 
     583          12 :          CALL create_group_dist(virt_dist, exchange_work%para_env_sub%num_pe, virt)
     584             : 
     585             :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     586          12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     587             : 
     588          12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_Gamma_3, mat_Gamma_3_3D, ispin, virt_dist)
     589          12 :          CALL cp_fm_release(fm_mat_Gamma_3)
     590             : 
     591             :          ! We need only the pure matrix
     592          12 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     593             : 
     594             :          ! Reorder matrix from (P, i*a) -> (a, i, P) with P being distributed within subgroups
     595          12 :          CALL exchange_work%redistribute_into_subgroups(fm_mat_S(ispin), mat_B_iaP, ispin, virt_dist)
     596             : 
     597             :          ! Return to the original tensor
     598          12 :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), hom, omega, 0.0_dp)
     599             : 
     600          12 :          L_size_Gamma = SIZE(mat_Gamma_3_3D, 3)
     601          12 :          my_virt = SIZE(mat_Gamma_3_3D, 1)
     602          12 :          block_size = exchange_work%block_size
     603             : 
     604          12 :          mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size) => mat_Gamma_3_3D(:, :, 1:my_aux_size)
     605          12 :          mat_Gamma_3_2D(1:my_virt, 1:hom*my_aux_size) => mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size)
     606             : 
     607           0 :          ALLOCATE (product_matrix_1D(INT(hom*MIN(block_size, L_size_gamma), KIND=int_8)* &
     608          36 :                                      INT(hom*MIN(block_size, max_aux_size), KIND=int_8)))
     609          36 :          ALLOCATE (recv_buffer_1D(INT(virt, KIND=int_8)*hom*max_aux_size))
     610          12 :          recv_buffer_2D(1:my_virt, 1:hom*max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     611          12 :          recv_buffer_3D(1:my_virt, 1:hom, 1:max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
     612          36 :          DO proc_shift = 0, para_env%num_pe - 1, exchange_work%para_env_sub%num_pe
     613          24 :             send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     614          24 :             recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     615             : 
     616          24 :             CALL get_group_dist(exchange_work%aux_func_dist, recv_proc/exchange_work%para_env_sub%num_pe, sizes=recv_size)
     617             : 
     618          24 :             IF (recv_size == 0) recv_proc = mp_proc_null
     619             : 
     620          24 :             CALL para_env%sendrecv(mat_B_iaP, send_proc, recv_buffer_3D(:, :, 1:recv_size), recv_proc)
     621             : 
     622          24 :             IF (recv_size == 0) CYCLE
     623             : 
     624        1038 :             DO P_start = 1, L_size_Gamma, block_size
     625        1002 :                P_end = MIN(L_size_Gamma, P_start + block_size - 1)
     626        1002 :                P_size = P_end - P_start + 1
     627       43875 :                DO Q_start = 1, recv_size, block_size
     628       42849 :                   Q_end = MIN(recv_size, Q_start + block_size - 1)
     629       42849 :                   Q_size = Q_end - Q_start + 1
     630             : 
     631             :                   ! Reassign product_matrix pointers to enforce contiguity of target array
     632             :                   product_matrix_2D(1:hom*P_size, 1:hom*Q_size) => &
     633       42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     634             :                   product_matrix_4D(1:hom, 1:P_size, 1:hom, 1:Q_size) => &
     635       42849 :                      product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
     636             : 
     637       42849 :                   CALL timeset(routineN//"_gemm", handle2)
     638             :                   CALL mp2_env%local_gemm_ctx%gemm("T", "N", hom*P_size, hom*Q_size, my_virt, 1.0_dp, &
     639             :                                                    mat_Gamma_3_2D(:, hom*(P_start - 1) + 1:hom*P_end), my_virt, &
     640             :                                                    recv_buffer_2D(:, hom*(Q_start - 1) + 1:hom*Q_end), my_virt, &
     641       42849 :                                                    0.0_dp, product_matrix_2D, hom*P_size)
     642       42849 :                   CALL timestop(handle2)
     643             : 
     644       42849 :                   CALL timeset(routineN//"_energy", handle2)
     645             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(P_size, Q_size, hom, product_matrix_4D) &
     646       42849 : !$OMP             COLLAPSE(3) REDUCTION(+: e_exchange_corr) PRIVATE(P, Q, i)
     647             :                   DO P = 1, P_size
     648             :                   DO Q = 1, Q_size
     649             :                   DO i = 1, hom
     650             :                      e_exchange_corr = e_exchange_corr + DOT_PRODUCT(product_matrix_4D(i, P, :, Q), product_matrix_4D(:, P, i, Q))
     651             :                   END DO
     652             :                   END DO
     653             :                   END DO
     654       86700 :                   CALL timestop(handle2)
     655             :                END DO
     656             :             END DO
     657             :          END DO
     658             : 
     659          12 :          CALL release_group_dist(virt_dist)
     660          12 :          IF (ALLOCATED(mat_B_iaP)) DEALLOCATE (mat_B_iaP)
     661          12 :          IF (ALLOCATED(mat_Gamma_3_3D)) DEALLOCATE (mat_Gamma_3_3D)
     662          12 :          IF (ALLOCATED(product_matrix_1D)) DEALLOCATE (product_matrix_1D)
     663          58 :          IF (ALLOCATED(recv_buffer_1D)) DEALLOCATE (recv_buffer_1D)
     664             :       END DO
     665             : 
     666          10 :       CALL mp2_env%local_gemm_ctx%destroy()
     667             : 
     668          10 :       IF (nspins == 2) e_exchange_corr = e_exchange_corr*2.0_dp
     669          10 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     670             : 
     671          10 :       CALL timestop(handle)
     672             : 
     673          20 :    END SUBROUTINE rpa_exchange_work_compute_fm
     674             : 
     675             : ! **************************************************************************************************
     676             : !> \brief Contract RPA-exchange density matrix with HF exchange integrals and evaluate the correction
     677             : !> \param exchange_work ...
     678             : !> \param fm_mat_S ...
     679             : !> \param eig ...
     680             : !> \param omega ...
     681             : !> \param e_exchange_corr ...
     682             : !> \author Vladimir Rybkin, 08/2016
     683             : ! **************************************************************************************************
     684           2 :    SUBROUTINE rpa_exchange_work_compute_hfx(exchange_work, fm_mat_S, eig, omega, e_exchange_corr)
     685             :       CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
     686             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
     687             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: eig
     688             :       REAL(KIND=dp), INTENT(IN)                          :: omega
     689             :       REAL(KIND=dp), INTENT(OUT) :: e_exchange_corr
     690             : 
     691             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_hfx'
     692             : 
     693             :       INTEGER                                            :: handle, ispin, my_aux_start, my_aux_end, &
     694             :                                                             my_aux_size, nspins, L_counter, dimen_ia, hom, virt
     695             :       REAL(KIND=dp)                                      :: e_exchange_P
     696           2 :       TYPE(dbcsr_matrix_p_set), DIMENSION(:), ALLOCATABLE          :: dbcsr_Gamma_3
     697             :       TYPE(cp_fm_type) :: fm_mat_Gamma_3
     698             :       TYPE(mp_para_env_type), POINTER :: para_env
     699             : 
     700           2 :       CALL timeset(routineN, handle)
     701             : 
     702           2 :       e_exchange_corr = 0.0_dp
     703             : 
     704           2 :       nspins = SIZE(fm_mat_S)
     705             : 
     706           2 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     707             : 
     708           8 :       ALLOCATE (dbcsr_Gamma_3(nspins))
     709           4 :       DO ispin = 1, nspins
     710           2 :          hom = exchange_work%homo(ispin)
     711           2 :          virt = exchange_work%virtual(ispin)
     712           2 :          dimen_ia = hom*virt
     713           2 :          IF (hom < 1 .OR. virt < 1) CYCLE
     714             : 
     715           2 :          CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
     716             : 
     717           2 :          CALL cp_fm_create(fm_mat_Gamma_3, exchange_work%exchange_env%struct_Gamma(ispin)%struct)
     718           2 :          CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
     719             : 
     720             :          ! Update G with a new value of Omega: in practice, it is G*S
     721             : 
     722             :          ! Scale fm_work_iaP
     723             :          CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
     724           2 :                                 hom, omega, 0.0_dp)
     725             : 
     726             :          ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
     727             :          CALL parallel_gemm(transa="T", transb="N", m=dimen_ia, n=exchange_work%dimen_RI, &
     728             :                             k=exchange_work%dimen_RI, alpha=1.0_dp, &
     729             :                             matrix_a=fm_mat_S(ispin), matrix_b=exchange_work%fm_mat_R_half_gemm, beta=0.0_dp, &
     730           2 :                             matrix_c=fm_mat_Gamma_3)
     731             : 
     732             :          ! Remove extra factor from S after the multiplication (to return to the original matrix)
     733           2 :          CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
     734             : 
     735             :          ! Copy Gamma_ia_P^3 to dbcsr matrix set
     736             :          CALL gamma_fm_to_dbcsr(fm_mat_Gamma_3, dbcsr_Gamma_3(ispin)%matrix_set, &
     737             :                                 para_env, exchange_work%para_env_sub, hom, virt, &
     738             :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     739           6 :                                 exchange_work%ngroup, my_aux_start, my_aux_end, my_aux_size)
     740             :       END DO
     741             : 
     742          85 :       DO L_counter = 1, my_aux_size
     743         166 :          DO ispin = 1, nspins
     744             :             ! Do dbcsr multiplication: transform the virtual index
     745             :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mo_coeff_v(ispin), &
     746             :                                 dbcsr_Gamma_3(ispin)%matrix_set(L_counter), &
     747             :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     748          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     749             : 
     750          83 :             CALL dbcsr_release(dbcsr_Gamma_3(ispin)%matrix_set(L_counter))
     751             : 
     752             :             ! Do dbcsr multiplication: transform the occupied index
     753             :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     754             :                                 exchange_work%exchange_env%mo_coeff_o(ispin), &
     755             :                                 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     756          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     757             :             CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%mo_coeff_o(ispin), &
     758             :                                 exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
     759             :                                 1.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     760          83 :                                 filter_eps=exchange_work%exchange_env%eps_filter)
     761             : 
     762         166 :             CALL dbcsr_set(exchange_work%exchange_env%mat_hfx(ispin)%matrix, 0.0_dp)
     763             :          END DO
     764             : 
     765             :          CALL tddft_hfx_matrix(exchange_work%exchange_env%mat_hfx, exchange_work%exchange_env%dbcsr_Gamma_munu_P, &
     766             :                                exchange_work%exchange_env%qs_env, .FALSE., &
     767             :                                exchange_work%exchange_env%my_recalc_hfx_integrals, &
     768             :                                exchange_work%exchange_env%hfx_sections, exchange_work%exchange_env%x_data, &
     769          83 :                                exchange_work%exchange_env%para_env)
     770             : 
     771          83 :          exchange_work%exchange_env%my_recalc_hfx_integrals = .FALSE.
     772         168 :          DO ispin = 1, nspins
     773             :             CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mat_hfx(ispin)%matrix, &
     774             :                                 exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
     775          83 :                                 0.0_dp, exchange_work%exchange_env%work_ao, filter_eps=exchange_work%exchange_env%eps_filter)
     776          83 :             CALL dbcsr_trace(exchange_work%exchange_env%work_ao, e_exchange_P)
     777         166 :             e_exchange_corr = e_exchange_corr - e_exchange_P
     778             :          END DO
     779             :       END DO
     780             : 
     781             :       IF (nspins == 2) e_exchange_corr = e_exchange_corr
     782           2 :       IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
     783             : 
     784           2 :       CALL timestop(handle)
     785             : 
     786           6 :    END SUBROUTINE rpa_exchange_work_compute_hfx
     787             : 
     788             : ! **************************************************************************************************
     789             : !> \brief ...
     790             : !> \param exchange_work ...
     791             : !> \param fm_mat ...
     792             : !> \param mat ...
     793             : !> \param ispin ...
     794             : !> \param virt_dist ...
     795             : ! **************************************************************************************************
     796          24 :    SUBROUTINE redistribute_into_subgroups(exchange_work, fm_mat, mat, ispin, virt_dist)
     797             :       CLASS(rpa_exchange_work_type), INTENT(IN) :: exchange_work
     798             :       TYPE(cp_fm_type), INTENT(IN)                       :: fm_mat
     799             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     800             :          INTENT(OUT)                                     :: mat
     801             :       INTEGER, INTENT(IN)                                :: ispin
     802             :       TYPE(group_dist_d1_type), INTENT(IN)               :: virt_dist
     803             : 
     804             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'redistribute_into_subgroups'
     805             : 
     806             :       INTEGER :: aux_counter, aux_global, aux_local, aux_proc, avirt, dimen_RI, handle, handle2, &
     807             :                  ia_global, ia_local, iocc, max_number_recv, max_number_send, my_aux_end, my_aux_size, &
     808             :                  my_aux_start, my_process_column, my_process_row, my_virt_end, my_virt_size, &
     809             :                  my_virt_start, proc, proc_shift, recv_proc, send_proc, virt_counter, virt_proc, group_size
     810          24 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: data2send, recv_col_indices, &
     811          24 :                                             recv_row_indices, send_aux_indices, send_virt_indices, virt2send
     812             :       INTEGER, DIMENSION(2)                              :: recv_shape
     813          24 :       INTEGER, DIMENSION(:), POINTER                     :: aux_distribution_fm, col_indices, &
     814          24 :                                                             ia_distribution_fm, row_indices
     815          24 :       INTEGER, DIMENSION(:, :), POINTER                  :: mpi2blacs
     816          24 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET   :: recv_buffer, send_buffer
     817             :       REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
     818          24 :          POINTER                                         :: recv_ptr, send_ptr
     819             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     820             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     821             : 
     822          24 :       CALL timeset(routineN, handle)
     823             : 
     824             :       CALL cp_fm_get_info(matrix=fm_mat, &
     825             :                           nrow_locals=aux_distribution_fm, &
     826             :                           col_indices=col_indices, &
     827             :                           row_indices=row_indices, &
     828             :                           ncol_locals=ia_distribution_fm, &
     829             :                           context=context, &
     830             :                           nrow_global=dimen_RI, &
     831          24 :                           para_env=para_env)
     832             : 
     833          24 :       IF (exchange_work%homo(ispin) <= 0 .OR. exchange_work%virtual(ispin) <= 0) THEN
     834           0 :          CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     835           0 :          ALLOCATE (mat(exchange_work%homo(ispin), my_virt_size, dimen_RI))
     836           0 :          CALL timestop(handle)
     837           0 :          RETURN
     838             :       END IF
     839             : 
     840          24 :       group_size = exchange_work%para_env_sub%num_pe
     841             : 
     842          24 :       CALL timeset(routineN//"_prep", handle2)
     843          24 :       CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
     844          24 :       CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
     845          24 :       CALL context%get(my_process_column=my_process_column, my_process_row=my_process_row, mpi2blacs=mpi2blacs)
     846             : 
     847             :       ! Determine the number of columns to send
     848         120 :       ALLOCATE (send_aux_indices(MAXVAL(exchange_work%aux2send)))
     849          72 :       ALLOCATE (virt2send(0:group_size - 1))
     850          48 :       virt2send = 0
     851        1924 :       DO ia_local = 1, ia_distribution_fm(my_process_column)
     852        1900 :          ia_global = col_indices(ia_local)
     853        1900 :          avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     854        1900 :          proc = group_dist_proc(virt_dist, avirt)
     855        1924 :          virt2send(proc) = virt2send(proc) + 1
     856             :       END DO
     857             : 
     858          72 :       ALLOCATE (data2send(0:para_env%num_pe - 1))
     859          72 :       DO aux_proc = 0, exchange_work%ngroup - 1
     860         120 :       DO virt_proc = 0, group_size - 1
     861          96 :          data2send(aux_proc*group_size + virt_proc) = exchange_work%aux2send(aux_proc)*virt2send(virt_proc)
     862             :       END DO
     863             :       END DO
     864             : 
     865          96 :       ALLOCATE (send_virt_indices(MAXVAL(virt2send)))
     866          72 :       max_number_send = MAXVAL(data2send)
     867             : 
     868          72 :       ALLOCATE (send_buffer(INT(max_number_send, KIND=int_8)*exchange_work%homo(ispin)))
     869          24 :       max_number_recv = max_number_send
     870          24 :       CALL para_env%max(max_number_recv)
     871          72 :       ALLOCATE (recv_buffer(max_number_recv))
     872             : 
     873         120 :       ALLOCATE (mat(my_virt_size, exchange_work%homo(ispin), my_aux_size))
     874             : 
     875          24 :       CALL timestop(handle2)
     876             : 
     877          24 :       CALL timeset(routineN//"_own", handle2)
     878             :       ! Start with own data
     879        1026 :       DO aux_local = 1, aux_distribution_fm(my_process_row)
     880        1002 :          aux_global = row_indices(aux_local)
     881        1002 :          IF (aux_global < my_aux_start .OR. aux_global > my_aux_end) CYCLE
     882       40848 :          DO ia_local = 1, ia_distribution_fm(my_process_column)
     883       40318 :             ia_global = fm_mat%matrix_struct%col_indices(ia_local)
     884             : 
     885       40318 :             iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     886       40318 :             avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     887             : 
     888       40318 :             IF (my_virt_start > avirt .OR. my_virt_end < avirt) CYCLE
     889             : 
     890       41320 :             mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = fm_mat%local_data(aux_local, ia_local)
     891             :          END DO
     892             :       END DO
     893          24 :       CALL timestop(handle2)
     894             : 
     895          48 :       DO proc_shift = 1, para_env%num_pe - 1
     896          24 :          send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
     897          24 :          recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
     898             : 
     899          24 :          CALL timeset(routineN//"_pack_buffer", handle2)
     900             :          send_ptr(1:virt2send(MOD(send_proc, group_size)), &
     901             :                   1:exchange_work%aux2send(send_proc/group_size)) => &
     902             :             send_buffer(1:INT(virt2send(MOD(send_proc, group_size)), KIND=int_8)* &
     903          24 :                         exchange_work%aux2send(send_proc/group_size))
     904             : ! Pack send buffer
     905          24 :          aux_counter = 0
     906        1026 :          DO aux_local = 1, aux_distribution_fm(my_process_row)
     907        1002 :             aux_global = row_indices(aux_local)
     908        1002 :             proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
     909        1002 :             IF (proc /= send_proc/group_size) CYCLE
     910         496 :             aux_counter = aux_counter + 1
     911         496 :             virt_counter = 0
     912       40016 :             DO ia_local = 1, ia_distribution_fm(my_process_column)
     913       39520 :                ia_global = col_indices(ia_local)
     914       39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     915             : 
     916       39520 :                proc = group_dist_proc(virt_dist, avirt)
     917       39520 :                IF (proc /= MOD(send_proc, group_size)) CYCLE
     918       39520 :                virt_counter = virt_counter + 1
     919       39520 :                send_ptr(virt_counter, aux_counter) = fm_mat%local_data(aux_local, ia_local)
     920       40016 :                send_virt_indices(virt_counter) = ia_global
     921             :             END DO
     922        1026 :             send_aux_indices(aux_counter) = aux_global
     923             :          END DO
     924          24 :          CALL timestop(handle2)
     925             : 
     926          24 :          CALL timeset(routineN//"_ex_size", handle2)
     927          24 :          recv_shape = [1, 1]
     928          72 :          CALL para_env%sendrecv(SHAPE(send_ptr), send_proc, recv_shape, recv_proc)
     929          24 :          CALL timestop(handle2)
     930             : 
     931          72 :          IF (SIZE(send_ptr) == 0) send_proc = mp_proc_null
     932          72 :          IF (PRODUCT(recv_shape) == 0) recv_proc = mp_proc_null
     933             : 
     934          24 :          CALL timeset(routineN//"_ex_idx", handle2)
     935         120 :          ALLOCATE (recv_row_indices(recv_shape(1)), recv_col_indices(recv_shape(2)))
     936          24 :          CALL para_env%sendrecv(send_virt_indices(1:virt_counter), send_proc, recv_row_indices, recv_proc)
     937          24 :          CALL para_env%sendrecv(send_aux_indices(1:aux_counter), send_proc, recv_col_indices, recv_proc)
     938          24 :          CALL timestop(handle2)
     939             : 
     940             :          ! Prepare pointer to recv buffer (consider transposition while packing the send buffer)
     941          24 :          recv_ptr(1:recv_shape(1), 1:MAX(1, recv_shape(2))) => recv_buffer(1:recv_shape(1)*MAX(1, recv_shape(2)))
     942             : 
     943          24 :          CALL timeset(routineN//"_sendrecv", handle2)
     944             : ! Perform communication
     945          24 :          CALL para_env%sendrecv(send_ptr, send_proc, recv_ptr, recv_proc)
     946          24 :          CALL timestop(handle2)
     947             : 
     948          24 :          IF (recv_proc == mp_proc_null) THEN
     949           0 :             DEALLOCATE (recv_row_indices, recv_col_indices)
     950           0 :             CYCLE
     951             :          END IF
     952             : 
     953          24 :          CALL timeset(routineN//"_unpack", handle2)
     954             : ! Unpack receive buffer
     955         520 :          DO aux_local = 1, SIZE(recv_col_indices)
     956         496 :             aux_global = recv_col_indices(aux_local)
     957             : 
     958       40040 :             DO ia_local = 1, SIZE(recv_row_indices)
     959       39520 :                ia_global = recv_row_indices(ia_local)
     960             : 
     961       39520 :                iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
     962       39520 :                avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
     963             : 
     964       40016 :                mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = recv_ptr(ia_local, aux_local)
     965             :             END DO
     966             :          END DO
     967          24 :          CALL timestop(handle2)
     968             : 
     969          24 :          IF (ALLOCATED(recv_row_indices)) DEALLOCATE (recv_row_indices)
     970         168 :          IF (ALLOCATED(recv_col_indices)) DEALLOCATE (recv_col_indices)
     971             :       END DO
     972             : 
     973          24 :       DEALLOCATE (send_aux_indices, send_virt_indices)
     974             : 
     975          24 :       CALL timestop(handle)
     976             : 
     977          96 :    END SUBROUTINE redistribute_into_subgroups
     978             : 
     979           0 : END MODULE rpa_exchange

Generated by: LCOV version 1.15