Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Auxiliary routines needed for RPA-exchange
10 : !> given blacs_env to another
11 : !> \par History
12 : !> 09.2016 created [Vladimir Rybkin]
13 : !> 03.2019 Renamed [Frederick Stein]
14 : !> 03.2019 Moved Functions from rpa_ri_gpw.F [Frederick Stein]
15 : !> 04.2024 Added open-shell calculations, SOSEX [Frederick Stein]
16 : !> \author Vladimir Rybkin
17 : ! **************************************************************************************************
18 : MODULE rpa_exchange
19 : USE atomic_kind_types, ONLY: atomic_kind_type
20 : USE cell_types, ONLY: cell_type
21 : USE cp_blacs_env, ONLY: cp_blacs_env_type
22 : USE cp_control_types, ONLY: dft_control_type
23 : USE cp_dbcsr_api, ONLY: &
24 : dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_init_p, dbcsr_multiply, dbcsr_p_type, &
25 : dbcsr_release, dbcsr_set, dbcsr_trace, dbcsr_type, dbcsr_type_no_symmetry
26 : USE cp_dbcsr_operations, ONLY: dbcsr_allocate_matrix_set
27 : USE cp_fm_basic_linalg, ONLY: cp_fm_column_scale
28 : USE cp_fm_diag, ONLY: choose_eigv_solver
29 : USE cp_fm_struct, ONLY: cp_fm_struct_create,&
30 : cp_fm_struct_p_type,&
31 : cp_fm_struct_release
32 : USE cp_fm_types, ONLY: cp_fm_create,&
33 : cp_fm_get_info,&
34 : cp_fm_release,&
35 : cp_fm_set_all,&
36 : cp_fm_to_fm,&
37 : cp_fm_to_fm_submat_general,&
38 : cp_fm_type
39 : USE group_dist_types, ONLY: create_group_dist,&
40 : get_group_dist,&
41 : group_dist_d1_type,&
42 : group_dist_proc,&
43 : maxsize,&
44 : release_group_dist
45 : USE hfx_admm_utils, ONLY: tddft_hfx_matrix
46 : USE hfx_types, ONLY: hfx_create,&
47 : hfx_release,&
48 : hfx_type
49 : USE input_constants, ONLY: rpa_exchange_axk,&
50 : rpa_exchange_none,&
51 : rpa_exchange_sosex
52 : USE input_section_types, ONLY: section_vals_get_subs_vals,&
53 : section_vals_type
54 : USE kinds, ONLY: dp,&
55 : int_8
56 : USE local_gemm_api, ONLY: LOCAL_GEMM_PU_GPU
57 : USE mathconstants, ONLY: sqrthalf
58 : USE message_passing, ONLY: mp_para_env_type,&
59 : mp_proc_null
60 : USE mp2_types, ONLY: mp2_type
61 : USE parallel_gemm_api, ONLY: parallel_gemm
62 : USE particle_types, ONLY: particle_type
63 : USE qs_environment_types, ONLY: get_qs_env,&
64 : qs_environment_type
65 : USE qs_kind_types, ONLY: qs_kind_type
66 : USE qs_subsys_types, ONLY: qs_subsys_get,&
67 : qs_subsys_type
68 : USE rpa_communication, ONLY: gamma_fm_to_dbcsr
69 : USE rpa_util, ONLY: calc_fm_mat_S_rpa,&
70 : remove_scaling_factor_rpa
71 : USE scf_control_types, ONLY: scf_control_type
72 : #include "./base/base_uses.f90"
73 :
74 : IMPLICIT NONE
75 :
76 : PRIVATE
77 :
78 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_exchange'
79 :
80 : PUBLIC :: rpa_exchange_work_type, rpa_exchange_needed_mem
81 :
82 : TYPE rpa_exchange_env_type
83 : PRIVATE
84 : TYPE(qs_environment_type), POINTER :: qs_env => NULL()
85 : TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: mat_hfx => NULL()
86 : TYPE(dbcsr_p_type), DIMENSION(:), POINTER :: dbcsr_Gamma_munu_P => NULL()
87 : TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: dbcsr_Gamma_inu_P
88 : ! Workaround GCC 8
89 : TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_o => NULL()
90 : TYPE(dbcsr_type), DIMENSION(:), POINTER :: mo_coeff_v => NULL()
91 : TYPE(dbcsr_type) :: work_ao
92 : TYPE(hfx_type), DIMENSION(:, :), POINTER :: x_data => NULL()
93 : TYPE(mp_para_env_type), POINTER :: para_env => NULL()
94 : TYPE(section_vals_type), POINTER :: hfx_sections => NULL()
95 : LOGICAL :: my_recalc_hfx_integrals = .FALSE.
96 : REAL(KIND=dp) :: eps_filter = 0.0_dp
97 : TYPE(cp_fm_struct_p_type), DIMENSION(:), ALLOCATABLE :: struct_Gamma
98 : CONTAINS
99 : PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: create => hfx_create_subgroup
100 : !PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: integrate => integrate_exchange
101 : PROCEDURE, PASS(exchange_env), NON_OVERRIDABLE :: release => hfx_release_subgroup
102 : END TYPE
103 :
104 : TYPE dbcsr_matrix_p_set
105 : TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:) :: matrix_set
106 : END TYPE
107 :
108 : TYPE rpa_exchange_work_type
109 : PRIVATE
110 : INTEGER :: exchange_correction = rpa_exchange_none
111 : TYPE(rpa_exchange_env_type) :: exchange_env
112 : INTEGER, DIMENSION(:), ALLOCATABLE :: homo, virtual, dimen_ia
113 : TYPE(group_dist_d1_type) :: aux_func_dist = group_dist_d1_type()
114 : INTEGER, DIMENSION(:), ALLOCATABLE :: aux2send
115 : INTEGER :: dimen_RI = 0
116 : INTEGER :: block_size = 0
117 : INTEGER :: color_sub = 0
118 : INTEGER :: ngroup = 0
119 : TYPE(cp_fm_type) :: fm_mat_Q_tmp = cp_fm_type()
120 : TYPE(cp_fm_type) :: fm_mat_R_half_gemm = cp_fm_type()
121 : TYPE(cp_fm_type) :: fm_mat_U = cp_fm_type()
122 : TYPE(mp_para_env_type), POINTER :: para_env_sub => NULL()
123 : CONTAINS
124 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: create => rpa_exchange_work_create
125 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: compute => rpa_exchange_work_compute
126 : PROCEDURE, PUBLIC, PASS(exchange_work), NON_OVERRIDABLE :: release => rpa_exchange_work_release
127 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: redistribute_into_subgroups
128 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_fm => rpa_exchange_work_compute_fm
129 : PROCEDURE, PRIVATE, PASS(exchange_work), NON_OVERRIDABLE :: compute_hfx => rpa_exchange_work_compute_hfx
130 : END TYPE
131 :
132 : CONTAINS
133 :
134 : ! **************************************************************************************************
135 : !> \brief ...
136 : !> \param mp2_env ...
137 : !> \param homo ...
138 : !> \param virtual ...
139 : !> \param dimen_RI ...
140 : !> \param para_env ...
141 : !> \param mem_per_rank ...
142 : !> \param mem_per_repl ...
143 : ! **************************************************************************************************
144 128 : SUBROUTINE rpa_exchange_needed_mem(mp2_env, homo, virtual, dimen_RI, para_env, mem_per_rank, mem_per_repl)
145 : TYPE(mp2_type), INTENT(IN) :: mp2_env
146 : INTEGER, DIMENSION(:), INTENT(IN) :: homo, virtual
147 : INTEGER, INTENT(IN) :: dimen_RI
148 : TYPE(mp_para_env_type), INTENT(IN) :: para_env
149 : REAL(KIND=dp), INTENT(INOUT) :: mem_per_rank, mem_per_repl
150 :
151 : INTEGER :: block_size
152 :
153 : ! We need the block size and if it is unknown, an upper bound
154 128 : block_size = mp2_env%ri_rpa%exchange_block_size
155 128 : IF (block_size <= 0) block_size = MAX(1, (dimen_RI + para_env%num_pe - 1)/para_env%num_pe)
156 :
157 : ! storage of product matrix (upper bound only as it depends on the square of the potential still unknown block size)
158 284 : mem_per_rank = mem_per_rank + REAL(MAXVAL(homo), KIND=dp)**2*block_size**2*8.0_dp/(1024_dp**2)
159 :
160 : ! work arrays R (2x) and U, copies of Gamma (2x), communication buffer (as expensive as Gamma)
161 : mem_per_repl = mem_per_repl + 3.0_dp*dimen_RI*dimen_RI*8.0_dp/(1024_dp**2) &
162 284 : + 3.0_dp*MAXVAL(homo*virtual)*dimen_RI*8.0_dp/(1024_dp**2)
163 128 : END SUBROUTINE rpa_exchange_needed_mem
164 :
165 : ! **************************************************************************************************
166 : !> \brief ...
167 : !> \param exchange_work ...
168 : !> \param qs_env ...
169 : !> \param para_env_sub ...
170 : !> \param mat_munu ...
171 : !> \param dimen_RI ...
172 : !> \param fm_mat_S ...
173 : !> \param fm_mat_Q ...
174 : !> \param fm_mat_Q_gemm ...
175 : !> \param homo ...
176 : !> \param virtual ...
177 : ! **************************************************************************************************
178 134 : SUBROUTINE rpa_exchange_work_create(exchange_work, qs_env, para_env_sub, mat_munu, dimen_RI, &
179 134 : fm_mat_S, fm_mat_Q, fm_mat_Q_gemm, homo, virtual)
180 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
181 : TYPE(qs_environment_type), POINTER :: qs_env
182 : TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
183 : TYPE(dbcsr_p_type), INTENT(IN) :: mat_munu
184 : INTEGER, INTENT(IN) :: dimen_RI
185 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
186 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q, fm_mat_Q_gemm
187 : INTEGER, DIMENSION(SIZE(fm_mat_S)), INTENT(IN) :: homo, virtual
188 :
189 : INTEGER :: nspins, aux_global, aux_local, my_process_row, proc, ispin
190 134 : INTEGER, DIMENSION(:), POINTER :: row_indices, aux_distribution_fm
191 : TYPE(cp_blacs_env_type), POINTER :: context
192 :
193 134 : exchange_work%exchange_correction = qs_env%mp2_env%ri_rpa%exchange_correction
194 :
195 134 : IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
196 :
197 : ASSOCIATE (para_env => fm_mat_S(1)%matrix_struct%para_env)
198 12 : exchange_work%para_env_sub => para_env_sub
199 12 : exchange_work%ngroup = para_env%num_pe/para_env_sub%num_pe
200 12 : exchange_work%color_sub = para_env%mepos/para_env_sub%num_pe
201 : END ASSOCIATE
202 :
203 12 : CALL cp_fm_get_info(fm_mat_S(1), row_indices=row_indices, nrow_locals=aux_distribution_fm, context=context)
204 12 : CALL context%get(my_process_row=my_process_row)
205 :
206 12 : CALL create_group_dist(exchange_work%aux_func_dist, exchange_work%ngroup, dimen_RI)
207 36 : ALLOCATE (exchange_work%aux2send(0:exchange_work%ngroup - 1))
208 36 : exchange_work%aux2send = 0
209 499 : DO aux_local = 1, aux_distribution_fm(my_process_row)
210 487 : aux_global = row_indices(aux_local)
211 487 : proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
212 499 : exchange_work%aux2send(proc) = exchange_work%aux2send(proc) + 1
213 : END DO
214 :
215 12 : nspins = SIZE(fm_mat_S)
216 :
217 60 : ALLOCATE (exchange_work%homo(nspins), exchange_work%virtual(nspins), exchange_work%dimen_ia(nspins))
218 26 : exchange_work%homo(:) = homo
219 26 : exchange_work%virtual(:) = virtual
220 26 : exchange_work%dimen_ia(:) = homo*virtual
221 12 : exchange_work%dimen_RI = dimen_RI
222 :
223 12 : exchange_work%block_size = qs_env%mp2_env%ri_rpa%exchange_block_size
224 12 : IF (exchange_work%block_size <= 0) exchange_work%block_size = dimen_RI
225 :
226 12 : CALL cp_fm_create(exchange_work%fm_mat_U, fm_mat_Q%matrix_struct, name="fm_mat_U")
227 12 : CALL cp_fm_create(exchange_work%fm_mat_Q_tmp, fm_mat_Q%matrix_struct, name="fm_mat_Q_tmp")
228 12 : CALL cp_fm_create(exchange_work%fm_mat_R_half_gemm, fm_mat_Q_gemm%matrix_struct)
229 :
230 12 : IF (qs_env%mp2_env%ri_rpa%use_hfx_implementation) THEN
231 2 : CALL exchange_work%exchange_env%create(qs_env, mat_munu%matrix, para_env_sub, fm_mat_S)
232 : END IF
233 :
234 12 : IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_o)) THEN
235 22 : DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_o)
236 22 : CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_o(ispin))
237 : END DO
238 10 : DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_o)
239 : END IF
240 :
241 12 : IF (ASSOCIATED(qs_env%mp2_env%ri_rpa%mo_coeff_v)) THEN
242 22 : DO ispin = 1, SIZE(qs_env%mp2_env%ri_rpa%mo_coeff_v)
243 22 : CALL dbcsr_release(qs_env%mp2_env%ri_rpa%mo_coeff_v(ispin))
244 : END DO
245 10 : DEALLOCATE (qs_env%mp2_env%ri_rpa%mo_coeff_v)
246 : END IF
247 134 : END SUBROUTINE
248 :
249 : ! **************************************************************************************************
250 : !> \brief ... Initializes x_data on a subgroup
251 : !> \param exchange_env ...
252 : !> \param qs_env ...
253 : !> \param mat_munu ...
254 : !> \param para_env_sub ...
255 : !> \param fm_mat_S ...
256 : !> \author Vladimir Rybkin
257 : ! **************************************************************************************************
258 2 : SUBROUTINE hfx_create_subgroup(exchange_env, qs_env, mat_munu, para_env_sub, fm_mat_S)
259 : CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
260 : TYPE(dbcsr_type), INTENT(IN) :: mat_munu
261 : TYPE(qs_environment_type), POINTER :: qs_env
262 : TYPE(mp_para_env_type), POINTER, INTENT(IN) :: para_env_sub
263 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
264 :
265 : CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_create_subgroup'
266 :
267 : INTEGER :: handle, nelectron_total, ispin, &
268 : number_of_aos, nspins, dimen_RI, dimen_ia
269 2 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
270 : TYPE(cell_type), POINTER :: my_cell
271 : TYPE(dft_control_type), POINTER :: dft_control
272 2 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
273 2 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
274 : TYPE(qs_subsys_type), POINTER :: subsys
275 : TYPE(scf_control_type), POINTER :: scf_control
276 : TYPE(section_vals_type), POINTER :: input
277 :
278 2 : CALL timeset(routineN, handle)
279 :
280 2 : exchange_env%mo_coeff_o => qs_env%mp2_env%ri_rpa%mo_coeff_o
281 2 : exchange_env%mo_coeff_v => qs_env%mp2_env%ri_rpa%mo_coeff_v
282 2 : NULLIFY (qs_env%mp2_env%ri_rpa%mo_coeff_o, qs_env%mp2_env%ri_rpa%mo_coeff_v)
283 :
284 2 : nspins = SIZE(exchange_env%mo_coeff_o)
285 :
286 2 : exchange_env%qs_env => qs_env
287 2 : exchange_env%para_env => para_env_sub
288 2 : exchange_env%eps_filter = qs_env%mp2_env%mp2_gpw%eps_filter
289 :
290 2 : NULLIFY (my_cell, atomic_kind_set, particle_set, dft_control, qs_kind_set, scf_control)
291 :
292 : CALL get_qs_env(qs_env, &
293 : subsys=subsys, &
294 : input=input, &
295 : scf_control=scf_control, &
296 2 : nelectron_total=nelectron_total)
297 :
298 : CALL qs_subsys_get(subsys, &
299 : cell=my_cell, &
300 : atomic_kind_set=atomic_kind_set, &
301 : qs_kind_set=qs_kind_set, &
302 2 : particle_set=particle_set)
303 :
304 2 : exchange_env%hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")
305 2 : CALL get_qs_env(qs_env, dft_control=dft_control)
306 :
307 : ! Retrieve particle_set and atomic_kind_set
308 : CALL hfx_create(exchange_env%x_data, para_env_sub, exchange_env%hfx_sections, atomic_kind_set, &
309 : qs_kind_set, particle_set, dft_control, my_cell, orb_basis='ORB', &
310 2 : nelectron_total=nelectron_total)
311 :
312 2 : exchange_env%my_recalc_hfx_integrals = .TRUE.
313 :
314 2 : CALL dbcsr_allocate_matrix_set(exchange_env%mat_hfx, nspins)
315 4 : DO ispin = 1, nspins
316 2 : ALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
317 2 : CALL dbcsr_init_p(exchange_env%mat_hfx(ispin)%matrix)
318 : CALL dbcsr_create(exchange_env%mat_hfx(ispin)%matrix, template=mat_munu, &
319 2 : matrix_type=dbcsr_type_no_symmetry)
320 4 : CALL dbcsr_copy(exchange_env%mat_hfx(ispin)%matrix, mat_munu)
321 : END DO
322 :
323 2 : CALL dbcsr_get_info(mat_munu, nfullcols_total=number_of_aos)
324 :
325 : CALL dbcsr_create(exchange_env%work_ao, template=mat_munu, &
326 2 : matrix_type=dbcsr_type_no_symmetry)
327 :
328 8 : ALLOCATE (exchange_env%dbcsr_Gamma_inu_P(nspins))
329 2 : CALL dbcsr_allocate_matrix_set(exchange_env%dbcsr_Gamma_munu_P, nspins)
330 4 : DO ispin = 1, nspins
331 2 : ALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
332 : CALL dbcsr_create(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, template=mat_munu, &
333 2 : matrix_type=dbcsr_type_no_symmetry)
334 2 : CALL dbcsr_copy(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, mat_munu)
335 2 : CALL dbcsr_set(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, 0.0_dp)
336 :
337 2 : CALL dbcsr_create(exchange_env%dbcsr_Gamma_inu_P(ispin), template=exchange_env%mo_coeff_o(ispin))
338 2 : CALL dbcsr_copy(exchange_env%dbcsr_Gamma_inu_P(ispin), exchange_env%mo_coeff_o(ispin))
339 4 : CALL dbcsr_set(exchange_env%dbcsr_Gamma_inu_P(ispin), 0.0_dp)
340 : END DO
341 :
342 8 : ALLOCATE (exchange_env%struct_Gamma(nspins))
343 4 : DO ispin = 1, nspins
344 2 : CALL cp_fm_get_info(fm_mat_S(ispin), nrow_global=dimen_RI, ncol_global=dimen_ia)
345 : CALL cp_fm_struct_create(exchange_env%struct_Gamma(ispin)%struct, template_fmstruct=fm_mat_S(ispin)%matrix_struct, &
346 4 : nrow_global=dimen_ia, ncol_global=dimen_RI)
347 : END DO
348 :
349 2 : CALL timestop(handle)
350 :
351 4 : END SUBROUTINE hfx_create_subgroup
352 :
353 : ! **************************************************************************************************
354 : !> \brief ...
355 : !> \param exchange_work ...
356 : ! **************************************************************************************************
357 134 : SUBROUTINE rpa_exchange_work_release(exchange_work)
358 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
359 :
360 134 : IF (ALLOCATED(exchange_work%homo)) DEALLOCATE (exchange_work%homo)
361 134 : IF (ALLOCATED(exchange_work%virtual)) DEALLOCATE (exchange_work%virtual)
362 134 : IF (ALLOCATED(exchange_work%dimen_ia)) DEALLOCATE (exchange_work%dimen_ia)
363 134 : NULLIFY (exchange_work%para_env_sub)
364 134 : CALL release_group_dist(exchange_work%aux_func_dist)
365 134 : IF (ALLOCATED(exchange_work%aux2send)) DEALLOCATE (exchange_work%aux2send)
366 134 : CALL cp_fm_release(exchange_work%fm_mat_Q_tmp)
367 134 : CALL cp_fm_release(exchange_work%fm_mat_U)
368 134 : CALL cp_fm_release(exchange_work%fm_mat_R_half_gemm)
369 :
370 134 : CALL exchange_work%exchange_env%release()
371 134 : END SUBROUTINE
372 :
373 : ! **************************************************************************************************
374 : !> \brief ...
375 : !> \param exchange_env ...
376 : ! **************************************************************************************************
377 134 : SUBROUTINE hfx_release_subgroup(exchange_env)
378 : CLASS(rpa_exchange_env_type), INTENT(INOUT) :: exchange_env
379 :
380 : INTEGER :: ispin
381 :
382 134 : NULLIFY (exchange_env%para_env, exchange_env%hfx_sections)
383 :
384 134 : IF (ASSOCIATED(exchange_env%x_data)) THEN
385 2 : CALL hfx_release(exchange_env%x_data)
386 2 : NULLIFY (exchange_env%x_data)
387 : END IF
388 :
389 134 : CALL dbcsr_release(exchange_env%work_ao)
390 :
391 134 : IF (ASSOCIATED(exchange_env%dbcsr_Gamma_munu_P)) THEN
392 4 : DO ispin = 1, SIZE(exchange_env%mat_hfx, 1)
393 2 : CALL dbcsr_release(exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
394 2 : CALL dbcsr_release(exchange_env%mat_hfx(ispin)%matrix)
395 2 : CALL dbcsr_release(exchange_env%dbcsr_Gamma_inu_P(ispin))
396 2 : CALL dbcsr_release(exchange_env%mo_coeff_o(ispin))
397 2 : CALL dbcsr_release(exchange_env%mo_coeff_v(ispin))
398 2 : DEALLOCATE (exchange_env%mat_hfx(ispin)%matrix)
399 4 : DEALLOCATE (exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix)
400 : END DO
401 2 : DEALLOCATE (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
402 2 : DEALLOCATE (exchange_env%dbcsr_Gamma_inu_P, exchange_env%mo_coeff_o, exchange_env%mo_coeff_v)
403 2 : NULLIFY (exchange_env%mat_hfx, exchange_env%dbcsr_Gamma_munu_P)
404 : END IF
405 134 : IF (ALLOCATED(exchange_env%struct_Gamma)) THEN
406 4 : DO ispin = 1, SIZE(exchange_env%struct_Gamma)
407 4 : CALL cp_fm_struct_release(exchange_env%struct_Gamma(ispin)%struct)
408 : END DO
409 2 : DEALLOCATE (exchange_env%struct_Gamma)
410 : END IF
411 134 : END SUBROUTINE hfx_release_subgroup
412 :
413 : ! **************************************************************************************************
414 : !> \brief Main driver for RPA-exchange energies
415 : !> \param exchange_work ...
416 : !> \param fm_mat_Q ...
417 : !> \param eig ...
418 : !> \param fm_mat_S ...
419 : !> \param omega ...
420 : !> \param e_exchange_corr exchange energy correction for a quadrature point
421 : !> \param mp2_env ...
422 : !> \author Vladimir Rybkin, 07/2016
423 : ! **************************************************************************************************
424 12 : SUBROUTINE rpa_exchange_work_compute(exchange_work, fm_mat_Q, eig, fm_mat_S, omega, &
425 : e_exchange_corr, mp2_env)
426 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
427 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat_Q
428 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
429 : TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
430 : REAL(KIND=dp), INTENT(IN) :: omega
431 : REAL(KIND=dp), INTENT(INOUT) :: e_exchange_corr
432 : TYPE(mp2_type), INTENT(INOUT) :: mp2_env
433 :
434 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute'
435 : REAL(KIND=dp), PARAMETER :: thresh = 0.0000001_dp
436 :
437 : INTEGER :: handle, nspins, dimen_RI, iiB
438 12 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: eigenval
439 :
440 12 : IF (exchange_work%exchange_correction == rpa_exchange_none) RETURN
441 :
442 12 : CALL timeset(routineN, handle)
443 :
444 12 : CALL cp_fm_get_info(fm_mat_Q, ncol_global=dimen_RI)
445 :
446 12 : nspins = SIZE(fm_mat_S)
447 :
448 : ! Eigenvalues
449 36 : ALLOCATE (eigenval(dimen_RI))
450 986 : eigenval = 0.0_dp
451 :
452 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_Q_tmp, alpha=0.0_dp)
453 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_U, alpha=0.0_dp)
454 :
455 : ! Copy Q to Q_tmp
456 12 : CALL cp_fm_to_fm(fm_mat_Q, exchange_work%fm_mat_Q_tmp)
457 : ! Diagonalize Q
458 12 : CALL choose_eigv_solver(exchange_work%fm_mat_Q_tmp, exchange_work%fm_mat_U, eigenval)
459 :
460 : ! Calculate diagonal matrix for R_half
461 :
462 : ! Manipulate eigenvalues to get diagonal matrix
463 12 : IF (exchange_work%exchange_correction == rpa_exchange_axk) THEN
464 818 : DO iib = 1, dimen_RI
465 818 : IF (ABS(eigenval(iib)) .GE. thresh) THEN
466 : eigenval(iib) = &
467 : SQRT((1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
468 718 : - 1.0_dp/(eigenval(iib)*(eigenval(iib) + 1.0_dp)))
469 : ELSE
470 90 : eigenval(iib) = sqrthalf
471 : END IF
472 : END DO
473 2 : ELSE IF (exchange_work%exchange_correction == rpa_exchange_sosex) THEN
474 168 : DO iib = 1, dimen_RI
475 168 : IF (ABS(eigenval(iib)) .GE. thresh) THEN
476 : eigenval(iib) = &
477 : SQRT(-(1.0_dp/(eigenval(iib)**2))*LOG(1.0_dp + eigenval(iib)) &
478 144 : + 1.0_dp/eigenval(iib))
479 : ELSE
480 22 : eigenval(iib) = sqrthalf
481 : END IF
482 : END DO
483 : ELSE
484 0 : CPABORT("Unknown RPA exchange correction")
485 : END IF
486 :
487 : ! fm_mat_U now contains some sqrt of the required matrix-valued function
488 12 : CALL cp_fm_column_scale(exchange_work%fm_mat_U, eigenval)
489 :
490 : ! Release memory
491 12 : DEALLOCATE (eigenval)
492 :
493 : ! Redistribute fm_mat_U for "rectangular" multiplication: ia*P P*P
494 12 : CALL cp_fm_set_all(matrix=exchange_work%fm_mat_R_half_gemm, alpha=0.0_dp)
495 :
496 : CALL cp_fm_to_fm_submat_general(exchange_work%fm_mat_U, exchange_work%fm_mat_R_half_gemm, dimen_RI, &
497 12 : dimen_RI, 1, 1, 1, 1, exchange_work%fm_mat_U%matrix_struct%context)
498 :
499 12 : IF (mp2_env%ri_rpa%use_hfx_implementation) THEN
500 2 : CALL exchange_work%compute_hfx(fm_mat_S, eig, omega, e_exchange_corr)
501 : ELSE
502 10 : CALL exchange_work%compute_fm(fm_mat_S, eig, omega, e_exchange_corr, mp2_env)
503 : END IF
504 :
505 12 : CALL timestop(handle)
506 :
507 12 : END SUBROUTINE rpa_exchange_work_compute
508 :
509 : ! **************************************************************************************************
510 : !> \brief Main driver for RPA-exchange energies
511 : !> \param exchange_work ...
512 : !> \param fm_mat_S ...
513 : !> \param eig ...
514 : !> \param omega ...
515 : !> \param e_exchange_corr exchange energy correction for a quadrature point
516 : !> \param mp2_env ...
517 : !> \author Frederick Stein, May-June 2024
518 : ! **************************************************************************************************
519 10 : SUBROUTINE rpa_exchange_work_compute_fm(exchange_work, fm_mat_S, eig, omega, &
520 : e_exchange_corr, mp2_env)
521 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
522 : TYPE(cp_fm_type), DIMENSION(:), INTENT(IN) :: fm_mat_S
523 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
524 : REAL(KIND=dp), INTENT(IN) :: omega
525 : REAL(KIND=dp), INTENT(INOUT) :: e_exchange_corr
526 : TYPE(mp2_type), INTENT(INOUT) :: mp2_env
527 :
528 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_fm'
529 :
530 : INTEGER :: handle, ispin, nspins, P, Q, L_size_Gamma, hom, virt, i, &
531 : send_proc, recv_proc, recv_size, max_aux_size, proc_shift, dimen_ia, &
532 : block_size, P_start, P_end, P_size, Q_start, Q_size, Q_end, handle2, my_aux_size, my_virt
533 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET :: mat_Gamma_3_3D
534 10 : REAL(KIND=dp), POINTER, DIMENSION(:), CONTIGUOUS :: mat_Gamma_3_1D
535 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: mat_Gamma_3_2D
536 10 : REAL(KIND=dp), ALLOCATABLE, TARGET, DIMENSION(:) :: recv_buffer_1D
537 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: recv_buffer_2D
538 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :, :), CONTIGUOUS :: recv_buffer_3D
539 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: mat_B_iaP
540 10 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: product_matrix_1D
541 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :), CONTIGUOUS :: product_matrix_2D
542 10 : REAL(KIND=dp), POINTER, DIMENSION(:, :, :, :), CONTIGUOUS :: product_matrix_4D
543 : TYPE(cp_fm_type) :: fm_mat_Gamma_3
544 : TYPE(mp_para_env_type), POINTER :: para_env
545 10 : TYPE(group_dist_d1_type) :: virt_dist
546 :
547 10 : CALL timeset(routineN, handle)
548 :
549 10 : nspins = SIZE(fm_mat_S)
550 :
551 10 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, sizes=my_aux_size)
552 :
553 10 : e_exchange_corr = 0.0_dp
554 10 : max_aux_size = maxsize(exchange_work%aux_func_dist)
555 :
556 : ! local_gemm_ctx has a very large footprint the first time this routine is
557 : ! called.
558 10 : CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
559 10 : CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)
560 :
561 22 : DO ispin = 1, nspins
562 12 : hom = exchange_work%homo(ispin)
563 12 : virt = exchange_work%virtual(ispin)
564 12 : dimen_ia = hom*virt
565 12 : IF (hom < 1 .OR. virt < 1) CYCLE
566 :
567 12 : CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
568 :
569 12 : CALL cp_fm_create(fm_mat_Gamma_3, fm_mat_S(ispin)%matrix_struct)
570 12 : CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
571 :
572 : ! Update G with a new value of Omega: in practice, it is G*S
573 :
574 : ! Scale fm_work_iaP
575 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
576 12 : hom, omega, 0.0_dp)
577 :
578 : ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
579 : CALL parallel_gemm(transa="T", transb="N", m=exchange_work%dimen_RI, n=dimen_ia, k=exchange_work%dimen_RI, alpha=1.0_dp, &
580 : matrix_a=exchange_work%fm_mat_R_half_gemm, matrix_b=fm_mat_S(ispin), beta=0.0_dp, &
581 12 : matrix_c=fm_mat_Gamma_3)
582 :
583 12 : CALL create_group_dist(virt_dist, exchange_work%para_env_sub%num_pe, virt)
584 :
585 : ! Remove extra factor from S after the multiplication (to return to the original matrix)
586 12 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
587 :
588 12 : CALL exchange_work%redistribute_into_subgroups(fm_mat_Gamma_3, mat_Gamma_3_3D, ispin, virt_dist)
589 12 : CALL cp_fm_release(fm_mat_Gamma_3)
590 :
591 : ! We need only the pure matrix
592 12 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
593 :
594 : ! Reorder matrix from (P, i*a) -> (a, i, P) with P being distributed within subgroups
595 12 : CALL exchange_work%redistribute_into_subgroups(fm_mat_S(ispin), mat_B_iaP, ispin, virt_dist)
596 :
597 : ! Return to the original tensor
598 12 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), hom, omega, 0.0_dp)
599 :
600 12 : L_size_Gamma = SIZE(mat_Gamma_3_3D, 3)
601 12 : my_virt = SIZE(mat_Gamma_3_3D, 1)
602 12 : block_size = exchange_work%block_size
603 :
604 12 : mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size) => mat_Gamma_3_3D(:, :, 1:my_aux_size)
605 12 : mat_Gamma_3_2D(1:my_virt, 1:hom*my_aux_size) => mat_Gamma_3_1D(1:INT(my_virt, KIND=int_8)*hom*my_aux_size)
606 :
607 0 : ALLOCATE (product_matrix_1D(INT(hom*MIN(block_size, L_size_gamma), KIND=int_8)* &
608 36 : INT(hom*MIN(block_size, max_aux_size), KIND=int_8)))
609 36 : ALLOCATE (recv_buffer_1D(INT(virt, KIND=int_8)*hom*max_aux_size))
610 12 : recv_buffer_2D(1:my_virt, 1:hom*max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
611 12 : recv_buffer_3D(1:my_virt, 1:hom, 1:max_aux_size) => recv_buffer_1D(1:INT(virt, KIND=int_8)*hom*max_aux_size)
612 36 : DO proc_shift = 0, para_env%num_pe - 1, exchange_work%para_env_sub%num_pe
613 24 : send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
614 24 : recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
615 :
616 24 : CALL get_group_dist(exchange_work%aux_func_dist, recv_proc/exchange_work%para_env_sub%num_pe, sizes=recv_size)
617 :
618 24 : IF (recv_size == 0) recv_proc = mp_proc_null
619 :
620 24 : CALL para_env%sendrecv(mat_B_iaP, send_proc, recv_buffer_3D(:, :, 1:recv_size), recv_proc)
621 :
622 24 : IF (recv_size == 0) CYCLE
623 :
624 1038 : DO P_start = 1, L_size_Gamma, block_size
625 1002 : P_end = MIN(L_size_Gamma, P_start + block_size - 1)
626 1002 : P_size = P_end - P_start + 1
627 43875 : DO Q_start = 1, recv_size, block_size
628 42849 : Q_end = MIN(recv_size, Q_start + block_size - 1)
629 42849 : Q_size = Q_end - Q_start + 1
630 :
631 : ! Reassign product_matrix pointers to enforce contiguity of target array
632 : product_matrix_2D(1:hom*P_size, 1:hom*Q_size) => &
633 42849 : product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
634 : product_matrix_4D(1:hom, 1:P_size, 1:hom, 1:Q_size) => &
635 42849 : product_matrix_1D(1:INT(hom*P_size, KIND=int_8)*INT(hom*Q_size, KIND=int_8))
636 :
637 42849 : CALL timeset(routineN//"_gemm", handle2)
638 : CALL mp2_env%local_gemm_ctx%gemm("T", "N", hom*P_size, hom*Q_size, my_virt, 1.0_dp, &
639 : mat_Gamma_3_2D(:, hom*(P_start - 1) + 1:hom*P_end), my_virt, &
640 : recv_buffer_2D(:, hom*(Q_start - 1) + 1:hom*Q_end), my_virt, &
641 42849 : 0.0_dp, product_matrix_2D, hom*P_size)
642 42849 : CALL timestop(handle2)
643 :
644 42849 : CALL timeset(routineN//"_energy", handle2)
645 : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(P_size, Q_size, hom, product_matrix_4D) &
646 42849 : !$OMP COLLAPSE(3) REDUCTION(+: e_exchange_corr) PRIVATE(P, Q, i)
647 : DO P = 1, P_size
648 : DO Q = 1, Q_size
649 : DO i = 1, hom
650 : e_exchange_corr = e_exchange_corr + DOT_PRODUCT(product_matrix_4D(i, P, :, Q), product_matrix_4D(:, P, i, Q))
651 : END DO
652 : END DO
653 : END DO
654 86700 : CALL timestop(handle2)
655 : END DO
656 : END DO
657 : END DO
658 :
659 12 : CALL release_group_dist(virt_dist)
660 12 : IF (ALLOCATED(mat_B_iaP)) DEALLOCATE (mat_B_iaP)
661 12 : IF (ALLOCATED(mat_Gamma_3_3D)) DEALLOCATE (mat_Gamma_3_3D)
662 12 : IF (ALLOCATED(product_matrix_1D)) DEALLOCATE (product_matrix_1D)
663 58 : IF (ALLOCATED(recv_buffer_1D)) DEALLOCATE (recv_buffer_1D)
664 : END DO
665 :
666 10 : CALL mp2_env%local_gemm_ctx%destroy()
667 :
668 10 : IF (nspins == 2) e_exchange_corr = e_exchange_corr*2.0_dp
669 10 : IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
670 :
671 10 : CALL timestop(handle)
672 :
673 20 : END SUBROUTINE rpa_exchange_work_compute_fm
674 :
675 : ! **************************************************************************************************
676 : !> \brief Contract RPA-exchange density matrix with HF exchange integrals and evaluate the correction
677 : !> \param exchange_work ...
678 : !> \param fm_mat_S ...
679 : !> \param eig ...
680 : !> \param omega ...
681 : !> \param e_exchange_corr ...
682 : !> \author Vladimir Rybkin, 08/2016
683 : ! **************************************************************************************************
684 2 : SUBROUTINE rpa_exchange_work_compute_hfx(exchange_work, fm_mat_S, eig, omega, e_exchange_corr)
685 : CLASS(rpa_exchange_work_type), INTENT(INOUT) :: exchange_work
686 : TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT) :: fm_mat_S
687 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: eig
688 : REAL(KIND=dp), INTENT(IN) :: omega
689 : REAL(KIND=dp), INTENT(OUT) :: e_exchange_corr
690 :
691 : CHARACTER(LEN=*), PARAMETER :: routineN = 'rpa_exchange_work_compute_hfx'
692 :
693 : INTEGER :: handle, ispin, my_aux_start, my_aux_end, &
694 : my_aux_size, nspins, L_counter, dimen_ia, hom, virt
695 : REAL(KIND=dp) :: e_exchange_P
696 2 : TYPE(dbcsr_matrix_p_set), DIMENSION(:), ALLOCATABLE :: dbcsr_Gamma_3
697 : TYPE(cp_fm_type) :: fm_mat_Gamma_3
698 : TYPE(mp_para_env_type), POINTER :: para_env
699 :
700 2 : CALL timeset(routineN, handle)
701 :
702 2 : e_exchange_corr = 0.0_dp
703 :
704 2 : nspins = SIZE(fm_mat_S)
705 :
706 2 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
707 :
708 8 : ALLOCATE (dbcsr_Gamma_3(nspins))
709 4 : DO ispin = 1, nspins
710 2 : hom = exchange_work%homo(ispin)
711 2 : virt = exchange_work%virtual(ispin)
712 2 : dimen_ia = hom*virt
713 2 : IF (hom < 1 .OR. virt < 1) CYCLE
714 :
715 2 : CALL cp_fm_get_info(fm_mat_S(ispin), para_env=para_env)
716 :
717 2 : CALL cp_fm_create(fm_mat_Gamma_3, exchange_work%exchange_env%struct_Gamma(ispin)%struct)
718 2 : CALL cp_fm_set_all(matrix=fm_mat_Gamma_3, alpha=0.0_dp)
719 :
720 : ! Update G with a new value of Omega: in practice, it is G*S
721 :
722 : ! Scale fm_work_iaP
723 : CALL calc_fm_mat_S_rpa(fm_mat_S(ispin), .TRUE., virt, eig(:, ispin), &
724 2 : hom, omega, 0.0_dp)
725 :
726 : ! Calculate Gamma_3: Gamma_3 = G*S*R^(1/2) = G*S*R^(1/2)
727 : CALL parallel_gemm(transa="T", transb="N", m=dimen_ia, n=exchange_work%dimen_RI, &
728 : k=exchange_work%dimen_RI, alpha=1.0_dp, &
729 : matrix_a=fm_mat_S(ispin), matrix_b=exchange_work%fm_mat_R_half_gemm, beta=0.0_dp, &
730 2 : matrix_c=fm_mat_Gamma_3)
731 :
732 : ! Remove extra factor from S after the multiplication (to return to the original matrix)
733 2 : CALL remove_scaling_factor_rpa(fm_mat_S(ispin), virt, eig(:, ispin), hom, omega)
734 :
735 : ! Copy Gamma_ia_P^3 to dbcsr matrix set
736 : CALL gamma_fm_to_dbcsr(fm_mat_Gamma_3, dbcsr_Gamma_3(ispin)%matrix_set, &
737 : para_env, exchange_work%para_env_sub, hom, virt, &
738 : exchange_work%exchange_env%mo_coeff_o(ispin), &
739 6 : exchange_work%ngroup, my_aux_start, my_aux_end, my_aux_size)
740 : END DO
741 :
742 85 : DO L_counter = 1, my_aux_size
743 166 : DO ispin = 1, nspins
744 : ! Do dbcsr multiplication: transform the virtual index
745 : CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mo_coeff_v(ispin), &
746 : dbcsr_Gamma_3(ispin)%matrix_set(L_counter), &
747 : 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
748 83 : filter_eps=exchange_work%exchange_env%eps_filter)
749 :
750 83 : CALL dbcsr_release(dbcsr_Gamma_3(ispin)%matrix_set(L_counter))
751 :
752 : ! Do dbcsr multiplication: transform the occupied index
753 : CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
754 : exchange_work%exchange_env%mo_coeff_o(ispin), &
755 : 0.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
756 83 : filter_eps=exchange_work%exchange_env%eps_filter)
757 : CALL dbcsr_multiply("N", "T", 0.5_dp, exchange_work%exchange_env%mo_coeff_o(ispin), &
758 : exchange_work%exchange_env%dbcsr_Gamma_inu_P(ispin), &
759 : 1.0_dp, exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
760 83 : filter_eps=exchange_work%exchange_env%eps_filter)
761 :
762 166 : CALL dbcsr_set(exchange_work%exchange_env%mat_hfx(ispin)%matrix, 0.0_dp)
763 : END DO
764 :
765 : CALL tddft_hfx_matrix(exchange_work%exchange_env%mat_hfx, exchange_work%exchange_env%dbcsr_Gamma_munu_P, &
766 : exchange_work%exchange_env%qs_env, .FALSE., &
767 : exchange_work%exchange_env%my_recalc_hfx_integrals, &
768 : exchange_work%exchange_env%hfx_sections, exchange_work%exchange_env%x_data, &
769 83 : exchange_work%exchange_env%para_env)
770 :
771 83 : exchange_work%exchange_env%my_recalc_hfx_integrals = .FALSE.
772 168 : DO ispin = 1, nspins
773 : CALL dbcsr_multiply("N", "T", 1.0_dp, exchange_work%exchange_env%mat_hfx(ispin)%matrix, &
774 : exchange_work%exchange_env%dbcsr_Gamma_munu_P(ispin)%matrix, &
775 83 : 0.0_dp, exchange_work%exchange_env%work_ao, filter_eps=exchange_work%exchange_env%eps_filter)
776 83 : CALL dbcsr_trace(exchange_work%exchange_env%work_ao, e_exchange_P)
777 166 : e_exchange_corr = e_exchange_corr - e_exchange_P
778 : END DO
779 : END DO
780 :
781 : IF (nspins == 2) e_exchange_corr = e_exchange_corr
782 2 : IF (nspins == 1) e_exchange_corr = e_exchange_corr*4.0_dp
783 :
784 2 : CALL timestop(handle)
785 :
786 6 : END SUBROUTINE rpa_exchange_work_compute_hfx
787 :
788 : ! **************************************************************************************************
789 : !> \brief ...
790 : !> \param exchange_work ...
791 : !> \param fm_mat ...
792 : !> \param mat ...
793 : !> \param ispin ...
794 : !> \param virt_dist ...
795 : ! **************************************************************************************************
796 24 : SUBROUTINE redistribute_into_subgroups(exchange_work, fm_mat, mat, ispin, virt_dist)
797 : CLASS(rpa_exchange_work_type), INTENT(IN) :: exchange_work
798 : TYPE(cp_fm_type), INTENT(IN) :: fm_mat
799 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
800 : INTENT(OUT) :: mat
801 : INTEGER, INTENT(IN) :: ispin
802 : TYPE(group_dist_d1_type), INTENT(IN) :: virt_dist
803 :
804 : CHARACTER(LEN=*), PARAMETER :: routineN = 'redistribute_into_subgroups'
805 :
806 : INTEGER :: aux_counter, aux_global, aux_local, aux_proc, avirt, dimen_RI, handle, handle2, &
807 : ia_global, ia_local, iocc, max_number_recv, max_number_send, my_aux_end, my_aux_size, &
808 : my_aux_start, my_process_column, my_process_row, my_virt_end, my_virt_size, &
809 : my_virt_start, proc, proc_shift, recv_proc, send_proc, virt_counter, virt_proc, group_size
810 24 : INTEGER, ALLOCATABLE, DIMENSION(:) :: data2send, recv_col_indices, &
811 24 : recv_row_indices, send_aux_indices, send_virt_indices, virt2send
812 : INTEGER, DIMENSION(2) :: recv_shape
813 24 : INTEGER, DIMENSION(:), POINTER :: aux_distribution_fm, col_indices, &
814 24 : ia_distribution_fm, row_indices
815 24 : INTEGER, DIMENSION(:, :), POINTER :: mpi2blacs
816 24 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET :: recv_buffer, send_buffer
817 : REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
818 24 : POINTER :: recv_ptr, send_ptr
819 : TYPE(cp_blacs_env_type), POINTER :: context
820 : TYPE(mp_para_env_type), POINTER :: para_env
821 :
822 24 : CALL timeset(routineN, handle)
823 :
824 : CALL cp_fm_get_info(matrix=fm_mat, &
825 : nrow_locals=aux_distribution_fm, &
826 : col_indices=col_indices, &
827 : row_indices=row_indices, &
828 : ncol_locals=ia_distribution_fm, &
829 : context=context, &
830 : nrow_global=dimen_RI, &
831 24 : para_env=para_env)
832 :
833 24 : IF (exchange_work%homo(ispin) <= 0 .OR. exchange_work%virtual(ispin) <= 0) THEN
834 0 : CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
835 0 : ALLOCATE (mat(exchange_work%homo(ispin), my_virt_size, dimen_RI))
836 0 : CALL timestop(handle)
837 0 : RETURN
838 : END IF
839 :
840 24 : group_size = exchange_work%para_env_sub%num_pe
841 :
842 24 : CALL timeset(routineN//"_prep", handle2)
843 24 : CALL get_group_dist(exchange_work%aux_func_dist, exchange_work%color_sub, my_aux_start, my_aux_end, my_aux_size)
844 24 : CALL get_group_dist(virt_dist, exchange_work%para_env_sub%mepos, my_virt_start, my_virt_end, my_virt_size)
845 24 : CALL context%get(my_process_column=my_process_column, my_process_row=my_process_row, mpi2blacs=mpi2blacs)
846 :
847 : ! Determine the number of columns to send
848 120 : ALLOCATE (send_aux_indices(MAXVAL(exchange_work%aux2send)))
849 72 : ALLOCATE (virt2send(0:group_size - 1))
850 48 : virt2send = 0
851 1924 : DO ia_local = 1, ia_distribution_fm(my_process_column)
852 1900 : ia_global = col_indices(ia_local)
853 1900 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
854 1900 : proc = group_dist_proc(virt_dist, avirt)
855 1924 : virt2send(proc) = virt2send(proc) + 1
856 : END DO
857 :
858 72 : ALLOCATE (data2send(0:para_env%num_pe - 1))
859 72 : DO aux_proc = 0, exchange_work%ngroup - 1
860 120 : DO virt_proc = 0, group_size - 1
861 96 : data2send(aux_proc*group_size + virt_proc) = exchange_work%aux2send(aux_proc)*virt2send(virt_proc)
862 : END DO
863 : END DO
864 :
865 96 : ALLOCATE (send_virt_indices(MAXVAL(virt2send)))
866 72 : max_number_send = MAXVAL(data2send)
867 :
868 72 : ALLOCATE (send_buffer(INT(max_number_send, KIND=int_8)*exchange_work%homo(ispin)))
869 24 : max_number_recv = max_number_send
870 24 : CALL para_env%max(max_number_recv)
871 72 : ALLOCATE (recv_buffer(max_number_recv))
872 :
873 120 : ALLOCATE (mat(my_virt_size, exchange_work%homo(ispin), my_aux_size))
874 :
875 24 : CALL timestop(handle2)
876 :
877 24 : CALL timeset(routineN//"_own", handle2)
878 : ! Start with own data
879 1026 : DO aux_local = 1, aux_distribution_fm(my_process_row)
880 1002 : aux_global = row_indices(aux_local)
881 1002 : IF (aux_global < my_aux_start .OR. aux_global > my_aux_end) CYCLE
882 40848 : DO ia_local = 1, ia_distribution_fm(my_process_column)
883 40318 : ia_global = fm_mat%matrix_struct%col_indices(ia_local)
884 :
885 40318 : iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
886 40318 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
887 :
888 40318 : IF (my_virt_start > avirt .OR. my_virt_end < avirt) CYCLE
889 :
890 41320 : mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = fm_mat%local_data(aux_local, ia_local)
891 : END DO
892 : END DO
893 24 : CALL timestop(handle2)
894 :
895 48 : DO proc_shift = 1, para_env%num_pe - 1
896 24 : send_proc = MODULO(para_env%mepos + proc_shift, para_env%num_pe)
897 24 : recv_proc = MODULO(para_env%mepos - proc_shift, para_env%num_pe)
898 :
899 24 : CALL timeset(routineN//"_pack_buffer", handle2)
900 : send_ptr(1:virt2send(MOD(send_proc, group_size)), &
901 : 1:exchange_work%aux2send(send_proc/group_size)) => &
902 : send_buffer(1:INT(virt2send(MOD(send_proc, group_size)), KIND=int_8)* &
903 24 : exchange_work%aux2send(send_proc/group_size))
904 : ! Pack send buffer
905 24 : aux_counter = 0
906 1026 : DO aux_local = 1, aux_distribution_fm(my_process_row)
907 1002 : aux_global = row_indices(aux_local)
908 1002 : proc = group_dist_proc(exchange_work%aux_func_dist, aux_global)
909 1002 : IF (proc /= send_proc/group_size) CYCLE
910 496 : aux_counter = aux_counter + 1
911 496 : virt_counter = 0
912 40016 : DO ia_local = 1, ia_distribution_fm(my_process_column)
913 39520 : ia_global = col_indices(ia_local)
914 39520 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
915 :
916 39520 : proc = group_dist_proc(virt_dist, avirt)
917 39520 : IF (proc /= MOD(send_proc, group_size)) CYCLE
918 39520 : virt_counter = virt_counter + 1
919 39520 : send_ptr(virt_counter, aux_counter) = fm_mat%local_data(aux_local, ia_local)
920 40016 : send_virt_indices(virt_counter) = ia_global
921 : END DO
922 1026 : send_aux_indices(aux_counter) = aux_global
923 : END DO
924 24 : CALL timestop(handle2)
925 :
926 24 : CALL timeset(routineN//"_ex_size", handle2)
927 24 : recv_shape = [1, 1]
928 72 : CALL para_env%sendrecv(SHAPE(send_ptr), send_proc, recv_shape, recv_proc)
929 24 : CALL timestop(handle2)
930 :
931 72 : IF (SIZE(send_ptr) == 0) send_proc = mp_proc_null
932 72 : IF (PRODUCT(recv_shape) == 0) recv_proc = mp_proc_null
933 :
934 24 : CALL timeset(routineN//"_ex_idx", handle2)
935 120 : ALLOCATE (recv_row_indices(recv_shape(1)), recv_col_indices(recv_shape(2)))
936 24 : CALL para_env%sendrecv(send_virt_indices(1:virt_counter), send_proc, recv_row_indices, recv_proc)
937 24 : CALL para_env%sendrecv(send_aux_indices(1:aux_counter), send_proc, recv_col_indices, recv_proc)
938 24 : CALL timestop(handle2)
939 :
940 : ! Prepare pointer to recv buffer (consider transposition while packing the send buffer)
941 24 : recv_ptr(1:recv_shape(1), 1:MAX(1, recv_shape(2))) => recv_buffer(1:recv_shape(1)*MAX(1, recv_shape(2)))
942 :
943 24 : CALL timeset(routineN//"_sendrecv", handle2)
944 : ! Perform communication
945 24 : CALL para_env%sendrecv(send_ptr, send_proc, recv_ptr, recv_proc)
946 24 : CALL timestop(handle2)
947 :
948 24 : IF (recv_proc == mp_proc_null) THEN
949 0 : DEALLOCATE (recv_row_indices, recv_col_indices)
950 0 : CYCLE
951 : END IF
952 :
953 24 : CALL timeset(routineN//"_unpack", handle2)
954 : ! Unpack receive buffer
955 520 : DO aux_local = 1, SIZE(recv_col_indices)
956 496 : aux_global = recv_col_indices(aux_local)
957 :
958 40040 : DO ia_local = 1, SIZE(recv_row_indices)
959 39520 : ia_global = recv_row_indices(ia_local)
960 :
961 39520 : iocc = (ia_global - 1)/exchange_work%virtual(ispin) + 1
962 39520 : avirt = MOD(ia_global - 1, exchange_work%virtual(ispin)) + 1
963 :
964 40016 : mat(avirt - my_virt_start + 1, iocc, aux_global - my_aux_start + 1) = recv_ptr(ia_local, aux_local)
965 : END DO
966 : END DO
967 24 : CALL timestop(handle2)
968 :
969 24 : IF (ALLOCATED(recv_row_indices)) DEALLOCATE (recv_row_indices)
970 168 : IF (ALLOCATED(recv_col_indices)) DEALLOCATE (recv_col_indices)
971 : END DO
972 :
973 24 : DEALLOCATE (send_aux_indices, send_virt_indices)
974 :
975 24 : CALL timestop(handle)
976 :
977 96 : END SUBROUTINE redistribute_into_subgroups
978 :
979 0 : END MODULE rpa_exchange
|