Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Auxiliary routines needed for RPA-exchange
10 : !> given blacs_env to another
11 : !> \par History
12 : !> 09.2016 created [Vladimir Rybkin]
13 : !> 03.2019 Renamed [Frederick Stein]
14 : !> 03.2019 Moved Functions from rpa_ri_gpw.F [Frederick Stein]
15 : !> 04.2024 Added open-shell calculations, SOSEX [Frederick Stein]
16 : !> \author Vladimir Rybkin
17 : ! **************************************************************************************************
18 : MODULE rpa_exchange
19 : USE atomic_kind_types, ONLY: atomic_kind_type
20 : USE cell_types, ONLY: cell_type
21 : USE cp_blacs_env, ONLY: cp_blacs_env_type
22 : USE cp_control_types, ONLY: dft_control_type
23 : USE cp_dbcsr_api, ONLY: &
24 : dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, dbcsr_multiply, dbcsr_p_type, &
25 : dbcsr_release, dbcsr_set, dbcsr_type, dbcsr_type_no_symmetry
26 : USE cp_dbcsr_contrib, ONLY: dbcsr_trace
27 : USE cp_dbcsr_operations, ONLY: dbcsr_allocate_matrix_set
28 : USE cp_fm_basic_linalg, ONLY: cp_fm_column_scale
29 : USE cp_fm_diag, ONLY: choose_eigv_solver
30 : USE cp_fm_struct, ONLY: cp_fm_struct_create,&
31 : cp_fm_struct_p_type,&
32 : cp_fm_struct_release
33 : USE cp_fm_types, ONLY: cp_fm_create,&
34 : cp_fm_get_info,&
35 : cp_fm_release,&
36 : cp_fm_set_all,&
37 : cp_fm_to_fm,&
38 : cp_fm_to_fm_submat_general,&
39 : cp_fm_type
40 : USE group_dist_types, ONLY: create_group_dist,&
41 : get_group_dist,&
42 : group_dist_d1_type,&
43 : group_dist_proc,&
44 : maxsize,&
45 : release_group_dist
46 : USE hfx_admm_utils, ONLY: tddft_hfx_matrix
47 : USE hfx_types, ONLY: hfx_create,&
48 : hfx_release,&
49 : hfx_type
50 : USE input_constants, ONLY: rpa_exchange_axk,&
51 : rpa_exchange_none,&
52 : rpa_exchange_sosex
53 : USE input_section_types, ONLY: section_vals_get_subs_vals,&
54 : section_vals_type
55 : USE kinds, ONLY: dp,&
56 : int_8
57 : USE local_gemm_api, ONLY: LOCAL_GEMM_PU_GPU
58 : USE mathconstants, ONLY: sqrthalf
59 : USE message_passing, ONLY: mp_para_env_type,&
60 : mp_proc_null
61 : USE mp2_types, ONLY: mp2_type
62 : USE parallel_gemm_api, ONLY: parallel_gemm
63 : USE particle_types, ONLY: particle_type
64 : USE qs_environment_types, ONLY: get_qs_env,&
65 : qs_environment_type
66 : USE qs_kind_types, ONLY: qs_kind_type
67 : USE qs_subsys_types, ONLY: qs_subsys_get,&
68 : qs_subsys_type
69 : USE rpa_communication, ONLY: gamma_fm_to_dbcsr
70 : USE rpa_util, ONLY: calc_fm_mat_S_rpa,&
71 : remove_scaling_factor_rpa
72 : USE scf_control_types, ONLY: scf_control_type
73 : #include "./base/base_uses.f90"
74 :
75 : IMPLICIT NONE
76 :
77 : PRIVATE
78 :
79 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_exchange'
80 :
81 : PUBLIC :: rpa_exchange_work_type, rpa_exchange_needed_mem
82 :
83 : TYPE rpa_exchange_env_type
84 : PRIVATE
85 : TYPE(qs_environment_type), POINTER :: qs_env => NULL()
86 : TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: mat_hfx => NULL()
87 : TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: dbcsr_Gamma_munu_P => NULL()
88 : TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: dbcsr_Gamma_inu_P
89 : ! Workaround GCC 8
90 : TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_o => NULL()
91 : TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_v => NULL()
92 : TYPE(dbcsr_type) :: work_ao
93 : TYPE(hfx_type), DIMENSION(:, :), POINTER :: x_data => NULL()
94 : TYPE(mp_para_env_type), POINTER :: para_env => NULL()
95 : TYPE(section_vals_type), POINTER :: hfx_sections => NULL()
96 : LOGICAL :: my_recalc_hfx_integrals = .FALSE.
97 : REAL(KIND=dp) :: eps_filter = 0.0_dp
98 : TYPE(cp_fm_struct_p_type), DIMENSION(:), ALLOCATABLE :: struct_Gamma
99 : CONTAINS
100 : PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: create => hfx_create_subgroup
101 : !PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: integrate => integrate_exchange
102 : PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: release => hfx_release_subgroup
103 : END TYPE
104 :
105 : TYPE dbcsr_matrix_p_set
106 : TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: matrix_set
107 : END TYPE
108 :
109 : TYPE rpa_exchange_work_type
110 : PRIVATE
111 : INTEGER :: exchange_correction = rpa_exchange_none
112 : TYPE(rpa_exchange_env_type) :: exchange_env
113 : INTEGER, DIMENSION(:), ALLOCATABLE :: homo, virtual, dimen_ia
114 : TYPE(group_dist_d1_type) :: aux_func_dist = group_dist_d1_type()
115 : INTEGER, DIMENSION(:), ALLOCATABLE :: aux2send
116 : INTEGER :: dimen_RI = 0
117 : INTEGER :: block_size = 0
118 : INTEGER :: color_sub = 0
119 : INTEGER :: ngroup = 0
120 : TYPE(cp_fm_type) :: fm_mat_Q_tmp = cp_fm_type()
121 : TYPE(cp_fm_type) :: fm_mat_R_half_gemm = cp_fm_type()
122 : TYPE(cp_fm_type) :: fm_mat_U = cp_fm_type()
123 : TYPE(mp_para_env_type), POINTER :: para_env_sub => NULL()
124 : CONTAINS
125 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: create => rpa_exchange_work_create
126 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: compute => rpa_exchange_work_compute
127 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: release => rpa_exchange_work_release
128 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: redistribute_into_subgroups
129 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_fm => rpa_exchange_work_compute_fm
130 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_hfx => rpa_exchange_work_compute_hfx
131 : END TYPE
132 :
133 : CONTAINS
134 :
135 : ! **************************************************************************************************
136 : !> \brief ...
137 : !> \param mp2_env ...
138 : !> \param homo ...
139 : !> \param virtual ...
140 : !> \param dimen_RI ...
141 : !> \param para_env ...
142 : !> \param mem_per_rank ...
143 : !> \param mem_per_repl ...
144 : ! **************************************************************************************************
145 128 : SUBROUTINE rpa_exchange_needed_mem(mp2_env, homo, virtual, dimen_RI, para_env, mem_per_rank, mem_per_repl)
146 : TYPE(mp2_type), INTENT(IN) :: mp2_env
147 : INTEGER, DIMENSION(:), INTENT(IN) :: homo, virtual
148 : INTEGER, INTENT(IN) :: dimen_RI
149 : TYPE(mp_para_env_type), INTENT(IN) :: para_env
150 : REAL(KIND=dp), INTENT(INOUT) :: mem_per_rank, mem_per_repl
151 :
152 : INTEGER :: block_size
153 :
154 : ! We need the block size and if it is unknown, an upper bound
155 128 : block_size = mp2_env%ri_rpa%exchange_block_size
156 128 : IF (block_size <= 0) block_size = MAX(1, (dimen_RI + para_env%num_pe - 1)/para_env%num_pe)
157 :
158 : ! storage of product matrix (upper bound only as it depends on the square of the potential still unknown block size)
159 284 : mem_per_rank = mem_per_rank + REAL(MAXVAL(homo), KIND=dp)**2*block_size**2*8.0_dp/(1024_dp**2)
160 :
161 : ! work arrays R (2x) and U, copies of Gamma (2x), communication buffer (as expensive as Gamma)
162 : mem_per_repl = mem_per_repl + 3.0_dp*dimen_RI*dimen_RI*8.0_dp/(1024_dp**2) &
163 284 : + 3.0_dp*MAXVAL(homo*virtual)*dimen_RI*8.0_dp/(1024_dp**2)
164 128 : END SUBROUTINE rpa_exchange_needed_mem
165 :
166 : ! **************************************************************************************************
167 : !> \brief ...
168 : !> \param exchange_work ...
169 : !> \param qs_env ...
170 : !> \param para_env_sub ...
171 : !> \param mat_munu ...
172 : !> \param dimen_RI ...
173 : !> \param fm_mat_S ...
174 : !> \param fm_mat_Q ...
175 : !> \param fm_mat_Q_gemm ...
176 : !> \param homo ...
177 : !> \param virtual ...
178 : ! **************************************************************************************************
179 132 : SUBROUTINE rpa_exchange_work_create(exchange_work, qs_env, para_env_sub, mat_munu, dimen_RI, &
180 132 : fm_mat_S, fm_mat_Q, fm_mat_Q_gemm, homo, virtual)
181 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
182 : TYPE(qs_environment_type), POINTER :: qs_env
183 : TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
184 : TYPE(dbcsr_p_type), INTENT(IN) :: mat_munu
185 : INTEGER, INTENT(IN) :: dimen_RI
186 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
187 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q, fm_mat_Q_gemm
188 : INTEGER, DIMENSION(SIZE(fm_mat_S)), INTENT(IN) :: homo, virtual
189 :
190 : INTEGER :: nspins, aux_global, aux_local, my_process_row, proc, ispin
191 132 : INTEGER, DIMENSION(:), POINTER :: row_indices, aux_distribution_fm
192 : TYPE(cp_blacs_env_type), POINTER :: context
193 :
194 132 : exchange_work%exchange_correction = qs_env%mp2_env%ri_rpa%exchange_correction
195 :
196 132 : IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
197 :
198 : ASSOCIATE (para_env => fm_mat_S(1)%matrix_struct%para_env)
199 12 : exchange_work%para_env_sub => para_env_sub
200 12 : exchange_work%ngroup = para_env%num_pe/para_env_sub%num_pe
201 12 : exchange_work%color_sub = para_env%mepos/para_env_sub%num_pe
202 : END ASSOCIATE
203 :
204 12 : CALL cp_fm_get_info(fm_mat_S(1), row_indices=row_indices, nrow_locals=aux_distribution_fm, context=context)
205 12 : CALL context%get(my_process_row=my_process_row)
206 :
207 12 : CALL create_group_dist(exchange_work%aux_func_dist, exchange_work%ngroup, dimen_RI)
208 36 : ALLOCATE (exchange_work%aux2send(0:exchange_work%ngroup - 1))
209 36 : exchange_work%aux2send = 0
210 499 : DO aux_local = 1, aux_distribution_fm(my_process_row)
211 487 : aux_global = row_indices(aux_local)
212 487 : proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
213 499 : exchange_work%aux2send(proc) = exchange_work%aux2send(proc) + 1
214 : END DO
215 :
216 12 : nspins = SIZE(fm_mat_S)
217 :
218 60 : ALLOCATE (exchange_work%homo(nspins), exchange_work%virtual(nspins), exchange_work%dimen_ia(nspins))
219 26 : exchange_work%homo(:) = homo
220 26 : exchange_work%virtual(:) = virtual
221 26 : exchange_work%dimen_ia(:) = homo*virtual
222 12 : exchange_work%dimen_RI = dimen_RI
223 :
224 12 : exchange_work%block_size = qs_env%mp2_env%ri_rpa%exchange_block_size
225 12 : IF (exchange_work%block_size <= 0) exchange_work%block_size = dimen_RI
226 :
227 12 : CALL cp_fm_create(exchange_work%fm_mat_U, fm_mat_Q%matrix_struct, name="fm_mat_U")
228 12 : CALL cp_fm_create(exchange_work%fm_mat_Q_tmp, fm_mat_Q%matrix_struct, name="fm_mat_Q_tmp")
229 12 : CALL cp_fm_create(exchange_work%fm_mat_R_half_gemm, fm_mat_Q_gemm%matrix_struct)
230 :
231 12 : IF (qs_env%mp2_env%ri_rpa%use_hfx_implementation) THEN
232 2 : CALL exchange_work%exchange_env%create(qs_env, mat_munu%matrix, para_env_sub, fm_mat_S)
233 : END IF
234 :
235 12 : IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_o)) THEN
236 22 : DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_o)
237 22 : CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_o(ispin))
238 : END DO
239 10 : DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_o)
240 : END IF
241 :
242 12 : IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_v)) THEN
243 22 : DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_v)
244 22 : CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_v(ispin))
245 : END DO
246 10 : DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_v)
247 : END IF
248 132 : END SUBROUTINE
249 :
250 : ! **************************************************************************************************
251 : !> \brief ... Initializes x_data on a subgroup
252 : !> \param exchange_env ...
253 : !> \param qs_env ...
254 : !> \param mat_munu ...
255 : !> \param para_env_sub ...
256 : !> \param fm_mat_S ...
257 : !> \author Vladimir Rybkin
258 : ! **************************************************************************************************
259 2 : SUBROUTINE hfx_create_subgroup(exchange_env, qs_env, mat_munu, para_env_sub, fm_mat_S)
260 : CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
261 : TYPE(dbcsr_type), INTENT(IN) :: mat_munu
262 : TYPE(qs_environment_type), POINTER :: qs_env
263 : TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
264 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
265 :
266 : CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_create_subgroup'
267 :
268 : INTEGER :: handle, nelectron_total, ispin, &
269 : number_of_aos, nspins, dimen_RI, dimen_ia
270 2 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
271 : TYPE(cell_type), POINTER :: my_cell
272 : TYPE(dft_control_type), POINTER :: dft_control
273 2 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
274 2 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
275 : TYPE(qs_subsys_type), POINTER :: subsys
276 : TYPE(scf_control_type), POINTER :: scf_control
277 : TYPE(section_vals_type), POINTER :: input
278 :
279 2 : CALL timeset(routineN, handle)
280 :
281 2 : exchange_env%mo_coeff_o => qs_env%mp2_env%ri_rpa%mo_coeff_o
282 2 : exchange_env%mo_coeff_v => qs_env%mp2_env%ri_rpa%mo_coeff_v
283 2 : NULLIFY (qs_env%mp2_env%ri_rpa%mo_coeff_o, qs_env%mp2_env%ri_rpa%mo_coeff_v)
284 :
285 2 : nspins = SIZE(exchange_env%mo_coeff_o)
286 :
287 2 : exchange_env%qs_env => qs_env
288 2 : exchange_env%para_env => para_env_sub
289 2 : exchange_env%eps_filter = qs_env%mp2_env%mp2_gpw%eps_filter
290 :
291 2 : NULLIFY (my_cell, atomic_kind_set, particle_set, dft_control, qs_kind_set, scf_control)
292 :
293 : CALL get_qs_env(qs_env, &
294 : subsys=subsys, &
295 : input=input, &
296 : scf_control=scf_control, &
297 2 : nelectron_total=nelectron_total)
298 :
299 : CALL qs_subsys_get(subsys, &
300 : cell=my_cell, &
301 : atomic_kind_set=atomic_kind_set, &
302 : qs_kind_set=qs_kind_set, &
303 2 : particle_set=particle_set)
304 :
305 2 : exchange_env%hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")
306 2 : CALL get_qs_env(qs_env, dft_control=dft_control)
307 :
308 : ! Retrieve particle_set and atomic_kind_set
309 : CALL hfx_create(exchange_env%x_data, para_env_sub, exchange_env%hfx_sections, atomic_kind_set, &
310 : qs_kind_set, particle_set, dft_control, my_cell, orb_basis='ORB', &
311 2 : nelectron_total=nelectron_total)
312 :
313 2 : exchange_env%my_recalc_hfx_integrals = .TRUE.
314 :
315 2 : CALL dbcsr_allocate_matrix_set(exchange_env%mat_hfx, nspins)
316 4 : DO ispin = 1, nspins
317 2 : ALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
318 2 : CALL dbcsr_init_p(exchange_env%mat_hfx(ispin)%matrix)
319 : CALL dbcsr_create(exchange_env%mat_hfx(ispin)%matrix, template=mat_munu, &
320 2 : matrix_type=dbcsr_type_no_symmetry)
321 4 : CALL dbcsr_copy(exchange_env%mat_hfx(ispin)%matrix, mat_munu)
322 : END DO
323 :
324 2 : CALL dbcsr_get_info(mat_munu, nfullcols_total=number_of_aos)
325 :
326 : CALL dbcsr_create(exchange_env%work_ao, template=mat_munu, &
327 2 : matrix_type=dbcsr_type_no_symmetry)
328 :
329 8 : ALLOCATE (exchange_env%dbcsr_Gamma_inu_P(nspins))
330 2 : CALL dbcsr_allocate_matrix_set(exchange_env%dbcsr_Gamma_munu_P, nspins)
331 4 : DO ispin = 1, nspins
332 2 : ALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
333 : CALL dbcsr_create(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, template=mat_munu, &
334 2 : matrix_type=dbcsr_type_no_symmetry)
335 2 : CALL dbcsr_copy(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, mat_munu)
336 2 : CALL dbcsr_set(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, 0.0_dp)
337 :
338 2 : CALL dbcsr_create(exchange_env%dbcsr_Gamma_inu_P(ispin), template=exchange_env%mo_coeff_o(ispin))
339 2 : CALL dbcsr_copy(exchange_env%dbcsr_Gamma_inu_P(ispin), exchange_env%mo_coeff_o(ispin))
340 4 : CALL dbcsr_set(exchange_env%dbcsr_Gamma_inu_P(ispin), 0.0_dp)
341 : END DO
342 :
343 8 : ALLOCATE (exchange_env%struct_Gamma(nspins))
344 4 : DO ispin = 1, nspins
345 2 : CALL cp_fm_get_info(fm_mat_S(ispin), nrow_global=dimen_RI, ncol_global=dimen_ia)
346 : CALL cp_fm_struct_create(exchange_env%struct_Gamma(ispin)%struct, template_fmstruct=fm_mat_S(ispin)%matrix_struct, &
347 4 : nrow_global=dimen_ia, ncol_global=dimen_RI)
348 : END DO
349 :
350 2 : CALL timestop(handle)
351 :
352 4 : END SUBROUTINE hfx_create_subgroup
353 :
354 : ! **************************************************************************************************
355 : !> \brief ...
356 : !> \param exchange_work ...
357 : ! **************************************************************************************************
358 132 : SUBROUTINE rpa_exchange_work_release(exchange_work)
359 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
360 :
361 132 : IF (ALLOCATED(exchange_work%homo)) DEALLOCATE (exchange_work%homo)
362 132 : IF (ALLOCATED(exchange_work%virtual)) DEALLOCATE (exchange_work%virtual)
363 132 : IF (ALLOCATED(exchange_work%dimen_ia)) DEALLOCATE (exchange_work%dimen_ia)
364 132 : NULLIFY (exchange_work%para_env_sub)
365 132 : CALL release_group_dist(exchange_work%aux_func_dist)
366 132 : IF (ALLOCATED(exchange_work%aux2send)) DEALLOCATE (exchange_work%aux2send)
367 132 : CALL cp_fm_release(exchange_work%fm_mat_Q_tmp)
368 132 : CALL cp_fm_release(exchange_work%fm_mat_U)
369 132 : CALL cp_fm_release(exchange_work%fm_mat_R_half_gemm)
370 :
371 132 : CALL exchange_work%exchange_env%release()
372 132 : END SUBROUTINE
373 :
374 : ! **************************************************************************************************
375 : !> \brief ...
376 : !> \param exchange_env ...
377 : ! **************************************************************************************************
378 132 : SUBROUTINE hfx_release_subgroup(exchange_env)
379 : CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
380 :
381 : INTEGER :: ispin
382 :
383 132 : NULLIFY (exchange_env%para_env, exchange_env%hfx_sections)
384 :
385 132 : IF (ASSOCIATED(exchange_env%x_data)) THEN
386 2 : CALL hfx_release(exchange_env%x_data)
387 2 : NULLIFY (exchange_env%x_data)
388 : END IF
389 :
390 132 : CALL dbcsr_release(exchange_env%work_ao)
391 :
392 132 : IF (ASSOCIATED(exchange_env%dbcsr_Gamma_munu_P)) THEN
393 4 : DO ispin = 1, SIZE(exchange_env%mat_hfx, 1)
394 2 : CALL dbcsr_release(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
395 2 : CALL dbcsr_release(exchange_env%mat_hfx(ispin)%matrix)
396 2 : CALL dbcsr_release(exchange_env%dbcsr_Gamma_inu_P(ispin))
397 2 : CALL dbcsr_release(exchange_env%mo_coeff_o(ispin))
398 2 : CALL dbcsr_release(exchange_env%mo_coeff_v(ispin))
399 2 : DEALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
400 4 : DEALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
401 : END DO
402 2 : DEALLOCATE (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
403 2 : DEALLOCATE (exchange_env%dbcsr_Gamma_inu_P, exchange_env%mo_coeff_o, exchange_env%mo_coeff_v)
404 2 : NULLIFY (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
405 : END IF
406 132 : IF (ALLOCATED(exchange_env%struct_Gamma)) THEN
407 4 : DO ispin = 1, SIZE(exchange_env%struct_Gamma)
408 4 : CALL cp_fm_struct_release(exchange_env%struct_Gamma(ispin)%struct)
409 : END DO
410 2 : DEALLOCATE (exchange_env%struct_Gamma)
411 : END IF
412 132 : END SUBROUTINE hfx_release_subgroup
413 :
414 : ! **************************************************************************************************
415 : !> \brief Main driver for RPA-exchange energies
416 : !> \param exchange_work ...
417 : !> \param fm_mat_Q ...
418 : !> \param eig ...
419 : !> \param fm_mat_S ...
420 : !> \param omega ...
421 : !> \param e_exchange_corr exchange energy correction for a quadrature point
422 : !> \param mp2_env ...
423 : !> \author Vladimir Rybkin, 07/2016
424 : ! **************************************************************************************************
425 12 : SUBROUTINE rpa_exchange_work_compute(exchange_work, fm_mat_Q, eig, fm_mat_S, omega, &
426 : e_exchange_corr, mp2_env)
427 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
428 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q
429 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
430 : TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
431 : REAL(KIND=dp), INTENT(IN) :: omega
432 : REAL(KIND=dp), INTENT(INOUT) :: e_exchange_corr
433 : TYPE(mp2_type), INTENT(INOUT) :: mp2_env
434 :
435 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute'
436 : REAL(KIND=dp), PARAMETER :: thresh = 0.0000001_dp
437 :
438 : INTEGER :: handle, nspins, dimen_RI, iiB
439 12 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: eigenval
440 :
441 12 : IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
442 :
443 12 : CALL timeset(routineN, handle)
444 :
445 12 : CALL cp_fm_get_info(fm_mat_Q, ncol_global=dimen_RI)
446 :
447 12 : nspins = SIZE(fm_mat_S)
448 :
449 : ! Eigenvalues
450 36 : ALLOCATE (eigenval(dimen_RI))
451 986 : eigenval = 0.0_dp
452 :
453 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_Q_tmp, alpha=0.0_dp)
454 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_U, alpha=0.0_dp)
455 :
456 : ! Copy Q to Q_tmp
457 12 : CALL cp_fm_to_fm(fm_mat_Q, exchange_work%fm_mat_Q_tmp)
458 : ! Diagonalize Q
459 12 : CALL choose_eigv_solver(exchange_work%fm_mat_Q_tmp, exchange_work%fm_mat_U, eigenval)
460 :
461 : ! Calculate diagonal matrix for R_half
462 :
463 : ! Manipulate eigenvalues to get diagonal matrix
464 12 : IF (exchange_work%exchange_correction == rpa_exchange_axk) THEN
465 818 : DO iib = 1, dimen_RI
466 818 : IF (ABS(eigenval(iib)) .GE. thresh) THEN
467 : eigenval(iib) = &
468 : SQRT((1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
469 718 : - 1.0_dp/(eigenval(iib)*(eigenval(iib) + 1.0_dp)))
470 : ELSE
471 90 : eigenval(iib) = sqrthalf
472 : END IF
473 : END DO
474 2 : ELSE IF (exchange_work%exchange_correction == rpa_exchange_sosex) THEN
475 168 : DO iib = 1, dimen_RI
476 168 : IF (ABS(eigenval(iib)) .GE. thresh) THEN
477 : eigenval(iib) = &
478 : SQRT(-(1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
479 144 : + 1.0_dp/eigenval(iib))
480 : ELSE
481 22 : eigenval(iib) = sqrthalf
482 : END IF
483 : END DO
484 : ELSE
485 0 : CPABORT("Unknown RPA exchange correction")
486 : END IF
487 :
488 : ! fm_mat_U now contains some sqrt of the required matrix-valued function
489 12 : CALL cp_fm_column_scale(exchange_work%fm_mat_U, eigenval)
490 :
491 : ! Release memory
492 12 : DEALLOCATE (eigenval)
493 :
494 : ! Redistribute fm_mat_U for "rectangular" multiplication: ia*P P*P
495 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_R_half_gemm, alpha=0.0_dp)
496 :
497 : CALL cp_fm_to_fm_submat_general(exchange_work%fm_mat_U, exchange_work%fm_mat_R_half_gemm, dimen_RI, &
498 12 : dimen_RI, 1, 1, 1, 1, exchange_work%fm_mat_U%matrix_struct%context)
499 :
500 12 : IF (mp2_env%ri_rpa%use_hfx_implementation) THEN
501 2 : CALL exchange_work%compute_hfx(fm_mat_S, eig, omega, e_exchange_corr)
502 : ELSE
503 10 : CALL exchange_work%compute_fm(fm_mat_S, eig, omega, e_exchange_corr, mp2_env)
504 : END IF
505 :
506 12 : CALL timestop(handle)
507 :
508 12 : END SUBROUTINE rpa_exchange_work_compute
509 :
510 : ! **************************************************************************************************
511 : !> \brief Main driver for RPA-exchange energies
512 : !> \param exchange_work ...
513 : !> \param fm_mat_S ...
514 : !> \param eig ...
515 : !> \param omega ...
516 : !> \param e_exchange_corr exchange energy correction for a quadrature point
517 : !> \param mp2_env ...
518 : !> \author Frederick Stein, May-June 2024
519 : ! **************************************************************************************************
520 10 : SUBROUTINE rpa_exchange_work_compute_fm(exchange_work, fm_mat_S, eig, omega, &
521 : e_exchange_corr, mp2_env)
522 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
523 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
524 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
525 : REAL(KIND=dp), INTENT(IN) :: omega
526 : REAL(KIND=dp), INTENT(INOUT) :: e_exchange_corr
527 : TYPE(mp2_type), INTENT(INOUT) :: mp2_env
528 :
529 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_fm'
530 :
531 : INTEGER :: handle, ispin, nspins, P, Q, L_size_Gamma, hom, virt, i, &
532 : send_proc, recv_proc, recv_size, max_aux_size, proc_shift, dimen_ia, &
533 : block_size, P_start, P_end, P_size, Q_start, Q_size, Q_end, handle2, my_aux_size, my_virt
534 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET :: mat_Gamma_3_3D
535 10 : REAL(KIND=dp), POINTER, DIMENSION(:), CONTIGUOUS :: mat_Gamma_3_1D
536 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: mat_Gamma_3_2D
537 10 : REAL(KIND=dp), ALLOCATABLE, TARGET, DIMENSION(:) :: recv_buffer_1D
538 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: recv_buffer_2D
539 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :, :), CONTIGUOUS :: recv_buffer_3D
540 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: mat_B_iaP
541 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: product_matrix_1D
542 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: product_matrix_2D
543 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :, :, :), CONTIGUOUS :: product_matrix_4D
544 : TYPE(cp_fm_type) :: fm_mat_Gamma_3
545 : TYPE(mp_para_env_type), POINTER :: para_env
546 10 : TYPE(group_dist_d1_type) :: virt_dist
547 :
548 10 : CALL timeset(routineN, handle)
549 :
550 10 : nspins = SIZE(fm_mat_S)
551 :
552 10 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, sizes=my_aux_size)
553 :
554 10 : e_exchange_corr = 0.0_dp
555 10 : max_aux_size = maxsize(exchange_work%aux_func_dist)
556 :
557 : ! local_gemm_ctx has a very large footprint the first time this routine is
558 : ! called.
559 10 : CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
560 10 : CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)
561 :
562 22 : DO ispin = 1, nspins
563 12 : hom = exchange_work%homo(ispin)
564 12 : virt = exchange_work%virtual(ispin)
565 12 : dimen_ia = hom*virt
566 12 : IF (hom < 1 .OR. virt < 1) CYCLE
567 :
568 12 : CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
569 :
570 12 : CALL cp_fm_create(fm_mat_Gamma_3, fm_mat_S(ispin)%matrix_struct)
571 12 : CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
572 :
573 : ! Update G with a new value of Omega: in practice, it is G*S
574 :
575 : ! Scale fm_work_iaP
576 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
577 12 : hom, omega, 0.0_dp)
578 :
579 : ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
580 : CALL parallel_gemm(transa="T", transb="N", m=exchange_work%dimen_RI, n=dimen_ia, k=exchange_work%dimen_RI, alpha=1.0_dp, &
581 : matrix_a=exchange_work%fm_mat_R_half_gemm, matrix_b=fm_mat_S(ispin), beta=0.0_dp, &
582 12 : matrix_c=fm_mat_Gamma_3)
583 :
584 12 : CALL create_group_dist(virt_dist, exchange_work%para_env_sub%num_pe, virt)
585 :
586 : ! Remove extra factor from S after the multiplication (to return to the original matrix)
587 12 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
588 :
589 12 : CALL exchange_work%redistribute_into_subgroups(fm_mat_Gamma_3, mat_Gamma_3_3D, ispin, virt_dist)
590 12 : CALL cp_fm_release(fm_mat_Gamma_3)
591 :
592 : ! We need only the pure matrix
593 12 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
594 :
595 : ! Reorder matrix from (P, i*a) -> (a, i, P) with P being distributed within subgroups
596 12 : CALL exchange_work%redistribute_into_subgroups(fm_mat_S(ispin), mat_B_iaP, ispin, virt_dist)
597 :
598 : ! Return to the original tensor
599 12 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), hom, omega, 0.0_dp)
600 :
601 12 : L_size_Gamma = SIZE(mat_Gamma_3_3D, 3)
602 12 : my_virt = SIZE(mat_Gamma_3_3D, 1)
603 12 : block_size = exchange_work%block_size
604 :
605 12 : mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size) => mat_Gamma_3_3D(:, :, 1:my_aux_size)
606 12 : mat_Gamma_3_2D(1:my_virt, 1:hom*my_aux_size) => mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size)
607 :
608 0 : ALLOCATE (product_matrix_1D(INT(hom*MIN(block_size, L_size_gamma), KIND=int_8)* &
609 36 : INT(hom*MIN(block_size, max_aux_size), KIND=int_8)))
610 36 : ALLOCATE (recv_buffer_1D(INT(virt, KIND=int_8)*hom*max_aux_size))
611 12 : recv_buffer_2D(1:my_virt, 1:hom*max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
612 12 : recv_buffer_3D(1:my_virt, 1:hom, 1:max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
613 36 : DO proc_shift = 0, para_env%num_pe - 1, exchange_work%para_env_sub%num_pe
614 24 : send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
615 24 : recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
616 :
617 24 : CALL get_group_dist(exchange_work%aux_func_dist, recv_proc/exchange_work%para_env_sub%num_pe, sizes=recv_size)
618 :
619 24 : IF (recv_size == 0) recv_proc = mp_proc_null
620 :
621 24 : CALL para_env%sendrecv(mat_B_iaP, send_proc, recv_buffer_3D(:, :, 1:recv_size), recv_proc)
622 :
623 24 : IF (recv_size == 0) CYCLE
624 :
625 1038 : DO P_start = 1, L_size_Gamma, block_size
626 1002 : P_end = MIN(L_size_Gamma, P_start + block_size - 1)
627 1002 : P_size = P_end - P_start + 1
628 43875 : DO Q_start = 1, recv_size, block_size
629 42849 : Q_end = MIN(recv_size, Q_start + block_size - 1)
630 42849 : Q_size = Q_end - Q_start + 1
631 :
632 : ! Reassign product_matrix pointers to enforce contiguity of target array
633 : product_matrix_2D(1:hom*P_size, 1:hom*Q_size) => &
634 42849 : product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
635 : product_matrix_4D(1:hom, 1:P_size, 1:hom, 1:Q_size) => &
636 42849 : product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
637 :
638 42849 : CALL timeset(routineN//"_gemm", handle2)
639 : CALL mp2_env%local_gemm_ctx%gemm("T", "N", hom*P_size, hom*Q_size, my_virt, 1.0_dp, &
640 : mat_Gamma_3_2D(:, hom*(P_start - 1) + 1:hom*P_end), my_virt, &
641 : recv_buffer_2D(:, hom*(Q_start - 1) + 1:hom*Q_end), my_virt, &
642 42849 : 0.0_dp, product_matrix_2D, hom*P_size)
643 42849 : CALL timestop(handle2)
644 :
645 42849 : CALL timeset(routineN//"_energy", handle2)
646 : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(P_size, Q_size, hom, product_matrix_4D) &
647 42849 : !$OMP COLLAPSE(3) REDUCTION(+: e_exchange_corr) PRIVATE(P, Q, i)
648 : DO P = 1, P_size
649 : DO Q = 1, Q_size
650 : DO i = 1, hom
651 : e_exchange_corr = e_exchange_corr + DOT_PRODUCT(product_matrix_4D(i, P, :, Q), product_matrix_4D(:, P, i, Q))
652 : END DO
653 : END DO
654 : END DO
655 86700 : CALL timestop(handle2)
656 : END DO
657 : END DO
658 : END DO
659 :
660 12 : CALL release_group_dist(virt_dist)
661 12 : IF (ALLOCATED(mat_B_iaP)) DEALLOCATE (mat_B_iaP)
662 12 : IF (ALLOCATED(mat_Gamma_3_3D)) DEALLOCATE (mat_Gamma_3_3D)
663 12 : IF (ALLOCATED(product_matrix_1D)) DEALLOCATE (product_matrix_1D)
664 58 : IF (ALLOCATED(recv_buffer_1D)) DEALLOCATE (recv_buffer_1D)
665 : END DO
666 :
667 10 : CALL mp2_env%local_gemm_ctx%destroy()
668 :
669 10 : IF (nspins == 2) e_exchange_corr = e_exchange_corr*2.0_dp
670 10 : IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
671 :
672 10 : CALL timestop(handle)
673 :
674 20 : END SUBROUTINE rpa_exchange_work_compute_fm
675 :
676 : ! **************************************************************************************************
677 : !> \brief Contract RPA-exchange density matrix with HF exchange integrals and evaluate the correction
678 : !> \param exchange_work ...
679 : !> \param fm_mat_S ...
680 : !> \param eig ...
681 : !> \param omega ...
682 : !> \param e_exchange_corr ...
683 : !> \author Vladimir Rybkin, 08/2016
684 : ! **************************************************************************************************
685 2 : SUBROUTINE rpa_exchange_work_compute_hfx(exchange_work, fm_mat_S, eig, omega, e_exchange_corr)
686 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
687 : TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
688 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
689 : REAL(KIND=dp), INTENT(IN) :: omega
690 : REAL(KIND=dp), INTENT(OUT) :: e_exchange_corr
691 :
692 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_hfx'
693 :
694 : INTEGER :: handle, ispin, my_aux_start, my_aux_end, &
695 : my_aux_size, nspins, L_counter, dimen_ia, hom, virt
696 : REAL(KIND=dp) :: e_exchange_P
697 2 : TYPE(dbcsr_matrix_p_set), DIMENSION(:), ALLOCATABLE :: dbcsr_Gamma_3
698 : TYPE(cp_fm_type) :: fm_mat_Gamma_3
699 : TYPE(mp_para_env_type), POINTER :: para_env
700 :
701 2 : CALL timeset(routineN, handle)
702 :
703 2 : e_exchange_corr = 0.0_dp
704 :
705 2 : nspins = SIZE(fm_mat_S)
706 :
707 2 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
708 :
709 8 : ALLOCATE (dbcsr_Gamma_3(nspins))
710 4 : DO ispin = 1, nspins
711 2 : hom = exchange_work%homo(ispin)
712 2 : virt = exchange_work%virtual(ispin)
713 2 : dimen_ia = hom*virt
714 2 : IF (hom < 1 .OR. virt < 1) CYCLE
715 :
716 2 : CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
717 :
718 2 : CALL cp_fm_create(fm_mat_Gamma_3, exchange_work%exchange_env%struct_Gamma(ispin)%struct)
719 2 : CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
720 :
721 : ! Update G with a new value of Omega: in practice, it is G*S
722 :
723 : ! Scale fm_work_iaP
724 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
725 2 : hom, omega, 0.0_dp)
726 :
727 : ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
728 : CALL parallel_gemm(transa="T", transb="N", m=dimen_ia, n=exchange_work%dimen_RI, &
729 : k=exchange_work%dimen_RI, alpha=1.0_dp, &
730 : matrix_a=fm_mat_S(ispin), matrix_b=exchange_work%fm_mat_R_half_gemm, beta=0.0_dp, &
731 2 : matrix_c=fm_mat_Gamma_3)
732 :
733 : ! Remove extra factor from S after the multiplication (to return to the original matrix)
734 2 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
735 :
736 : ! Copy Gamma_ia_P^3 to dbcsr matrix set
737 : CALL gamma_fm_to_dbcsr(fm_mat_Gamma_3, dbcsr_Gamma_3(ispin)%matrix_set, &
738 : para_env, exchange_work%para_env_sub, hom, virt, &
739 : exchange_work%exchange_env%mo_coeff_o(ispin), &
740 6 : exchange_work%ngroup, my_aux_start, my_aux_end, my_aux_size)
741 : END DO
742 :
743 85 : DO L_counter = 1, my_aux_size
744 166 : DO ispin = 1, nspins
745 : ! Do dbcsr multiplication: transform the virtual index
746 : CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mo_coeff_v(ispin), &
747 : dbcsr_Gamma_3(ispin)%matrix_set(L_counter), &
748 : 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
749 83 : filter_eps=exchange_work%exchange_env%eps_filter)
750 :
751 83 : CALL dbcsr_release(dbcsr_Gamma_3(ispin)%matrix_set(L_counter))
752 :
753 : ! Do dbcsr multiplication: transform the occupied index
754 : CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
755 : exchange_work%exchange_env%mo_coeff_o(ispin), &
756 : 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
757 83 : filter_eps=exchange_work%exchange_env%eps_filter)
758 : CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%mo_coeff_o(ispin), &
759 : exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
760 : 1.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
761 83 : filter_eps=exchange_work%exchange_env%eps_filter)
762 :
763 166 : CALL dbcsr_set(exchange_work%exchange_env%mat_hfx(ispin)%matrix, 0.0_dp)
764 : END DO
765 :
766 : CALL tddft_hfx_matrix(exchange_work%exchange_env%mat_hfx, exchange_work%exchange_env%dbcsr_Gamma_munu_P, &
767 : exchange_work%exchange_env%qs_env, .FALSE., &
768 : exchange_work%exchange_env%my_recalc_hfx_integrals, &
769 : exchange_work%exchange_env%hfx_sections, exchange_work%exchange_env%x_data, &
770 83 : exchange_work%exchange_env%para_env)
771 :
772 83 : exchange_work%exchange_env%my_recalc_hfx_integrals = .FALSE.
773 168 : DO ispin = 1, nspins
774 : CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mat_hfx(ispin)%matrix, &
775 : exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
776 83 : 0.0_dp, exchange_work%exchange_env%work_ao, filter_eps=exchange_work%exchange_env%eps_filter)
777 83 : CALL dbcsr_trace(exchange_work%exchange_env%work_ao, e_exchange_P)
778 166 : e_exchange_corr = e_exchange_corr - e_exchange_P
779 : END DO
780 : END DO
781 :
782 : IF (nspins == 2) e_exchange_corr = e_exchange_corr
783 2 : IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
784 :
785 2 : CALL timestop(handle)
786 :
787 6 : END SUBROUTINE rpa_exchange_work_compute_hfx
788 :
789 : ! **************************************************************************************************
790 : !> \brief ...
791 : !> \param exchange_work ...
792 : !> \param fm_mat ...
793 : !> \param mat ...
794 : !> \param ispin ...
795 : !> \param virt_dist ...
796 : ! **************************************************************************************************
797 24 : SUBROUTINE redistribute_into_subgroups(exchange_work, fm_mat, mat, ispin, virt_dist)
798 : CLASS(rpa_exchange_work_type), INTENT(IN) :: exchange_work
799 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat
800 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
801 : INTENT(OUT) :: mat
802 : INTEGER, INTENT(IN) :: ispin
803 : TYPE(group_dist_d1_type), INTENT(IN) :: virt_dist
804 :
805 : CHARACTER(LEN=*), PARAMETER :: routineN = 'redistribute_into_subgroups'
806 :
807 : INTEGER :: aux_counter, aux_global, aux_local, aux_proc, avirt, dimen_RI, handle, handle2, &
808 : ia_global, ia_local, iocc, max_number_recv, max_number_send, my_aux_end, my_aux_size, &
809 : my_aux_start, my_process_column, my_process_row, my_virt_end, my_virt_size, &
810 : my_virt_start, proc, proc_shift, recv_proc, send_proc, virt_counter, virt_proc, group_size
811 24 : INTEGER, ALLOCATABLE, DIMENSION(:) :: data2send, recv_col_indices, &
812 24 : recv_row_indices, send_aux_indices, send_virt_indices, virt2send
813 : INTEGER, DIMENSION(2) :: recv_shape
814 24 : INTEGER, DIMENSION(:), POINTER :: aux_distribution_fm, col_indices, &
815 24 : ia_distribution_fm, row_indices
816 24 : INTEGER, DIMENSION(:, :), POINTER :: mpi2blacs
817 24 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: recv_buffer, send_buffer
818 : REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
819 24 : POINTER :: recv_ptr, send_ptr
820 : TYPE(cp_blacs_env_type), POINTER :: context
821 : TYPE(mp_para_env_type), POINTER :: para_env
822 :
823 24 : CALL timeset(routineN, handle)
824 :
825 : CALL cp_fm_get_info(matrix=fm_mat, &
826 : nrow_locals=aux_distribution_fm, &
827 : col_indices=col_indices, &
828 : row_indices=row_indices, &
829 : ncol_locals=ia_distribution_fm, &
830 : context=context, &
831 : nrow_global=dimen_RI, &
832 24 : para_env=para_env)
833 :
834 24 : IF (exchange_work%homo(ispin) <= 0 .OR. exchange_work%virtual(ispin) <= 0) THEN
835 0 : CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
836 0 : ALLOCATE (mat(exchange_work%homo(ispin), my_virt_size, dimen_RI))
837 0 : CALL timestop(handle)
838 0 : RETURN
839 : END IF
840 :
841 24 : group_size = exchange_work%para_env_sub%num_pe
842 :
843 24 : CALL timeset(routineN//"_prep", handle2)
844 24 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
845 24 : CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
846 24 : CALL context%get(my_process_column=my_process_column, my_process_row=my_process_row, mpi2blacs=mpi2blacs)
847 :
848 : ! Determine the number of columns to send
849 120 : ALLOCATE (send_aux_indices(MAXVAL(exchange_work%aux2send)))
850 72 : ALLOCATE (virt2send(0:group_size - 1))
851 48 : virt2send = 0
852 1924 : DO ia_local = 1, ia_distribution_fm(my_process_column)
853 1900 : ia_global = col_indices(ia_local)
854 1900 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
855 1900 : proc = group_dist_proc(virt_dist, avirt)
856 1924 : virt2send(proc) = virt2send(proc) + 1
857 : END DO
858 :
859 72 : ALLOCATE (data2send(0:para_env%num_pe - 1))
860 72 : DO aux_proc = 0, exchange_work%ngroup - 1
861 120 : DO virt_proc = 0, group_size - 1
862 96 : data2send(aux_proc*group_size + virt_proc) = exchange_work%aux2send(aux_proc)*virt2send(virt_proc)
863 : END DO
864 : END DO
865 :
866 96 : ALLOCATE (send_virt_indices(MAXVAL(virt2send)))
867 72 : max_number_send = MAXVAL(data2send)
868 :
869 72 : ALLOCATE (send_buffer(INT(max_number_send, KIND=int_8)*exchange_work%homo(ispin)))
870 24 : max_number_recv = max_number_send
871 24 : CALL para_env%max(max_number_recv)
872 72 : ALLOCATE (recv_buffer(max_number_recv))
873 :
874 120 : ALLOCATE (mat(my_virt_size, exchange_work%homo(ispin), my_aux_size))
875 :
876 24 : CALL timestop(handle2)
877 :
878 24 : CALL timeset(routineN//"_own", handle2)
879 : ! Start with own data
880 1026 : DO aux_local = 1, aux_distribution_fm(my_process_row)
881 1002 : aux_global = row_indices(aux_local)
882 1002 : IF (aux_global < my_aux_start .OR. aux_global > my_aux_end) CYCLE
883 40848 : DO ia_local = 1, ia_distribution_fm(my_process_column)
884 40318 : ia_global = fm_mat%matrix_struct%col_indices(ia_local)
885 :
886 40318 : iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
887 40318 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
888 :
889 40318 : IF (my_virt_start > avirt .OR. my_virt_end < avirt) CYCLE
890 :
891 41320 : mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = fm_mat%local_data(aux_local, ia_local)
892 : END DO
893 : END DO
894 24 : CALL timestop(handle2)
895 :
896 48 : DO proc_shift = 1, para_env%num_pe - 1
897 24 : send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
898 24 : recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
899 :
900 24 : CALL timeset(routineN//"_pack_buffer", handle2)
901 : send_ptr(1:virt2send(MOD(send_proc, group_size)), &
902 : 1:exchange_work%aux2send(send_proc/group_size)) => &
903 : send_buffer(1:INT(virt2send(MOD(send_proc, group_size)), KIND=int_8)* &
904 24 : exchange_work%aux2send(send_proc/group_size))
905 : ! Pack send buffer
906 24 : aux_counter = 0
907 1026 : DO aux_local = 1, aux_distribution_fm(my_process_row)
908 1002 : aux_global = row_indices(aux_local)
909 1002 : proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
910 1002 : IF (proc /= send_proc/group_size) CYCLE
911 496 : aux_counter = aux_counter + 1
912 496 : virt_counter = 0
913 40016 : DO ia_local = 1, ia_distribution_fm(my_process_column)
914 39520 : ia_global = col_indices(ia_local)
915 39520 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
916 :
917 39520 : proc = group_dist_proc(virt_dist, avirt)
918 39520 : IF (proc /= MOD(send_proc, group_size)) CYCLE
919 39520 : virt_counter = virt_counter + 1
920 39520 : send_ptr(virt_counter, aux_counter) = fm_mat%local_data(aux_local, ia_local)
921 40016 : send_virt_indices(virt_counter) = ia_global
922 : END DO
923 1026 : send_aux_indices(aux_counter) = aux_global
924 : END DO
925 24 : CALL timestop(handle2)
926 :
927 24 : CALL timeset(routineN//"_ex_size", handle2)
928 24 : recv_shape = [1, 1]
929 72 : CALL para_env%sendrecv(SHAPE(send_ptr), send_proc, recv_shape, recv_proc)
930 24 : CALL timestop(handle2)
931 :
932 72 : IF (SIZE(send_ptr) == 0) send_proc = mp_proc_null
933 72 : IF (PRODUCT(recv_shape) == 0) recv_proc = mp_proc_null
934 :
935 24 : CALL timeset(routineN//"_ex_idx", handle2)
936 120 : ALLOCATE (recv_row_indices(recv_shape(1)), recv_col_indices(recv_shape(2)))
937 24 : CALL para_env%sendrecv(send_virt_indices(1:virt_counter), send_proc, recv_row_indices, recv_proc)
938 24 : CALL para_env%sendrecv(send_aux_indices(1:aux_counter), send_proc, recv_col_indices, recv_proc)
939 24 : CALL timestop(handle2)
940 :
941 : ! Prepare pointer to recv buffer (consider transposition while packing the send buffer)
942 24 : recv_ptr(1:recv_shape(1), 1:MAX(1, recv_shape(2))) => recv_buffer(1:recv_shape(1)*MAX(1, recv_shape(2)))
943 :
944 24 : CALL timeset(routineN//"_sendrecv", handle2)
945 : ! Perform communication
946 24 : CALL para_env%sendrecv(send_ptr, send_proc, recv_ptr, recv_proc)
947 24 : CALL timestop(handle2)
948 :
949 24 : IF (recv_proc == mp_proc_null) THEN
950 0 : DEALLOCATE (recv_row_indices, recv_col_indices)
951 0 : CYCLE
952 : END IF
953 :
954 24 : CALL timeset(routineN//"_unpack", handle2)
955 : ! Unpack receive buffer
956 520 : DO aux_local = 1, SIZE(recv_col_indices)
957 496 : aux_global = recv_col_indices(aux_local)
958 :
959 40040 : DO ia_local = 1, SIZE(recv_row_indices)
960 39520 : ia_global = recv_row_indices(ia_local)
961 :
962 39520 : iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
963 39520 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
964 :
965 40016 : mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = recv_ptr(ia_local, aux_local)
966 : END DO
967 : END DO
968 24 : CALL timestop(handle2)
969 :
970 24 : IF (ALLOCATED(recv_row_indices)) DEALLOCATE (recv_row_indices)
971 168 : IF (ALLOCATED(recv_col_indices)) DEALLOCATE (recv_col_indices)
972 : END DO
973 :
974 24 : DEALLOCATE (send_aux_indices, send_virt_indices)
975 :
976 24 : CALL timestop(handle)
977 :
978 96 : END SUBROUTINE redistribute_into_subgroups
979 :
980 0 : END MODULE rpa_exchange
|