LCOV - code coverage report
Current view: top level - src - hfx_ri_kp.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b8e0b09) Lines: 2161 2191 98.6 %
Date: 2024-08-31 06:31:37 Functions: 38 40 95.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief RI-methods for HFX and K-points.
      10             : !> \auhtor Augustin Bussy (01.2023)
      11             : ! **************************************************************************************************
      12             : 
      13             : MODULE hfx_ri_kp
      14             :    USE admm_types,                      ONLY: get_admm_env
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      16             :                                               get_atomic_kind_set
      17             :    USE basis_set_types,                 ONLY: get_gto_basis_set,&
      18             :                                               gto_basis_set_p_type
      19             :    USE bibliography,                    ONLY: Bussy2024,&
      20             :                                               cite_reference
      21             :    USE cell_types,                      ONLY: cell_type,&
      22             :                                               pbc,&
      23             :                                               real_to_scaled,&
      24             :                                               scaled_to_real
      25             :    USE cp_array_utils,                  ONLY: cp_1d_logical_p_type,&
      26             :                                               cp_2d_r_p_type,&
      27             :                                               cp_3d_r_p_type
      28             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_create,&
      29             :                                               cp_blacs_env_release,&
      30             :                                               cp_blacs_env_type
      31             :    USE cp_control_types,                ONLY: dft_control_type
      32             :    USE cp_dbcsr_api,                    ONLY: &
      33             :         dbcsr_add, dbcsr_clear, dbcsr_copy, dbcsr_create, dbcsr_distribution_get, &
      34             :         dbcsr_distribution_new, dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_dot, &
      35             :         dbcsr_filter, dbcsr_finalize, dbcsr_get_block_p, dbcsr_get_info, &
      36             :         dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
      37             :         dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_put_block, dbcsr_release, &
      38             :         dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_symmetric
      39             :    USE cp_dbcsr_cholesky,               ONLY: cp_dbcsr_cholesky_decompose,&
      40             :                                               cp_dbcsr_cholesky_invert
      41             :    USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
      42             :    USE cp_dbcsr_diag,                   ONLY: cp_dbcsr_power
      43             :    USE cp_dbcsr_operations,             ONLY: cp_dbcsr_dist2d_to_dist
      44             :    USE dbt_api,                         ONLY: &
      45             :         dbt_batched_contract_finalize, dbt_batched_contract_init, dbt_clear, dbt_contract, &
      46             :         dbt_copy, dbt_copy_matrix_to_tensor, dbt_copy_tensor_to_matrix, dbt_create, dbt_destroy, &
      47             :         dbt_distribution_destroy, dbt_distribution_new, dbt_distribution_type, dbt_filter, &
      48             :         dbt_finalize, dbt_get_block, dbt_get_info, dbt_get_stored_coordinates, &
      49             :         dbt_iterator_blocks_left, dbt_iterator_next_block, dbt_iterator_start, dbt_iterator_stop, &
      50             :         dbt_iterator_type, dbt_mp_environ_pgrid, dbt_pgrid_create, dbt_pgrid_destroy, &
      51             :         dbt_pgrid_type, dbt_put_block, dbt_scale, dbt_type
      52             :    USE distribution_2d_types,           ONLY: distribution_2d_release,&
      53             :                                               distribution_2d_type
      54             :    USE hfx_ri,                          ONLY: get_idx_to_atom,&
      55             :                                               hfx_ri_pre_scf_calc_tensors
      56             :    USE hfx_types,                       ONLY: hfx_ri_type
      57             :    USE input_constants,                 ONLY: do_potential_short,&
      58             :                                               hfx_ri_do_2c_cholesky,&
      59             :                                               hfx_ri_do_2c_diag,&
      60             :                                               hfx_ri_do_2c_iter
      61             :    USE input_cp2k_hfx,                  ONLY: ri_pmat
      62             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      63             :                                               section_vals_type,&
      64             :                                               section_vals_val_get,&
      65             :                                               section_vals_val_set
      66             :    USE iterate_matrix,                  ONLY: invert_hotelling
      67             :    USE kinds,                           ONLY: dp,&
      68             :                                               int_8
      69             :    USE kpoint_types,                    ONLY: get_kpoint_info,&
      70             :                                               kpoint_type
      71             :    USE libint_2c_3c,                    ONLY: cutoff_screen_factor
      72             :    USE machine,                         ONLY: m_flush,&
      73             :                                               m_walltime
      74             :    USE mathlib,                         ONLY: erfc_cutoff
      75             :    USE message_passing,                 ONLY: mp_cart_type,&
      76             :                                               mp_para_env_type,&
      77             :                                               mp_request_type,&
      78             :                                               mp_waitall
      79             :    USE particle_methods,                ONLY: get_particle_set
      80             :    USE particle_types,                  ONLY: particle_type
      81             :    USE physcon,                         ONLY: angstrom
      82             :    USE qs_environment_types,            ONLY: get_qs_env,&
      83             :                                               qs_environment_type
      84             :    USE qs_force_types,                  ONLY: qs_force_type
      85             :    USE qs_integral_utils,               ONLY: basis_set_list_setup
      86             :    USE qs_interactions,                 ONLY: init_interaction_radii_orb_basis
      87             :    USE qs_kind_types,                   ONLY: qs_kind_type
      88             :    USE qs_neighbor_list_types,          ONLY: get_iterator_info,&
      89             :                                               neighbor_list_iterate,&
      90             :                                               neighbor_list_iterator_create,&
      91             :                                               neighbor_list_iterator_p_type,&
      92             :                                               neighbor_list_iterator_release,&
      93             :                                               neighbor_list_set_p_type,&
      94             :                                               release_neighbor_list_sets
      95             :    USE qs_scf_types,                    ONLY: qs_scf_env_type
      96             :    USE qs_tensors,                      ONLY: &
      97             :         build_2c_derivatives, build_2c_neighbor_lists, build_3c_derivatives, &
      98             :         build_3c_neighbor_lists, get_3c_iterator_info, get_tensor_occupancy, &
      99             :         neighbor_list_3c_destroy, neighbor_list_3c_iterate, neighbor_list_3c_iterator_create, &
     100             :         neighbor_list_3c_iterator_destroy
     101             :    USE qs_tensors_types,                ONLY: create_2c_tensor,&
     102             :                                               create_3c_tensor,&
     103             :                                               create_tensor_batches,&
     104             :                                               distribution_2d_create,&
     105             :                                               distribution_3d_create,&
     106             :                                               distribution_3d_type,&
     107             :                                               neighbor_list_3c_iterator_type,&
     108             :                                               neighbor_list_3c_type
     109             :    USE util,                            ONLY: get_limit
     110             :    USE virial_types,                    ONLY: virial_type
     111             : #include "./base/base_uses.f90"
     112             : 
     113             : !$ USE OMP_LIB, ONLY: omp_get_num_threads
     114             : 
     115             :    IMPLICIT NONE
     116             :    PRIVATE
     117             : 
     118             :    PUBLIC :: hfx_ri_update_ks_kp, hfx_ri_update_forces_kp
     119             : 
     120             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_ri_kp'
     121             : CONTAINS
     122             : 
     123             : ! NOTES: for a start, we do not seek performance, but accuracy. So in this first implementation,
     124             : !        we give little consideration to batching, load balance and such.
     125             : !        We also put everything here, even if there is some code replication with the original RI_HFX
     126             : !        We will only work in the RHO flavor
     127             : !        For now, we will also always assume that there is a single para_env, and that there is no
     128             : !        K-point subgroup. This might change in the future
     129             : 
     130             : ! **************************************************************************************************
     131             : !> \brief I_1nitialize the ri_data for K-point. For now, we take the normal, usual existing ri_data
     132             : !>        and we adapt it to our needs
     133             : !> \param dbcsr_template ...
     134             : !> \param ri_data ...
     135             : !> \param qs_env ...
     136             : ! **************************************************************************************************
     137          70 :    SUBROUTINE adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     138             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     139             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     140             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     141             : 
     142             :       INTEGER                                            :: i_img, i_RI, i_spin, iatom, natom, &
     143             :                                                             nblks_RI, nimg, nkind, nspins
     144          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, dist1, dist2, dist3
     145             :       TYPE(dft_control_type), POINTER                    :: dft_control
     146             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     147             : 
     148          70 :       NULLIFY (dft_control, para_env)
     149             : 
     150             :       !The main thing that we need to do is to allocate more space for the integrals, such that there
     151             :       !is room for each periodic image. Note that we only go in 1D, i.e. we store (mu^0 sigma^a|P^0),
     152             :       !and (P^0|Q^a) => the RI basis is always in the main cell.
     153             : 
     154             :       !Get kpoint info
     155          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, para_env=para_env, nkind=nkind)
     156          70 :       nimg = ri_data%nimg
     157             : 
     158             :       !Along the RI direction we have basis elements spread accross ncell_RI images.
     159          70 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
     160         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     161         506 :       DO i_RI = 1, ri_data%ncell_RI
     162        2344 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
     163             :       END DO
     164             : 
     165        4058 :       ALLOCATE (ri_data%t_3c_int_ctr_1(1, nimg))
     166             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), dist1, dist2, dist3, &
     167             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     168          70 :                             ri_data%bsizes_AO_split, [1, 2], [3], name="(AO RI | AO)")
     169             : 
     170        1644 :       DO i_img = 2, nimg
     171        1644 :          CALL dbt_create(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_1(1, i_img))
     172             :       END DO
     173          70 :       DEALLOCATE (dist1, dist2, dist3)
     174             : 
     175         770 :       ALLOCATE (ri_data%t_3c_int_ctr_2(1, 1))
     176             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_2(1, 1), dist1, dist2, dist3, &
     177             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     178          70 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
     179          70 :       DEALLOCATE (dist1, dist2, dist3)
     180             : 
     181             :       !We use full block sizes for the 2c quantities
     182          70 :       DEALLOCATE (bsizes_RI_ext)
     183          70 :       nblks_RI = SIZE(ri_data%bsizes_RI)
     184         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     185         506 :       DO i_RI = 1, ri_data%ncell_RI
     186        1378 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
     187             :       END DO
     188             : 
     189        3010 :       ALLOCATE (ri_data%t_2c_inv(1, natom), ri_data%t_2c_int(1, natom), ri_data%t_2c_pot(1, natom))
     190             :       CALL create_2c_tensor(ri_data%t_2c_inv(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     191             :                             bsizes_RI_ext, bsizes_RI_ext, &
     192          70 :                             name="(RI | RI)")
     193          70 :       DEALLOCATE (dist1, dist2)
     194          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, 1))
     195          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, 1))
     196         140 :       DO iatom = 2, natom
     197          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_inv(1, iatom))
     198          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, iatom))
     199         140 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, iatom))
     200             :       END DO
     201             : 
     202         350 :       ALLOCATE (ri_data%kp_cost(natom, natom, nimg))
     203       11578 :       ri_data%kp_cost = 0.0_dp
     204             : 
     205             :       !We store the density and KS matrix in tensor format
     206          70 :       nspins = dft_control%nspins
     207        8930 :       ALLOCATE (ri_data%rho_ao_t(nspins, nimg), ri_data%ks_t(nspins, nimg))
     208             :       CALL create_2c_tensor(ri_data%rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     209             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     210          70 :                             name="(AO | AO)")
     211          70 :       DEALLOCATE (dist1, dist2)
     212             : 
     213          70 :       CALL dbt_create(dbcsr_template, ri_data%ks_t(1, 1))
     214             : 
     215          70 :       IF (nspins == 2) THEN
     216          24 :          CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(2, 1))
     217          24 :          CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(2, 1))
     218             :       END IF
     219        1644 :       DO i_img = 2, nimg
     220        3566 :          DO i_spin = 1, nspins
     221        1922 :             CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(i_spin, i_img))
     222        3496 :             CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(i_spin, i_img))
     223             :          END DO
     224             :       END DO
     225             : 
     226         210 :    END SUBROUTINE adapt_ri_data_to_kp
     227             : 
     228             : ! **************************************************************************************************
     229             : !> \brief The pre-scf steps for RI-HFX k-points calculation. Namely the calculation of the integrals
     230             : !> \param dbcsr_template ...
     231             : !> \param ri_data ...
     232             : !> \param qs_env ...
     233             : ! **************************************************************************************************
     234          70 :    SUBROUTINE hfx_ri_pre_scf_kp(dbcsr_template, ri_data, qs_env)
     235             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     236             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     237             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     238             : 
     239             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'hfx_ri_pre_scf_kp'
     240             : 
     241             :       INTEGER                                            :: handle, i_img, iatom, natom, nimg, nkind
     242          70 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: t_2c_op_pot, t_2c_op_RI
     243          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: t_3c_int
     244             :       TYPE(dft_control_type), POINTER                    :: dft_control
     245             : 
     246          70 :       NULLIFY (dft_control)
     247             : 
     248          70 :       CALL timeset(routineN, handle)
     249             : 
     250          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, nkind=nkind)
     251             : 
     252          70 :       CALL cleanup_kp(ri_data)
     253             : 
     254             :       !We do all the checks on what we allow in this initial implementation
     255          70 :       IF (ri_data%flavor .NE. ri_pmat) CPABORT("K-points RI-HFX only with RHO flavor")
     256          70 :       IF (ri_data%same_op) ri_data%same_op = .FALSE. !force the full calculation with RI metric
     257          70 :       IF (ABS(ri_data%eps_pgf_orb - dft_control%qs_control%eps_pgf_orb) > 1.0E-16_dp) &
     258           0 :          CPABORT("RI%EPS_PGF_ORB and QS%EPS_PGF_ORB must be identical for RI-HFX k-points")
     259             : 
     260          70 :       CALL get_kp_and_ri_images(ri_data, qs_env)
     261          70 :       nimg = ri_data%nimg
     262             : 
     263             :       !Calculate the integrals
     264        3568 :       ALLOCATE (t_2c_op_pot(nimg), t_2c_op_RI(nimg))
     265        4058 :       ALLOCATE (t_3c_int(1, nimg))
     266          70 :       CALL hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_op_RI, t_2c_op_pot, t_3c_int, do_kpoints=.TRUE.)
     267             : 
     268             :       !Make sure the internals have the k-point format
     269          70 :       CALL adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     270             : 
     271             :       !For each atom i, we calculate the inverse RI metric (P^0 | Q^0)^-1 without external bumping yet
     272             :       !Also store the off-diagonal integrals of the RI metric in case of forces, bumped from the left
     273         210 :       DO iatom = 1, natom
     274             :          CALL get_ext_2c_int(ri_data%t_2c_inv(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     275         140 :                              do_inverse=.TRUE.)
     276             :          !for the forces:
     277             :          !off-diagonl RI metric bumped from the left
     278             :          CALL get_ext_2c_int(ri_data%t_2c_int(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, &
     279         140 :                              qs_env, off_diagonal=.TRUE.)
     280         140 :          CALL apply_bump(ri_data%t_2c_int(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     281             : 
     282             :          !RI metric with bumped off-diagonal blocks (but not inverted), depumed from left and right
     283             :          CALL get_ext_2c_int(ri_data%t_2c_pot(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     284         140 :                              do_inverse=.TRUE., skip_inverse=.TRUE.)
     285             :          CALL apply_bump(ri_data%t_2c_pot(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., &
     286         210 :                          from_right=.TRUE., debump=.TRUE.)
     287             : 
     288             :       END DO
     289             : 
     290        1714 :       DO i_img = 1, nimg
     291        1714 :          CALL dbcsr_release(t_2c_op_RI(i_img))
     292             :       END DO
     293             : 
     294        3428 :       ALLOCATE (ri_data%kp_mat_2c_pot(1, nimg))
     295        1714 :       DO i_img = 1, nimg
     296        1644 :          CALL dbcsr_create(ri_data%kp_mat_2c_pot(1, i_img), template=t_2c_op_pot(i_img))
     297        1644 :          CALL dbcsr_copy(ri_data%kp_mat_2c_pot(1, i_img), t_2c_op_pot(i_img))
     298        1714 :          CALL dbcsr_release(t_2c_op_pot(i_img))
     299             :       END DO
     300             : 
     301             :       !Pre-contract all 3c integrals with the bumped inverse RI metric (P^0|Q^0)^-1,
     302             :       !and store in ri_data%t_3c_int_ctr_1
     303          70 :       CALL precontract_3c_ints(t_3c_int, ri_data, qs_env)
     304             : 
     305             :       !reorder the 3c integrals such that empty images are bunched up together
     306          70 :       CALL reorder_3c_ints(ri_data%t_3c_int_ctr_1(1, :), ri_data)
     307             : 
     308          70 :       CALL timestop(handle)
     309             : 
     310        1784 :    END SUBROUTINE hfx_ri_pre_scf_kp
     311             : 
     312             : ! **************************************************************************************************
     313             : !> \brief clean-up the KP specific data from ri_data
     314             : !> \param ri_data ...
     315             : ! **************************************************************************************************
     316          70 :    SUBROUTINE cleanup_kp(ri_data)
     317             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     318             : 
     319             :       INTEGER                                            :: i, j
     320             : 
     321          70 :       IF (ALLOCATED(ri_data%kp_cost)) DEALLOCATE (ri_data%kp_cost)
     322          70 :       IF (ALLOCATED(ri_data%idx_to_img)) DEALLOCATE (ri_data%idx_to_img)
     323          70 :       IF (ALLOCATED(ri_data%img_to_idx)) DEALLOCATE (ri_data%img_to_idx)
     324          70 :       IF (ALLOCATED(ri_data%present_images)) DEALLOCATE (ri_data%present_images)
     325          70 :       IF (ALLOCATED(ri_data%img_to_RI_cell)) DEALLOCATE (ri_data%img_to_RI_cell)
     326          70 :       IF (ALLOCATED(ri_data%RI_cell_to_img)) DEALLOCATE (ri_data%RI_cell_to_img)
     327             : 
     328          70 :       IF (ALLOCATED(ri_data%kp_mat_2c_pot)) THEN
     329         540 :          DO j = 1, SIZE(ri_data%kp_mat_2c_pot, 2)
     330        1060 :             DO i = 1, SIZE(ri_data%kp_mat_2c_pot, 1)
     331        1040 :                CALL dbcsr_release(ri_data%kp_mat_2c_pot(i, j))
     332             :             END DO
     333             :          END DO
     334          20 :          DEALLOCATE (ri_data%kp_mat_2c_pot)
     335             :       END IF
     336             : 
     337          70 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
     338         540 :          DO i = 1, SIZE(ri_data%kp_t_3c_int)
     339         540 :             CALL dbt_destroy(ri_data%kp_t_3c_int(i))
     340             :          END DO
     341         540 :          DEALLOCATE (ri_data%kp_t_3c_int)
     342             :       END IF
     343             : 
     344          70 :       IF (ALLOCATED(ri_data%t_2c_inv)) THEN
     345         160 :          DO j = 1, SIZE(ri_data%t_2c_inv, 2)
     346         250 :             DO i = 1, SIZE(ri_data%t_2c_inv, 1)
     347         180 :                CALL dbt_destroy(ri_data%t_2c_inv(i, j))
     348             :             END DO
     349             :          END DO
     350         160 :          DEALLOCATE (ri_data%t_2c_inv)
     351             :       END IF
     352             : 
     353          70 :       IF (ALLOCATED(ri_data%t_2c_int)) THEN
     354         160 :          DO j = 1, SIZE(ri_data%t_2c_int, 2)
     355         250 :             DO i = 1, SIZE(ri_data%t_2c_int, 1)
     356         180 :                CALL dbt_destroy(ri_data%t_2c_int(i, j))
     357             :             END DO
     358             :          END DO
     359         160 :          DEALLOCATE (ri_data%t_2c_int)
     360             :       END IF
     361             : 
     362          70 :       IF (ALLOCATED(ri_data%t_2c_pot)) THEN
     363         160 :          DO j = 1, SIZE(ri_data%t_2c_pot, 2)
     364         250 :             DO i = 1, SIZE(ri_data%t_2c_pot, 1)
     365         180 :                CALL dbt_destroy(ri_data%t_2c_pot(i, j))
     366             :             END DO
     367             :          END DO
     368         160 :          DEALLOCATE (ri_data%t_2c_pot)
     369             :       END IF
     370             : 
     371          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_1)) THEN
     372         640 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_1, 2)
     373        1210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_1, 1)
     374        1140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_1(i, j))
     375             :             END DO
     376             :          END DO
     377         640 :          DEALLOCATE (ri_data%t_3c_int_ctr_1)
     378             :       END IF
     379             : 
     380          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_2)) THEN
     381         140 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_2, 2)
     382         210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_2, 1)
     383         140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_2(i, j))
     384             :             END DO
     385             :          END DO
     386         140 :          DEALLOCATE (ri_data%t_3c_int_ctr_2)
     387             :       END IF
     388             : 
     389          70 :       IF (ALLOCATED(ri_data%rho_ao_t)) THEN
     390         640 :          DO j = 1, SIZE(ri_data%rho_ao_t, 2)
     391        1428 :             DO i = 1, SIZE(ri_data%rho_ao_t, 1)
     392        1358 :                CALL dbt_destroy(ri_data%rho_ao_t(i, j))
     393             :             END DO
     394             :          END DO
     395         858 :          DEALLOCATE (ri_data%rho_ao_t)
     396             :       END IF
     397             : 
     398          70 :       IF (ALLOCATED(ri_data%ks_t)) THEN
     399         640 :          DO j = 1, SIZE(ri_data%ks_t, 2)
     400        1428 :             DO i = 1, SIZE(ri_data%ks_t, 1)
     401        1358 :                CALL dbt_destroy(ri_data%ks_t(i, j))
     402             :             END DO
     403             :          END DO
     404         858 :          DEALLOCATE (ri_data%ks_t)
     405             :       END IF
     406             : 
     407          70 :    END SUBROUTINE cleanup_kp
     408             : 
     409             : ! **************************************************************************************************
     410             : !> \brief Update the KS matrices for each real-space image
     411             : !> \param qs_env ...
     412             : !> \param ri_data ...
     413             : !> \param ks_matrix ...
     414             : !> \param ehfx ...
     415             : !> \param rho_ao ...
     416             : !> \param geometry_did_change ...
     417             : !> \param nspins ...
     418             : !> \param hf_fraction ...
     419             : ! **************************************************************************************************
     420         214 :    SUBROUTINE hfx_ri_update_ks_kp(qs_env, ri_data, ks_matrix, ehfx, rho_ao, &
     421             :                                   geometry_did_change, nspins, hf_fraction)
     422             : 
     423             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     424             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     425             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     426             :       REAL(KIND=dp), INTENT(OUT)                         :: ehfx
     427             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     428             :       LOGICAL, INTENT(IN)                                :: geometry_did_change
     429             :       INTEGER, INTENT(IN)                                :: nspins
     430             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     431             : 
     432             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_ks_kp'
     433             : 
     434             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_spin, iatom, &
     435             :          iblk, igroup, jatom, mb_img, n_batch_nze, natom, ngroups, nimg, nimg_nze
     436             :       INTEGER(int_8)                                     :: nflop, nze
     437         214 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_at, batch_ranges_nze, &
     438         214 :                                                             idx_to_at_AO
     439         214 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     440         214 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: sparsity_pattern
     441             :       LOGICAL                                            :: use_delta_p
     442             :       REAL(dp)                                           :: etmp, fac, occ, pfac, pref, t1, t2, t3, &
     443             :                                                             t4
     444             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     445             :       TYPE(dbcsr_type)                                   :: ks_desymm, rho_desymm, tmp
     446         214 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     447             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     448         214 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ks_t_split, t_2c_ao_tmp, t_2c_work, &
     449         214 :                                                             t_3c_int, t_3c_work_2, t_3c_work_3
     450         214 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: ks_t, ks_t_sub, t_3c_apc, t_3c_apc_sub
     451             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     452             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     453             : 
     454         214 :       NULLIFY (para_env, para_env_sub, blacs_env_sub, hfx_section, dbcsr_template)
     455             : 
     456         214 :       CALL cite_reference(Bussy2024)
     457             : 
     458         214 :       CALL timeset(routineN, handle)
     459             : 
     460         214 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natom)
     461             : 
     462         214 :       IF (nspins == 1) THEN
     463         130 :          fac = 0.5_dp*hf_fraction
     464             :       ELSE
     465          84 :          fac = 1.0_dp*hf_fraction
     466             :       END IF
     467             : 
     468         214 :       IF (geometry_did_change) THEN
     469          70 :          CALL hfx_ri_pre_scf_kp(ks_matrix(1, 1)%matrix, ri_data, qs_env)
     470             :       END IF
     471         214 :       nimg = ri_data%nimg
     472         214 :       nimg_nze = ri_data%nimg_nze
     473             : 
     474             :       !We need to calculate the KS matrix for each periodic cell with index b: F_mu^0,nu^b
     475             :       !F_mu^0,nu^b = -0.5 sum_a,c P_sigma^0,lambda^c (mu^0, sigma^a| P^0) V_P^0,Q^b (Q^b| nu^b lambda^a+c)
     476             :       !with V_P^0,Q^b = (P^0|R^0)^-1 * (R^0|S^b) * (S^b|Q^b)^-1
     477             : 
     478             :       !We use a local RI basis set for each atom in the system, which inlcudes RI basis elements for
     479             :       !each neighboring atom standing within the KIND radius (decay of Gaussian with smallest exponent)
     480             : 
     481             :       !We also limit the number of periodic images we consider accorrding to the HFX potentail in the
     482             :       !RI basis, because if V_P^0,Q^b is zero everywhere, then image b can be ignored (RI basis less diffuse)
     483             : 
     484             :       !We manage to calculate each KS matrix doing a double loop on iamges, and a double loop on atoms
     485             :       !First, we pre-contract and store P_sigma^0,lambda^c (mu^0, sigma^a| P^0) (P^0|R^0)^-1 into T_mu^0,lambda^a+c,P^0
     486             :       !Then, we loop over b_img, iatom, jatom to get (R^0|S^b)
     487             :       !Finally, we do an additional loop over a+c images where we do (R^0|S^b) (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c)
     488             :       !and the final contraction with T_mu^0,lambda^a+c,P^0
     489             : 
     490             :       !Note that the 3-center integrals are pre-contracted with the RI metric, and that the same tensor can be used
     491             :       !(mu^0, sigma^a| P^0) (P^0|R^0)  <===> (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c) by relabelling the images
     492             : 
     493         214 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     494         214 :       CALL section_vals_val_get(hfx_section, "KP_USE_DELTA_P", l_val=use_delta_p)
     495             : 
     496             :       !By default, build the density tensor based on the difference of this SCF P and that of the prev. SCF
     497         214 :       pfac = -1.0_dp
     498         214 :       IF (.NOT. use_delta_p) pfac = 0.0_dp
     499         214 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, pfac, ri_data, qs_env)
     500             : 
     501       14740 :       ALLOCATE (ks_t(nspins, nimg))
     502        5514 :       DO i_img = 1, nimg
     503       12386 :          DO i_spin = 1, nspins
     504       12172 :             CALL dbt_create(ri_data%ks_t(1, 1), ks_t(i_spin, i_img))
     505             :          END DO
     506             :       END DO
     507             : 
     508         642 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     509         214 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     510             : 
     511             :       !First we calculate and store T^1_mu^0,lambda^a+c,P = P_mu^0,lambda^c * (mu_0 sigma^a | P^0) (P^0|R^0)^-1
     512             :       !To avoid doing nimg**2 tiny contractions that do not scale well with a large number of CPUs,
     513             :       !we instead do a single loop over the a+c image index. For each a+c, we get a list of allowed
     514             :       !combination of a,c indices. Then we build TAS tensors P_mu^0,lambda^c with all concerned c's
     515             :       !and (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 with all a's. Then we perform a single contraction with larger tensors,
     516             :       !were the sum over a,c is automatically taken care of
     517       14526 :       ALLOCATE (t_3c_apc(nspins, nimg))
     518        5514 :       DO i_img = 1, nimg
     519       12386 :          DO i_spin = 1, nspins
     520       12172 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     521             :          END DO
     522             :       END DO
     523         214 :       CALL contract_pmat_3c(t_3c_apc, ri_data%rho_ao_t, ri_data, qs_env)
     524             : 
     525         214 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     526         214 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     527         214 :       CALL section_vals_val_get(hfx_section, "KP_STACK_SIZE", i_val=batch_size)
     528         214 :       ri_data%kp_stack_size = batch_size
     529             : 
     530         214 :       IF (MOD(para_env%num_pe, ngroups) .NE. 0) THEN
     531           0 :          CPWARN("KP_NGROUPS must be an integer divisor of the total number of MPI ranks. It was set to 1.")
     532           0 :          ngroups = 1
     533           0 :          CALL section_vals_val_set(hfx_section, "KP_NGROUPS", i_val=ngroups)
     534             :       END IF
     535         214 :       IF ((MOD(ngroups, natom) .NE. 0) .AND. (MOD(natom, ngroups) .NE. 0) .AND. geometry_did_change) THEN
     536           0 :          IF (ngroups > 1) &
     537           0 :             CPWARN("Better load balancing is reached if NGROUPS is a multiple/divisor of the number of atoms")
     538             :       END IF
     539         214 :       group_size = para_env%num_pe/ngroups
     540         214 :       igroup = para_env%mepos/group_size
     541             : 
     542         214 :       ALLOCATE (para_env_sub)
     543         214 :       CALL para_env_sub%from_split(para_env, igroup)
     544         214 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     545             : 
     546             :       ! The sparsity pattern of each iatom, jatom pair, on each b_img, and on which subgroup
     547        1070 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     548         214 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     549         214 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     550             : 
     551             :       !Get all the required tensors in the subgroups
     552       26674 :       ALLOCATE (mat_2c_pot(nimg), ks_t_sub(nspins, nimg), t_2c_ao_tmp(1), ks_t_split(2), t_2c_work(3))
     553             :       CALL get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
     554         214 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     555             : 
     556       26888 :       ALLOCATE (t_3c_int(nimg), t_3c_apc_sub(nspins, nimg), t_3c_work_2(3), t_3c_work_3(3))
     557             :       CALL get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
     558         214 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     559             : 
     560             :       !We go atom by atom, therefore there is an automatic batching along that direction
     561             :       !Also, because we stack the 3c tensors nimg times, we naturally do some batching there too
     562         642 :       ALLOCATE (batch_ranges_at(natom + 1))
     563         214 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     564         214 :       iatom = 0
     565        1042 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     566        1042 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     567         428 :             iatom = iatom + 1
     568         428 :             batch_ranges_at(iatom) = iblk
     569             :          END IF
     570             :       END DO
     571             : 
     572         214 :       n_batch_nze = nimg_nze/batch_size
     573         214 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
     574         642 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
     575         452 :       DO i_batch = 1, n_batch_nze
     576         452 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     577             :       END DO
     578         214 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
     579             : 
     580         214 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     581         214 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     582         214 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     583         214 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     584             : 
     585         214 :       t1 = m_walltime()
     586       37314 :       ri_data%kp_cost(:, :, :) = 0.0_dp
     587         642 :       ALLOCATE (iapc_pairs(nimg, 2))
     588        5514 :       DO b_img = 1, nimg
     589        5300 :          CALL dbt_batched_contract_init(ks_t_split(1))
     590        5300 :          CALL dbt_batched_contract_init(ks_t_split(2))
     591       15900 :          DO jatom = 1, natom
     592       37100 :             DO iatom = 1, natom
     593       21200 :                IF (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup) CYCLE
     594        3606 :                pref = 1.0_dp
     595        3606 :                IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp
     596             : 
     597             :                !measure the cost of the given i, j, b configuration
     598        3606 :                t3 = m_walltime()
     599             : 
     600             :                !Get the proper HFX potential 2c integrals (R_i^0|S_j^b)
     601        3606 :                CALL timeset(routineN//"_2c", handle2)
     602             :                CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     603             :                                    blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     604        3606 :                                    dbcsr_template=dbcsr_template)
     605        3606 :                CALL dbt_copy(t_2c_work(1), t_2c_work(2), move_data=.TRUE.) !move to split blocks
     606        3606 :                CALL dbt_filter(t_2c_work(2), ri_data%filter_eps)
     607        3606 :                CALL timestop(handle2)
     608             : 
     609        3606 :                CALL dbt_batched_contract_init(t_2c_work(2))
     610        3606 :                CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env)
     611        3606 :                CALL timeset(routineN//"_3c", handle2)
     612             : 
     613             :                !Stack the (S^b|Q^b)^-1 * (Q^b| nu^b lambda^a+c) integrals over a+c and multiply by (R_i^0|S_j^b)
     614        8079 :                DO i_batch = 1, n_batch_nze
     615             :                   CALL fill_3c_stack(t_3c_work_3(3), t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
     616             :                                      filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
     617       13419 :                                      img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     618        4473 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1), move_data=.TRUE.)
     619             : 
     620             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_3c_work_3(1), &
     621             :                                     0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
     622             :                                     contract_1=[2], notcontract_1=[1], &
     623             :                                     contract_2=[1], notcontract_2=[2, 3], &
     624        4473 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     625        4473 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     626        4473 :                   CALL dbt_copy(t_3c_work_3(2), t_3c_work_2(2), order=[2, 1, 3], move_data=.TRUE.)
     627        4473 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1))
     628             : 
     629             :                   !Stack the P_sigma^a,lambda^a+c * (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 integrals over a+c and contract
     630             :                   !to get the final block of the KS matrix
     631       13750 :                   DO i_spin = 1, nspins
     632             :                      CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
     633             :                                         ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
     634       17013 :                                         img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     635        5671 :                      CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
     636        5671 :                      IF (nze == 0) CYCLE
     637        5651 :                      CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
     638             :                      CALL dbt_contract(-pref*fac, t_3c_work_2(1), t_3c_work_2(2), &
     639             :                                        1.0_dp, ks_t_split(i_spin), map_1=[1], map_2=[2], &
     640             :                                        contract_1=[2, 3], notcontract_1=[1], &
     641             :                                        contract_2=[2, 3], notcontract_2=[1], &
     642             :                                        filter_eps=ri_data%filter_eps, &
     643        5651 :                                        move_data=i_spin == nspins, flop=nflop)
     644       15795 :                      ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     645             :                   END DO
     646             :                END DO !i_batch
     647        3606 :                CALL timestop(handle2)
     648        3606 :                CALL dbt_batched_contract_finalize(t_2c_work(2))
     649             : 
     650        3606 :                t4 = m_walltime()
     651       39012 :                ri_data%kp_cost(iatom, jatom, b_img) = t4 - t3
     652             :             END DO !iatom
     653             :          END DO !jatom
     654        5300 :          CALL dbt_batched_contract_finalize(ks_t_split(1))
     655        5300 :          CALL dbt_batched_contract_finalize(ks_t_split(2))
     656             : 
     657       12386 :          DO i_spin = 1, nspins
     658        6872 :             CALL dbt_copy(ks_t_split(i_spin), t_2c_ao_tmp(1), move_data=.TRUE.)
     659       12172 :             CALL dbt_copy(t_2c_ao_tmp(1), ks_t_sub(i_spin, b_img), summation=.TRUE.)
     660             :          END DO
     661             :       END DO !b_img
     662         214 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
     663         214 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
     664         214 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
     665         214 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
     666         214 :       CALL para_env%sync()
     667         214 :       CALL para_env%sum(ri_data%dbcsr_nflop)
     668         214 :       CALL para_env%sum(ri_data%kp_cost)
     669         214 :       t2 = m_walltime()
     670         214 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
     671             : 
     672             :       !transfer KS tensor from subgroup to main group
     673         214 :       CALL gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
     674             : 
     675             :       !Keep the 3c integrals on the subgroups to avoid communication at next SCF step
     676        5514 :       DO i_img = 1, nimg
     677        5514 :          CALL dbt_copy(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img), move_data=.TRUE.)
     678             :       END DO
     679             : 
     680             :       !clean-up subgroup tensors
     681         214 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     682         214 :       CALL dbt_destroy(ks_t_split(1))
     683         214 :       CALL dbt_destroy(ks_t_split(2))
     684         214 :       CALL dbt_destroy(t_2c_work(1))
     685         214 :       CALL dbt_destroy(t_2c_work(2))
     686         214 :       CALL dbt_destroy(t_3c_work_2(1))
     687         214 :       CALL dbt_destroy(t_3c_work_2(2))
     688         214 :       CALL dbt_destroy(t_3c_work_2(3))
     689         214 :       CALL dbt_destroy(t_3c_work_3(1))
     690         214 :       CALL dbt_destroy(t_3c_work_3(2))
     691         214 :       CALL dbt_destroy(t_3c_work_3(3))
     692        5514 :       DO i_img = 1, nimg
     693        5300 :          CALL dbt_destroy(t_3c_int(i_img))
     694        5300 :          CALL dbcsr_release(mat_2c_pot(i_img))
     695       12386 :          DO i_spin = 1, nspins
     696        6872 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
     697       12172 :             CALL dbt_destroy(ks_t_sub(i_spin, i_img))
     698             :          END DO
     699             :       END DO
     700         214 :       IF (ASSOCIATED(dbcsr_template)) THEN
     701         214 :          CALL dbcsr_release(dbcsr_template)
     702         214 :          DEALLOCATE (dbcsr_template)
     703             :       END IF
     704             : 
     705             :       !End of subgroup parallelization
     706         214 :       CALL cp_blacs_env_release(blacs_env_sub)
     707         214 :       CALL para_env_sub%free()
     708         214 :       DEALLOCATE (para_env_sub)
     709             : 
     710             :       !Currently, rho_ao_t holds the density difference (wrt to pref SCF step).
     711             :       !ks_t also hold that diff, while only having half the blocks => need to add to prev ks_t and symmetrize
     712             :       !We need the full thing for the energy, on the next SCF step
     713         214 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     714         512 :       DO i_spin = 1, nspins
     715        7384 :          DO b_img = 1, nimg
     716        6872 :             CALL dbt_copy(ks_t(i_spin, b_img), ri_data%ks_t(i_spin, b_img), summation=.TRUE.)
     717             : 
     718             :             !desymmetrize
     719        6872 :             mb_img = get_opp_index(b_img, qs_env)
     720        7170 :             IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
     721        6194 :                CALL dbt_copy(ks_t(i_spin, mb_img), ri_data%ks_t(i_spin, b_img), order=[2, 1], summation=.TRUE.)
     722             :             END IF
     723             :          END DO
     724             :       END DO
     725        5514 :       DO b_img = 1, nimg
     726       12386 :          DO i_spin = 1, nspins
     727       12172 :             CALL dbt_destroy(ks_t(i_spin, b_img))
     728             :          END DO
     729             :       END DO
     730             : 
     731             :       !calculate the energy
     732         214 :       CALL dbt_create(ri_data%ks_t(1, 1), t_2c_ao_tmp(1))
     733         214 :       CALL dbcsr_create(tmp, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_symmetric)
     734         214 :       CALL dbcsr_create(ks_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     735         214 :       CALL dbcsr_create(rho_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     736         214 :       ehfx = 0.0_dp
     737        5514 :       DO i_img = 1, nimg
     738       12386 :          DO i_spin = 1, nspins
     739        6872 :             CALL dbt_filter(ri_data%ks_t(i_spin, i_img), ri_data%filter_eps)
     740        6872 :             CALL dbt_copy(ri_data%ks_t(i_spin, i_img), t_2c_ao_tmp(1))
     741        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), ks_desymm)
     742        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), tmp)
     743        6872 :             CALL dbcsr_add(ks_matrix(i_spin, i_img)%matrix, tmp, 1.0_dp, 1.0_dp)
     744             : 
     745        6872 :             CALL dbt_copy(ri_data%rho_ao_t(i_spin, i_img), t_2c_ao_tmp(1))
     746        6872 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), rho_desymm)
     747             : 
     748        6872 :             CALL dbcsr_dot(ks_desymm, rho_desymm, etmp)
     749        6872 :             ehfx = ehfx + 0.5_dp*etmp
     750             : 
     751       12172 :             IF (.NOT. use_delta_p) CALL dbt_clear(ri_data%ks_t(i_spin, i_img))
     752             :          END DO
     753             :       END DO
     754         214 :       CALL dbcsr_release(rho_desymm)
     755         214 :       CALL dbcsr_release(ks_desymm)
     756         214 :       CALL dbcsr_release(tmp)
     757         214 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     758             : 
     759         214 :       CALL timestop(handle)
     760             : 
     761       36426 :    END SUBROUTINE hfx_ri_update_ks_kp
     762             : 
     763             : ! **************************************************************************************************
     764             : !> \brief Update the K-points RI-HFX forces
     765             : !> \param qs_env ...
     766             : !> \param ri_data ...
     767             : !> \param nspins ...
     768             : !> \param hf_fraction ...
     769             : !> \param rho_ao ...
     770             : !> \param use_virial ...
     771             : !> \note Because this routine uses stored quantities calculated in the energy calculation, they should
     772             : !>       always be called by pairs, and with the same input densities
     773             : ! **************************************************************************************************
     774          42 :    SUBROUTINE hfx_ri_update_forces_kp(qs_env, ri_data, nspins, hf_fraction, rho_ao, use_virial)
     775             : 
     776             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     777             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     778             :       INTEGER, INTENT(IN)                                :: nspins
     779             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     780             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     781             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_virial
     782             : 
     783             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_forces_kp'
     784             : 
     785             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_loop, i_spin, &
     786             :          i_xyz, iatom, iblk, igroup, j_xyz, jatom, k_xyz, n_batch, natom, ngroups, nimg, nimg_nze
     787             :       INTEGER(int_8)                                     :: nflop, nze
     788          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_of_kind, batch_ranges_at, &
     789          42 :                                                             batch_ranges_nze, dist1, dist2, &
     790          42 :                                                             i_images, idx_to_at_AO, idx_to_at_RI, &
     791          42 :                                                             kind_of
     792          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     793          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: force_pattern, sparsity_pattern
     794             :       INTEGER, DIMENSION(2, 1)                           :: bounds_iat, bounds_jat
     795             :       LOGICAL                                            :: use_virial_prv
     796             :       REAL(dp)                                           :: fac, occ, pref, t1, t2
     797             :       REAL(dp), DIMENSION(3, 3)                          :: work_virial
     798          42 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     799             :       TYPE(cell_type), POINTER                           :: cell
     800             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     801          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     802          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_pot, mat_der_pot_sub
     803             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     804         714 :       TYPE(dbt_type)                                     :: t_2c_R, t_2c_R_split
     805          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_2c_bint, t_2c_binv, t_2c_der_pot, &
     806          84 :                                                             t_2c_inv, t_2c_metric, t_2c_work, &
     807          42 :                                                             t_3c_der_stack, t_3c_work_2, &
     808          42 :                                                             t_3c_work_3
     809          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :) :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
     810          84 :          t_2c_der_metric_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_AO, t_3c_der_AO_sub, t_3c_der_RI, &
     811          42 :          t_3c_der_RI_sub
     812             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     813          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     814          42 :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
     815             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     816             :       TYPE(virial_type), POINTER                         :: virial
     817             : 
     818          42 :       NULLIFY (para_env, para_env_sub, hfx_section, blacs_env_sub, dbcsr_template, force, atomic_kind_set, &
     819          42 :                virial, particle_set, cell)
     820             : 
     821          42 :       CALL timeset(routineN, handle)
     822             : 
     823          42 :       use_virial_prv = .FALSE.
     824          42 :       IF (PRESENT(use_virial)) use_virial_prv = use_virial
     825             : 
     826          42 :       IF (nspins == 1) THEN
     827          28 :          fac = 0.5_dp*hf_fraction
     828             :       ELSE
     829          14 :          fac = 1.0_dp*hf_fraction
     830             :       END IF
     831             : 
     832             :       CALL get_qs_env(qs_env, natom=natom, para_env=para_env, force=force, cell=cell, virial=virial, &
     833          42 :                       atomic_kind_set=atomic_kind_set, particle_set=particle_set)
     834          42 :       CALL get_atomic_kind_set(atomic_kind_set, kind_of=kind_of, atom_of_kind=atom_of_kind)
     835             : 
     836         126 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     837          42 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     838             : 
     839         126 :       ALLOCATE (idx_to_at_RI(SIZE(ri_data%bsizes_RI_split)))
     840          42 :       CALL get_idx_to_atom(idx_to_at_RI, ri_data%bsizes_RI_split, ri_data%bsizes_RI)
     841             : 
     842          42 :       nimg = ri_data%nimg
     843       10542 :       ALLOCATE (t_3c_der_RI(nimg, 3), t_3c_der_AO(nimg, 3), mat_der_pot(nimg, 3), t_2c_der_metric(natom, 3))
     844             : 
     845             :       !We assume that the integrals are available from the SCF
     846             :       !pre-calculate the derivs. 3c tensors as (P^0| sigma^a mu^0), with t_3c_der_AO holding deriv wrt mu^0
     847          42 :       CALL precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
     848             : 
     849             :       !Calculate the density matrix at each image
     850        2546 :       ALLOCATE (rho_ao_t(nspins, nimg))
     851             :       CALL create_2c_tensor(rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     852             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     853          42 :                             name="(AO | AO)")
     854          42 :       DEALLOCATE (dist1, dist2)
     855          42 :       IF (nspins == 2) CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(2, 1))
     856         938 :       DO i_img = 2, nimg
     857        1986 :          DO i_spin = 1, nspins
     858        1944 :             CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(i_spin, i_img))
     859             :          END DO
     860             :       END DO
     861          42 :       CALL get_pmat_images(rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     862             : 
     863             :       !Contract integrals with the density matrix
     864        2546 :       ALLOCATE (t_3c_apc(nspins, nimg))
     865         980 :       DO i_img = 1, nimg
     866        2084 :          DO i_spin = 1, nspins
     867        2042 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     868             :          END DO
     869             :       END DO
     870          42 :       CALL contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
     871             : 
     872             :       !Setup the subgroups
     873          42 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     874          42 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     875          42 :       group_size = para_env%num_pe/ngroups
     876          42 :       igroup = para_env%mepos/group_size
     877             : 
     878          42 :       ALLOCATE (para_env_sub)
     879          42 :       CALL para_env_sub%from_split(para_env, igroup)
     880          42 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     881             : 
     882             :       !Get the ususal sparsity pattern
     883         210 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     884          42 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     885          42 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     886             : 
     887             :       !Get the 2-center quantities in the subgroups (note: main group derivs are deleted wihtin)
     888           0 :       ALLOCATE (t_2c_inv(natom), mat_2c_pot(nimg), rho_ao_t_sub(nspins, nimg), t_2c_work(5), &
     889           0 :                 t_2c_der_metric_sub(natom, 3), mat_der_pot_sub(nimg, 3), t_2c_bint(natom), &
     890        9868 :                 t_2c_metric(natom), t_2c_binv(natom))
     891             :       CALL get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
     892             :                                   rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
     893          42 :                                   mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
     894          42 :       CALL dbt_create(t_2c_work(1), t_2c_R) !nRI x nRI
     895          42 :       CALL dbt_create(t_2c_work(5), t_2c_R_split) !nRI x nRI with split blocks
     896             : 
     897         504 :       ALLOCATE (t_2c_der_pot(3))
     898         168 :       DO i_xyz = 1, 3
     899         168 :          CALL dbt_create(t_2c_R, t_2c_der_pot(i_xyz))
     900             :       END DO
     901             : 
     902             :       !Get the 3-center quantities in the subgroups. The integrals and t_3c_apc already there
     903           0 :       ALLOCATE (t_3c_work_2(3), t_3c_work_3(4), t_3c_der_stack(6), t_3c_der_AO_sub(nimg, 3), &
     904       10736 :                 t_3c_der_RI_sub(nimg, 3), t_3c_apc_sub(nspins, nimg))
     905             :       CALL get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
     906             :                                   t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_stack, &
     907          42 :                                   group_size, ngroups, para_env, para_env_sub, ri_data)
     908             : 
     909             :       !Set up batched contraction (go atom by atom)
     910         126 :       ALLOCATE (batch_ranges_at(natom + 1))
     911          42 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     912          42 :       iatom = 0
     913         212 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     914         212 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     915          84 :             iatom = iatom + 1
     916          84 :             batch_ranges_at(iatom) = iblk
     917             :          END IF
     918             :       END DO
     919             : 
     920          42 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     921          42 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     922          42 :       CALL dbt_batched_contract_init(t_3c_work_3(3), batch_range_2=batch_ranges_at)
     923          42 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     924          42 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     925             : 
     926             :       !Preparing for the stacking of 3c tensors
     927          42 :       nimg_nze = ri_data%nimg_nze
     928          42 :       batch_size = ri_data%kp_stack_size
     929          42 :       n_batch = nimg_nze/batch_size
     930          42 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch = n_batch + 1
     931         126 :       ALLOCATE (batch_ranges_nze(n_batch + 1))
     932          90 :       DO i_batch = 1, n_batch
     933          90 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     934             :       END DO
     935          42 :       batch_ranges_nze(n_batch + 1) = nimg_nze + 1
     936             : 
     937             :       !Applying the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 from left and right
     938             :       !And keep the bump on LHS only version as well, with B*M^-1 = (M^-1*B)^T
     939         126 :       DO iatom = 1, natom
     940          84 :          CALL dbt_create(t_2c_inv(iatom), t_2c_binv(iatom))
     941          84 :          CALL dbt_copy(t_2c_inv(iatom), t_2c_binv(iatom))
     942          84 :          CALL apply_bump(t_2c_binv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     943         126 :          CALL apply_bump(t_2c_inv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
     944             :       END DO
     945             : 
     946          42 :       t1 = m_walltime()
     947          42 :       work_virial = 0.0_dp
     948         210 :       ALLOCATE (iapc_pairs(nimg, 2), i_images(nimg))
     949         210 :       ALLOCATE (force_pattern(natom, natom, nimg))
     950        6608 :       force_pattern(:, :, :) = -1
     951             :       !We proceed with 2 loops: one over the sparsity pattern from the SCF, one over the rest
     952             :       !We use the SCF cost model for the first loop, while we calculate the cost of the upcoming loop
     953         126 :       DO i_loop = 1, 2
     954        1960 :          DO b_img = 1, nimg
     955        5712 :             DO jatom = 1, natom
     956       13132 :                DO iatom = 1, natom
     957             : 
     958        7504 :                   pref = -0.5_dp*fac
     959        7504 :                   IF (i_loop == 1 .AND. (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     960        4333 :                   IF (i_loop == 2 .AND. (.NOT. force_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     961             : 
     962             :                   !Get the proper HFX potential 2c integrals (R_i^0|S_j^b), times (S_j^b|Q_j^b)^-1
     963        1102 :                   CALL timeset(routineN//"_2c_1", handle2)
     964             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     965             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     966        1102 :                                       dbcsr_template=dbcsr_template)
     967             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_inv(jatom), &
     968             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
     969             :                                     contract_1=[2], notcontract_1=[1], &
     970             :                                     contract_2=[1], notcontract_2=[2], &
     971        1102 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     972        1102 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(5), move_data=.TRUE.) !move to split blocks
     973        1102 :                   CALL dbt_filter(t_2c_work(5), ri_data%filter_eps)
     974        1102 :                   CALL timestop(handle2)
     975             : 
     976        1102 :                   CALL timeset(routineN//"_3c", handle2)
     977        5504 :                   bounds_iat(:, 1) = [SUM(ri_data%bsizes_AO(1:iatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:iatom))]
     978        5468 :                   bounds_jat(:, 1) = [SUM(ri_data%bsizes_AO(1:jatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:jatom))]
     979        1102 :                   CALL dbt_clear(t_2c_R_split)
     980             : 
     981        2443 :                   DO i_spin = 1, nspins
     982        2443 :                      CALL dbt_batched_contract_init(rho_ao_t_sub(i_spin, b_img))
     983             :                   END DO
     984             : 
     985        1102 :                   CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env, i_images) !i = a+c-b
     986        2499 :                   DO i_batch = 1, n_batch
     987             : 
     988             :                      !Stack the 3c derivatives to take the trace later on
     989        5588 :                      DO i_xyz = 1, 3
     990        4191 :                         CALL dbt_clear(t_3c_der_stack(i_xyz))
     991             :                         CALL fill_3c_stack(t_3c_der_stack(i_xyz), t_3c_der_RI_sub(:, i_xyz), &
     992             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     993             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
     994       12573 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     995             : 
     996        4191 :                         CALL dbt_clear(t_3c_der_stack(3 + i_xyz))
     997             :                         CALL fill_3c_stack(t_3c_der_stack(3 + i_xyz), t_3c_der_AO_sub(:, i_xyz), &
     998             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     999             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
    1000       13970 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1001             :                      END DO
    1002             : 
    1003        4135 :                      DO i_spin = 1, nspins
    1004             :                         !stack the t_3c_apc tensors
    1005        1636 :                         CALL dbt_clear(t_3c_work_2(3))
    1006             :                         CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
    1007             :                                            ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
    1008        4908 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1009        1636 :                         CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
    1010        1636 :                         IF (nze == 0) CYCLE
    1011        1636 :                         CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
    1012             : 
    1013             :                         !Contract with the second density matrix: P_mu^0,nu^b * t_3c_apc,
    1014             :                         !where t_3c_apc = P_sigma^a,lambda^a+c (mu^0 P^0 sigma^a) *(P^0|R^0)^-1 (stacked along a+c)
    1015             :                         CALL dbt_contract(1.0_dp, rho_ao_t_sub(i_spin, b_img), t_3c_work_2(1), &
    1016             :                                           0.0_dp, t_3c_work_2(2), map_1=[1], map_2=[2, 3], &
    1017             :                                           contract_1=[1], notcontract_1=[2], &
    1018             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1019             :                                           bounds_1=bounds_iat, bounds_2=bounds_jat, &
    1020        1636 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1021        1636 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1022             : 
    1023        1636 :                         CALL get_tensor_occupancy(t_3c_work_2(2), nze, occ)
    1024        1636 :                         IF (nze == 0) CYCLE
    1025             : 
    1026             :                         !Contract with V_PQ so that we can take the trace with (Q^b|nu^b lmabda^a+c)^(x)
    1027        1525 :                         CALL dbt_copy(t_3c_work_2(2), t_3c_work_3(1), order=[2, 1, 3], move_data=.TRUE.)
    1028        1525 :                         CALL dbt_batched_contract_init(t_2c_work(5))
    1029             :                         CALL dbt_contract(1.0_dp, t_2c_work(5), t_3c_work_3(1), &
    1030             :                                           0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
    1031             :                                           contract_1=[1], notcontract_1=[2], &
    1032             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1033        1525 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1034        1525 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1035        1525 :                         CALL dbt_batched_contract_finalize(t_2c_work(5))
    1036             : 
    1037             :                         !Contract with the 3c derivatives to get the force/virial
    1038        1525 :                         CALL dbt_copy(t_3c_work_3(2), t_3c_work_3(4), move_data=.TRUE.)
    1039        1525 :                         IF (use_virial_prv) THEN
    1040             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1041             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1042             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1043             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1044         257 :                                                         ri_data, qs_env, work_virial, cell, particle_set)
    1045             :                         ELSE
    1046             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1047             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1048             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1049             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1050        1268 :                                                         ri_data, qs_env)
    1051             :                         END IF
    1052        1525 :                         CALL dbt_clear(t_3c_work_3(4))
    1053             : 
    1054             :                         !Contract with the 3-center integrals in order to have a matrix R_PQ such that
    1055             :                         !we can take the trace sum_PQ R_PQ (P^0|Q^b)^(x)
    1056        1525 :                         IF (i_loop == 2) CYCLE
    1057             : 
    1058             :                         !Stack the 3c integrals
    1059             :                         CALL fill_3c_stack(t_3c_work_3(4), ri_data%kp_t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
    1060             :                                            filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
    1061        2415 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1062         805 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(3), move_data=.TRUE.)
    1063             : 
    1064         805 :                         CALL dbt_batched_contract_init(t_2c_R_split)
    1065             :                         CALL dbt_contract(1.0_dp, t_3c_work_3(1), t_3c_work_3(3), &
    1066             :                                           1.0_dp, t_2c_R_split, map_1=[1], map_2=[2], &
    1067             :                                           contract_1=[2, 3], notcontract_1=[1], &
    1068             :                                           contract_2=[2, 3], notcontract_2=[1], &
    1069         805 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1070         805 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1071         805 :                         CALL dbt_batched_contract_finalize(t_2c_R_split)
    1072        5474 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(1))
    1073             :                      END DO
    1074             :                   END DO
    1075        2443 :                   DO i_spin = 1, nspins
    1076        2443 :                      CALL dbt_batched_contract_finalize(rho_ao_t_sub(i_spin, b_img))
    1077             :                   END DO
    1078        1102 :                   CALL timestop(handle2)
    1079             : 
    1080        1102 :                   IF (i_loop == 2) CYCLE
    1081         581 :                   pref = 2.0_dp*pref
    1082         581 :                   IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp*pref
    1083             : 
    1084         581 :                   CALL timeset(routineN//"_2c_2", handle2)
    1085             :                   !Note that the derivatives are in atomic block format (not split)
    1086         581 :                   CALL dbt_copy(t_2c_R_split, t_2c_R, move_data=.TRUE.)
    1087             : 
    1088             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
    1089             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
    1090         581 :                                       dbcsr_template=dbcsr_template)
    1091             : 
    1092             :                   !We have to calculate: S^-1(iat) * R_PQ * S^-1(jat)    to trace with HFX pot der
    1093             :                   !                      + R_PQ * S^-1(jat) * pot^T      to trace with S^(x) (iat)
    1094             :                   !                      + pot^T * S^-1(iat) *R_PQ       to trace with S^(x) (jat)
    1095             : 
    1096             :                   !Because 3c tensors are all precontracted with the inverse RI metric,
    1097             :                   !t_2c_R is currently implicitely multiplied by S^-1(iat) from the left
    1098             :                   !and S^-1(jat) from the right, directly in the proper format for the trace
    1099             :                   !with the HFX potential derivative
    1100             : 
    1101             :                   !Trace with HFX pot deriv, that we need to build first
    1102        2324 :                   DO i_xyz = 1, 3
    1103             :                      CALL get_ext_2c_int(t_2c_der_pot(i_xyz), mat_der_pot_sub(:, i_xyz), iatom, jatom, &
    1104             :                                          b_img, ri_data, qs_env, blacs_env_ext=blacs_env_sub, &
    1105        2324 :                                          para_env_ext=para_env_sub, dbcsr_template=dbcsr_template)
    1106             :                   END DO
    1107             : 
    1108         581 :                   IF (use_virial_prv) THEN
    1109             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1110         113 :                                            b_img, pref, ri_data, qs_env, work_virial, cell, particle_set)
    1111             :                   ELSE
    1112             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1113         468 :                                            b_img, pref, ri_data, qs_env)
    1114             :                   END IF
    1115             : 
    1116        2324 :                   DO i_xyz = 1, 3
    1117        2324 :                      CALL dbt_clear(t_2c_der_pot(i_xyz))
    1118             :                   END DO
    1119             : 
    1120             :                   !R_PQ * S^-1(jat) * pot^T  (=A)
    1121             :                   CALL dbt_contract(1.0_dp, t_2c_metric(iatom), t_2c_R, & !get rid of implicit S^-1(iat)
    1122             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1123             :                                     contract_1=[2], notcontract_1=[1], &
    1124             :                                     contract_2=[1], notcontract_2=[2], &
    1125         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1126         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1127             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_work(1), &
    1128             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1129             :                                     contract_1=[2], notcontract_1=[1], &
    1130             :                                     contract_2=[2], notcontract_2=[1], &
    1131         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1132         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1133             : 
    1134             :                   !With the RI bump function, things get more complex. M = (S|P)_D + B*(S|P)_OD*B
    1135             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1136             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(iatom), &
    1137             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1138             :                                     contract_1=[2], notcontract_1=[1], &
    1139             :                                     contract_2=[1], notcontract_2=[2], &
    1140         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1141         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1142             : 
    1143             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1144             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1145             :                                     contract_1=[1], notcontract_1=[2], &
    1146             :                                     contract_2=[1], notcontract_2=[2], &
    1147         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1148         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1149             : 
    1150         581 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1151             :                   CALL get_2c_bump_forces(force, t_2c_work(4), iatom, atom_of_kind, kind_of, pref, &
    1152         581 :                                           ri_data, qs_env, work_virial)
    1153             : 
    1154             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1155             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(2), &
    1156             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1157             :                                     contract_1=[1], notcontract_1=[2], &
    1158             :                                     contract_2=[1], notcontract_2=[2], &
    1159         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1160         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1161             : 
    1162         581 :                   IF (use_virial_prv) THEN
    1163             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1164             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1165         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1166             :                   ELSE
    1167             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1168         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1169             :                   END IF
    1170             : 
    1171             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1172         581 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1173         581 :                   CALL apply_bump(t_2c_work(2), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1174             : 
    1175         581 :                   IF (use_virial_prv) THEN
    1176             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1177             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1178         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1179             :                   ELSE
    1180             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1181         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1182             :                   END IF
    1183             : 
    1184             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1185             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1186             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(iatom), &
    1187             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1188             :                                     contract_1=[2], notcontract_1=[1], &
    1189             :                                     contract_2=[1], notcontract_2=[2], &
    1190         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1191         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1192             : 
    1193             :                   CALL dbt_contract(1.0_dp, t_2c_bint(iatom), t_2c_work(4), &
    1194             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1195             :                                     contract_1=[1], notcontract_1=[2], &
    1196             :                                     contract_2=[1], notcontract_2=[2], &
    1197         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1198         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1199             : 
    1200             :                   CALL get_2c_bump_forces(force, t_2c_work(2), iatom, atom_of_kind, kind_of, -pref, &
    1201         581 :                                           ri_data, qs_env, work_virial)
    1202             : 
    1203             :                   ! pot^T * S^-1(iat) * R_PQ (=A)
    1204             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_R, &
    1205             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1206             :                                     contract_1=[1], notcontract_1=[2], &
    1207             :                                     contract_2=[1], notcontract_2=[2], &
    1208         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1209         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1210             : 
    1211             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_metric(jatom), & !get rid of implicit S^-1(jat)
    1212             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1213             :                                     contract_1=[2], notcontract_1=[1], &
    1214             :                                     contract_2=[1], notcontract_2=[2], &
    1215         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1216         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1217             : 
    1218             :                   !Do the same shenanigans with the S^(x) (jatom)
    1219             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1220             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(jatom), &
    1221             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1222             :                                     contract_1=[2], notcontract_1=[1], &
    1223             :                                     contract_2=[1], notcontract_2=[2], &
    1224         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1225         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1226             : 
    1227             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1228             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1229             :                                     contract_1=[1], notcontract_1=[2], &
    1230             :                                     contract_2=[1], notcontract_2=[2], &
    1231         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1232         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1233             : 
    1234         581 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1235             :                   CALL get_2c_bump_forces(force, t_2c_work(4), jatom, atom_of_kind, kind_of, pref, &
    1236         581 :                                           ri_data, qs_env, work_virial)
    1237             : 
    1238             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1239             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(2), &
    1240             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1241             :                                     contract_1=[1], notcontract_1=[2], &
    1242             :                                     contract_2=[1], notcontract_2=[2], &
    1243         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1244         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1245             : 
    1246         581 :                   IF (use_virial_prv) THEN
    1247             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1248             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1249         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1250             :                   ELSE
    1251             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1252         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1253             :                   END IF
    1254             : 
    1255             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1256         581 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1257         581 :                   CALL apply_bump(t_2c_work(2), jatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1258             : 
    1259         581 :                   IF (use_virial_prv) THEN
    1260             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1261             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1262         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1263             :                   ELSE
    1264             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1265         468 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1266             :                   END IF
    1267             : 
    1268             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1269             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1270             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(jatom), &
    1271             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1272             :                                     contract_1=[2], notcontract_1=[1], &
    1273             :                                     contract_2=[1], notcontract_2=[2], &
    1274         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1275         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1276             : 
    1277             :                   CALL dbt_contract(1.0_dp, t_2c_bint(jatom), t_2c_work(4), &
    1278             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1279             :                                     contract_1=[1], notcontract_1=[2], &
    1280             :                                     contract_2=[1], notcontract_2=[2], &
    1281         581 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1282         581 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1283             : 
    1284             :                   CALL get_2c_bump_forces(force, t_2c_work(2), jatom, atom_of_kind, kind_of, -pref, &
    1285         581 :                                           ri_data, qs_env, work_virial)
    1286             : 
    1287       13520 :                   CALL timestop(handle2)
    1288             :                END DO !iatom
    1289             :             END DO !jatom
    1290             :          END DO !b_img
    1291             : 
    1292         126 :          IF (i_loop == 1) THEN
    1293          42 :             CALL update_pattern_to_forces(force_pattern, sparsity_pattern, ngroups, ri_data, qs_env)
    1294             :          END IF
    1295             :       END DO !i_loop
    1296             : 
    1297          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
    1298          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
    1299          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(3))
    1300          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
    1301          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
    1302             : 
    1303          42 :       IF (use_virial_prv) THEN
    1304          32 :          DO k_xyz = 1, 3
    1305         104 :             DO j_xyz = 1, 3
    1306         312 :                DO i_xyz = 1, 3
    1307             :                   virial%pv_fock_4c(i_xyz, j_xyz) = virial%pv_fock_4c(i_xyz, j_xyz) &
    1308         288 :                                                     + work_virial(i_xyz, k_xyz)*cell%hmat(j_xyz, k_xyz)
    1309             :                END DO
    1310             :             END DO
    1311             :          END DO
    1312             :       END IF
    1313             : 
    1314             :       !End of subgroup parallelization
    1315          42 :       CALL cp_blacs_env_release(blacs_env_sub)
    1316          42 :       CALL para_env_sub%free()
    1317          42 :       DEALLOCATE (para_env_sub)
    1318             : 
    1319          42 :       CALL para_env%sync()
    1320          42 :       t2 = m_walltime()
    1321          42 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    1322             : 
    1323             :       !clean-up
    1324          42 :       IF (ASSOCIATED(dbcsr_template)) THEN
    1325          42 :          CALL dbcsr_release(dbcsr_template)
    1326          42 :          DEALLOCATE (dbcsr_template)
    1327             :       END IF
    1328          42 :       CALL dbt_destroy(t_2c_R)
    1329          42 :       CALL dbt_destroy(t_2c_R_split)
    1330          42 :       CALL dbt_destroy(t_2c_work(1))
    1331          42 :       CALL dbt_destroy(t_2c_work(2))
    1332          42 :       CALL dbt_destroy(t_2c_work(3))
    1333          42 :       CALL dbt_destroy(t_2c_work(4))
    1334          42 :       CALL dbt_destroy(t_2c_work(5))
    1335          42 :       CALL dbt_destroy(t_3c_work_2(1))
    1336          42 :       CALL dbt_destroy(t_3c_work_2(2))
    1337          42 :       CALL dbt_destroy(t_3c_work_2(3))
    1338          42 :       CALL dbt_destroy(t_3c_work_3(1))
    1339          42 :       CALL dbt_destroy(t_3c_work_3(2))
    1340          42 :       CALL dbt_destroy(t_3c_work_3(3))
    1341          42 :       CALL dbt_destroy(t_3c_work_3(4))
    1342          42 :       CALL dbt_destroy(t_3c_der_stack(1))
    1343          42 :       CALL dbt_destroy(t_3c_der_stack(2))
    1344          42 :       CALL dbt_destroy(t_3c_der_stack(3))
    1345          42 :       CALL dbt_destroy(t_3c_der_stack(4))
    1346          42 :       CALL dbt_destroy(t_3c_der_stack(5))
    1347          42 :       CALL dbt_destroy(t_3c_der_stack(6))
    1348         168 :       DO i_xyz = 1, 3
    1349         168 :          CALL dbt_destroy(t_2c_der_pot(i_xyz))
    1350             :       END DO
    1351         126 :       DO iatom = 1, natom
    1352          84 :          CALL dbt_destroy(t_2c_inv(iatom))
    1353          84 :          CALL dbt_destroy(t_2c_binv(iatom))
    1354          84 :          CALL dbt_destroy(t_2c_bint(iatom))
    1355          84 :          CALL dbt_destroy(t_2c_metric(iatom))
    1356         378 :          DO i_xyz = 1, 3
    1357         336 :             CALL dbt_destroy(t_2c_der_metric_sub(iatom, i_xyz))
    1358             :          END DO
    1359             :       END DO
    1360         980 :       DO i_img = 1, nimg
    1361         938 :          CALL dbcsr_release(mat_2c_pot(i_img))
    1362        2084 :          DO i_spin = 1, nspins
    1363        1104 :             CALL dbt_destroy(rho_ao_t_sub(i_spin, i_img))
    1364        2042 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
    1365             :          END DO
    1366             :       END DO
    1367         168 :       DO i_xyz = 1, 3
    1368        2982 :          DO i_img = 1, nimg
    1369        2814 :             CALL dbt_destroy(t_3c_der_RI_sub(i_img, i_xyz))
    1370        2814 :             CALL dbt_destroy(t_3c_der_AO_sub(i_img, i_xyz))
    1371        2940 :             CALL dbcsr_release(mat_der_pot_sub(i_img, i_xyz))
    1372             :          END DO
    1373             :       END DO
    1374             : 
    1375          42 :       CALL timestop(handle)
    1376             : 
    1377       17604 :    END SUBROUTINE hfx_ri_update_forces_kp
    1378             : 
    1379             : ! **************************************************************************************************
    1380             : !> \brief A routine the applies the RI bump matrix from the left and/or the right, given an input
    1381             : !>        matrix and the central RI atom. We assume atomic block sizes
    1382             : !> \param t_2c_inout ...
    1383             : !> \param atom_i ...
    1384             : !> \param ri_data ...
    1385             : !> \param qs_env ...
    1386             : !> \param from_left ...
    1387             : !> \param from_right ...
    1388             : !> \param debump ...
    1389             : ! **************************************************************************************************
    1390        1750 :    SUBROUTINE apply_bump(t_2c_inout, atom_i, ri_data, qs_env, from_left, from_right, debump)
    1391             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_inout
    1392             :       INTEGER, INTENT(IN)                                :: atom_i
    1393             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1394             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1395             :       LOGICAL, INTENT(IN), OPTIONAL                      :: from_left, from_right, debump
    1396             : 
    1397             :       INTEGER                                            :: i_img, i_RI, iatom, ind(2), j_img, j_RI, &
    1398             :                                                             jatom, natom, nblks(2), nimg, nkind
    1399        1750 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1400        1750 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1401             :       LOGICAL                                            :: found, my_debump, my_left, my_right
    1402             :       REAL(dp)                                           :: bval, r0, r1, ri(3), rj(3), rref(3), &
    1403             :                                                             scoord(3)
    1404        1750 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1405             :       TYPE(cell_type), POINTER                           :: cell
    1406             :       TYPE(dbt_iterator_type)                            :: iter
    1407             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1408        1750 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1409        1750 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1410             : 
    1411        1750 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1412             : 
    1413             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1414        1750 :                       kpoints=kpoints, particle_set=particle_set)
    1415        1750 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1416             : 
    1417        1750 :       my_debump = .FALSE.
    1418        1750 :       IF (PRESENT(debump)) my_debump = debump
    1419             : 
    1420        1750 :       my_left = .FALSE.
    1421        1750 :       IF (PRESENT(from_left)) my_left = from_left
    1422             : 
    1423        1750 :       my_right = .FALSE.
    1424        1750 :       IF (PRESENT(from_right)) my_right = from_right
    1425        1750 :       CPASSERT(my_left .OR. my_right)
    1426             : 
    1427        1750 :       CALL dbt_get_info(t_2c_inout, nblks_total=nblks)
    1428        1750 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1429        1750 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1430             : 
    1431        1750 :       nimg = ri_data%nimg
    1432             : 
    1433             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1434        1750 :       r1 = ri_data%kp_RI_range
    1435        1750 :       r0 = ri_data%kp_bump_rad
    1436        1750 :       rref = pbc(particle_set(atom_i)%r, cell)
    1437             : 
    1438             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_inout,natom,ri_data,cell,particle_set,index_to_cell,my_left, &
    1439             : !$OMP                               my_right,r0,r1,rref,my_debump) &
    1440        1750 : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,bval)
    1441             :       CALL dbt_iterator_start(iter, t_2c_inout)
    1442             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1443             :          CALL dbt_iterator_next_block(iter, ind)
    1444             :          CALL dbt_get_block(t_2c_inout, ind, blk, found)
    1445             :          IF (.NOT. found) CYCLE
    1446             : 
    1447             :          i_RI = (ind(1) - 1)/natom + 1
    1448             :          i_img = ri_data%RI_cell_to_img(i_RI)
    1449             :          iatom = ind(1) - (i_RI - 1)*natom
    1450             : 
    1451             :          CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    1452             :          CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    1453             : 
    1454             :          j_RI = (ind(2) - 1)/natom + 1
    1455             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1456             :          jatom = ind(2) - (j_RI - 1)*natom
    1457             : 
    1458             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1459             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1460             : 
    1461             :          IF (.NOT. my_debump) THEN
    1462             :             IF (my_left) blk(:, :) = blk(:, :)*bump(NORM2(ri - rref), r0, r1)
    1463             :             IF (my_right) blk(:, :) = blk(:, :)*bump(NORM2(rj - rref), r0, r1)
    1464             :          ELSE
    1465             :             !Note: by construction, the bump function is never quite zero, as its range is the same
    1466             :             !      as that of the extended RI basis (but we are safe)
    1467             :             bval = bump(NORM2(ri - rref), r0, r1)
    1468             :             IF (my_left .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1469             :             bval = bump(NORM2(rj - rref), r0, r1)
    1470             :             IF (my_right .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1471             :          END IF
    1472             : 
    1473             :          CALL dbt_put_block(t_2c_inout, ind, SHAPE(blk), blk)
    1474             : 
    1475             :          DEALLOCATE (blk)
    1476             :       END DO
    1477             :       CALL dbt_iterator_stop(iter)
    1478             : !$OMP END PARALLEL
    1479        1750 :       CALL dbt_filter(t_2c_inout, ri_data%filter_eps)
    1480             : 
    1481        3500 :    END SUBROUTINE apply_bump
    1482             : 
    1483             : ! **************************************************************************************************
    1484             : !> \brief A routine that calculates the forces due to the derivative of the bump function
    1485             : !> \param force ...
    1486             : !> \param t_2c_in ...
    1487             : !> \param atom_i ...
    1488             : !> \param atom_of_kind ...
    1489             : !> \param kind_of ...
    1490             : !> \param pref ...
    1491             : !> \param ri_data ...
    1492             : !> \param qs_env ...
    1493             : !> \param work_virial ...
    1494             : ! **************************************************************************************************
    1495        2324 :    SUBROUTINE get_2c_bump_forces(force, t_2c_in, atom_i, atom_of_kind, kind_of, pref, ri_data, &
    1496             :                                  qs_env, work_virial)
    1497             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    1498             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_in
    1499             :       INTEGER, INTENT(IN)                                :: atom_i
    1500             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    1501             :       REAL(dp), INTENT(IN)                               :: pref
    1502             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1503             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1504             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT)           :: work_virial
    1505             : 
    1506             :       INTEGER :: i, i_img, i_RI, i_xyz, iat_of_kind, iatom, ikind, ind(2), j_img, j_RI, j_xyz, &
    1507             :          jat_of_kind, jatom, jkind, natom, nblks(2), nimg, nkind
    1508        2324 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1509        2324 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1510             :       LOGICAL                                            :: found
    1511             :       REAL(dp)                                           :: new_force, r0, r1, ri(3), rj(3), &
    1512             :                                                             rref(3), scoord(3), x
    1513        2324 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1514             :       TYPE(cell_type), POINTER                           :: cell
    1515             :       TYPE(dbt_iterator_type)                            :: iter
    1516             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1517        2324 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1518        2324 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1519             : 
    1520        2324 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1521             : 
    1522             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1523        2324 :                       kpoints=kpoints, particle_set=particle_set)
    1524        2324 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1525             : 
    1526        2324 :       CALL dbt_get_info(t_2c_in, nblks_total=nblks)
    1527        2324 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1528        2324 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1529             : 
    1530        2324 :       nimg = ri_data%nimg
    1531             : 
    1532             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1533        2324 :       r1 = ri_data%kp_RI_range
    1534        2324 :       r0 = ri_data%kp_bump_rad
    1535        2324 :       rref = pbc(particle_set(atom_i)%r, cell)
    1536             : 
    1537        2324 :       iat_of_kind = atom_of_kind(atom_i)
    1538        2324 :       ikind = kind_of(atom_i)
    1539             : 
    1540             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_in,natom,ri_data,cell,particle_set,index_to_cell,pref, &
    1541             : !$OMP force,r0,r1,rref,atom_of_kind,kind_of,iat_of_kind,ikind,work_virial) &
    1542             : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,jkind,jat_of_kind, &
    1543        2324 : !$OMP         new_force,i_xyz,i,x,j_xyz)
    1544             :       CALL dbt_iterator_start(iter, t_2c_in)
    1545             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1546             :          CALL dbt_iterator_next_block(iter, ind)
    1547             :          IF (ind(1) .NE. ind(2)) CYCLE !bump matrix is diagonal
    1548             : 
    1549             :          CALL dbt_get_block(t_2c_in, ind, blk, found)
    1550             :          IF (.NOT. found) CYCLE
    1551             : 
    1552             :          !bump is a function of x = SQRT((R - Rref)^2). We refer to R as jatom, and Rref as atom_i
    1553             :          j_RI = (ind(2) - 1)/natom + 1
    1554             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1555             :          jatom = ind(2) - (j_RI - 1)*natom
    1556             :          jat_of_kind = atom_of_kind(jatom)
    1557             :          jkind = kind_of(jatom)
    1558             : 
    1559             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1560             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1561             :          x = NORM2(rj - rref)
    1562             :          IF (x < r0 .OR. x > r1) CYCLE
    1563             : 
    1564             :          new_force = 0.0_dp
    1565             :          DO i = 1, SIZE(blk, 1)
    1566             :             new_force = new_force + blk(i, i)
    1567             :          END DO
    1568             :          new_force = pref*new_force*dbump(x, r0, r1)
    1569             : 
    1570             :          !x = SQRT((R - Rref)^2), so we multiply by dx/dR and dx/dRref
    1571             :          DO i_xyz = 1, 3
    1572             :             !Force acting on second atom
    1573             : !$OMP ATOMIC
    1574             :             force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) + &
    1575             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1576             : 
    1577             :             !virial acting on second atom
    1578             :             CALL real_to_scaled(scoord, rj, cell)
    1579             :             DO j_xyz = 1, 3
    1580             : !$OMP ATOMIC
    1581             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1582             :                                            + new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1583             :             END DO
    1584             : 
    1585             :             !Force acting on reference atom, defining the RI basis
    1586             : !$OMP ATOMIC
    1587             :             force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) - &
    1588             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1589             : 
    1590             :             !virial of ref atom
    1591             :             CALL real_to_scaled(scoord, rref, cell)
    1592             :             DO j_xyz = 1, 3
    1593             : !$OMP ATOMIC
    1594             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1595             :                                            - new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1596             :             END DO
    1597             :          END DO !i_xyz
    1598             : 
    1599             :          DEALLOCATE (blk)
    1600             :       END DO
    1601             :       CALL dbt_iterator_stop(iter)
    1602             : !$OMP END PARALLEL
    1603             : 
    1604        4648 :    END SUBROUTINE get_2c_bump_forces
    1605             : 
    1606             : ! **************************************************************************************************
    1607             : !> \brief The bumb function as defined by Juerg
    1608             : !> \param x ...
    1609             : !> \param r0 ...
    1610             : !> \param r1 ...
    1611             : !> \return ...
    1612             : ! **************************************************************************************************
    1613       23891 :    FUNCTION bump(x, r0, r1) RESULT(b)
    1614             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1615             :       REAL(dp)                                           :: b
    1616             : 
    1617             :       REAL(dp)                                           :: r
    1618             : 
    1619             :       !Head-Gordon
    1620             :       !b = 1.0_dp/(1.0_dp+EXP((r1-r0)/(r1-x)-(r1-r0)/(x-r0)))
    1621             :       !Juerg
    1622       23891 :       r = (x - r0)/(r1 - r0)
    1623       23891 :       b = -6.0_dp*r**5 + 15.0_dp*r**4 - 10.0_dp*r**3 + 1.0_dp
    1624       23891 :       IF (x .GE. r1) b = 0.0_dp
    1625       23891 :       IF (x .LE. r0) b = 1.0_dp
    1626             : 
    1627       23891 :    END FUNCTION bump
    1628             : 
    1629             : ! **************************************************************************************************
    1630             : !> \brief The derivative of the bump function
    1631             : !> \param x ...
    1632             : !> \param r0 ...
    1633             : !> \param r1 ...
    1634             : !> \return ...
    1635             : ! **************************************************************************************************
    1636         509 :    FUNCTION dbump(x, r0, r1) RESULT(b)
    1637             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1638             :       REAL(dp)                                           :: b
    1639             : 
    1640             :       REAL(dp)                                           :: r
    1641             : 
    1642         509 :       r = (x - r0)/(r1 - r0)
    1643         509 :       b = (-30.0_dp*r**4 + 60.0_dp*r**3 - 30.0_dp*r**2)/(r1 - r0)
    1644         509 :       IF (x .GE. r1) b = 0.0_dp
    1645         509 :       IF (x .LE. r0) b = 0.0_dp
    1646             : 
    1647         509 :    END FUNCTION dbump
    1648             : 
    1649             : ! **************************************************************************************************
    1650             : !> \brief return the cell index a+c corresponding to given cell index i and b, with i = a+c-b
    1651             : !> \param i_index ...
    1652             : !> \param b_index ...
    1653             : !> \param qs_env ...
    1654             : !> \return ...
    1655             : ! **************************************************************************************************
    1656      158738 :    FUNCTION get_apc_index_from_ib(i_index, b_index, qs_env) RESULT(apc_index)
    1657             :       INTEGER, INTENT(IN)                                :: i_index, b_index
    1658             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1659             :       INTEGER                                            :: apc_index
    1660             : 
    1661             :       INTEGER, DIMENSION(3)                              :: cell_apc
    1662      158738 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1663      158738 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1664             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1665             : 
    1666      158738 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1667      158738 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1668             : 
    1669             :       !i = a+c-b => a+c = i+b
    1670      634952 :       cell_apc(:) = index_to_cell(:, i_index) + index_to_cell(:, b_index)
    1671             : 
    1672     1087392 :       IF (ANY([cell_apc(1), cell_apc(2), cell_apc(3)] < LBOUND(cell_to_index)) .OR. &
    1673             :           ANY([cell_apc(1), cell_apc(2), cell_apc(3)] > UBOUND(cell_to_index))) THEN
    1674             : 
    1675             :          apc_index = 0
    1676             :       ELSE
    1677      138666 :          apc_index = cell_to_index(cell_apc(1), cell_apc(2), cell_apc(3))
    1678             :       END IF
    1679             : 
    1680      158738 :    END FUNCTION get_apc_index_from_ib
    1681             : 
    1682             : ! **************************************************************************************************
    1683             : !> \brief return the cell index i corresponding to the summ of cell_a and cell_c
    1684             : !> \param a_index ...
    1685             : !> \param c_index ...
    1686             : !> \param qs_env ...
    1687             : !> \return ...
    1688             : ! **************************************************************************************************
    1689           0 :    FUNCTION get_apc_index(a_index, c_index, qs_env) RESULT(i_index)
    1690             :       INTEGER, INTENT(IN)                                :: a_index, c_index
    1691             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1692             :       INTEGER                                            :: i_index
    1693             : 
    1694             :       INTEGER, DIMENSION(3)                              :: cell_i
    1695           0 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1696           0 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1697             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1698             : 
    1699           0 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1700           0 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1701             : 
    1702           0 :       cell_i(:) = index_to_cell(:, a_index) + index_to_cell(:, c_index)
    1703             : 
    1704           0 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1705             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1706             : 
    1707             :          i_index = 0
    1708             :       ELSE
    1709           0 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1710             :       END IF
    1711             : 
    1712           0 :    END FUNCTION get_apc_index
    1713             : 
    1714             : ! **************************************************************************************************
    1715             : !> \brief return the cell index i corresponding to the summ of cell_a + cell_c - cell_b
    1716             : !> \param apc_index ...
    1717             : !> \param b_index ...
    1718             : !> \param qs_env ...
    1719             : !> \return ...
    1720             : ! **************************************************************************************************
    1721      526636 :    FUNCTION get_i_index(apc_index, b_index, qs_env) RESULT(i_index)
    1722             :       INTEGER, INTENT(IN)                                :: apc_index, b_index
    1723             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1724             :       INTEGER                                            :: i_index
    1725             : 
    1726             :       INTEGER, DIMENSION(3)                              :: cell_i
    1727      526636 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1728      526636 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1729             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1730             : 
    1731      526636 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1732      526636 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1733             : 
    1734     2106544 :       cell_i(:) = index_to_cell(:, apc_index) - index_to_cell(:, b_index)
    1735             : 
    1736     3597572 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1737             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1738             : 
    1739             :          i_index = 0
    1740             :       ELSE
    1741      450040 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1742             :       END IF
    1743             : 
    1744      526636 :    END FUNCTION get_i_index
    1745             : 
    1746             : ! **************************************************************************************************
    1747             : !> \brief A routine that returns all allowed a,c pairs such that a+c images corresponds to the value
    1748             : !>        of the apc_index input. Takes into account that image a corresponds to 3c integrals, which
    1749             : !>        are ordered in their own way
    1750             : !> \param ac_pairs ...
    1751             : !> \param apc_index ...
    1752             : !> \param ri_data ...
    1753             : !> \param qs_env ...
    1754             : ! **************************************************************************************************
    1755       16412 :    SUBROUTINE get_ac_pairs(ac_pairs, apc_index, ri_data, qs_env)
    1756             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: ac_pairs
    1757             :       INTEGER, INTENT(IN)                                :: apc_index
    1758             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1759             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1760             : 
    1761             :       INTEGER                                            :: a_index, actual_img, c_index, nimg
    1762             : 
    1763       16412 :       nimg = SIZE(ac_pairs, 1)
    1764             : 
    1765     1102508 :       ac_pairs(:, :) = 0
    1766             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(ac_pairs,nimg,ri_data,qs_env,apc_index) &
    1767       16412 : !$OMP PRIVATE(a_index,actual_img,c_index)
    1768             :       DO a_index = 1, nimg
    1769             :          actual_img = ri_data%idx_to_img(a_index)
    1770             :          !c = a+c - a
    1771             :          c_index = get_i_index(apc_index, actual_img, qs_env)
    1772             :          ac_pairs(a_index, 1) = a_index
    1773             :          ac_pairs(a_index, 2) = c_index
    1774             :       END DO
    1775             : !$OMP END PARALLEL DO
    1776             : 
    1777       16412 :    END SUBROUTINE get_ac_pairs
    1778             : 
    1779             : ! **************************************************************************************************
    1780             : !> \brief A routine that returns all allowed i,a+c pairs such that, for the given value of b, we have
    1781             : !>        i = a+c-b. Takes into account that image i corrsponds to the 3c ints, which are ordered in
    1782             : !>        their own way
    1783             : !> \param iapc_pairs ...
    1784             : !> \param b_index ...
    1785             : !> \param ri_data ...
    1786             : !> \param qs_env ...
    1787             : !> \param actual_i_img ...
    1788             : ! **************************************************************************************************
    1789        4708 :    SUBROUTINE get_iapc_pairs(iapc_pairs, b_index, ri_data, qs_env, actual_i_img)
    1790             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: iapc_pairs
    1791             :       INTEGER, INTENT(IN)                                :: b_index
    1792             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1793             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1794             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: actual_i_img
    1795             : 
    1796             :       INTEGER                                            :: actual_img, apc_index, i_index, nimg
    1797             : 
    1798        4708 :       nimg = SIZE(iapc_pairs, 1)
    1799       40802 :       IF (PRESENT(actual_i_img)) actual_i_img(:) = 0
    1800             : 
    1801      331600 :       iapc_pairs(:, :) = 0
    1802             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(iapc_pairs,nimg,ri_data,qs_env,b_index,actual_i_img) &
    1803        4708 : !$OMP PRIVATE(i_index,actual_img,apc_index)
    1804             :       DO i_index = 1, nimg
    1805             :          actual_img = ri_data%idx_to_img(i_index)
    1806             :          apc_index = get_apc_index_from_ib(actual_img, b_index, qs_env)
    1807             :          IF (apc_index == 0) CYCLE
    1808             :          iapc_pairs(i_index, 1) = i_index
    1809             :          iapc_pairs(i_index, 2) = apc_index
    1810             :          IF (PRESENT(actual_i_img)) actual_i_img(i_index) = actual_img
    1811             :       END DO
    1812             : 
    1813        4708 :    END SUBROUTINE get_iapc_pairs
    1814             : 
    1815             : ! **************************************************************************************************
    1816             : !> \brief A function that, given a cell index a, returun the index corresponding to -a, and zero if
    1817             : !>        if out of bounds
    1818             : !> \param a_index ...
    1819             : !> \param qs_env ...
    1820             : !> \return ...
    1821             : ! **************************************************************************************************
    1822       66222 :    FUNCTION get_opp_index(a_index, qs_env) RESULT(opp_index)
    1823             :       INTEGER, INTENT(IN)                                :: a_index
    1824             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1825             :       INTEGER                                            :: opp_index
    1826             : 
    1827             :       INTEGER, DIMENSION(3)                              :: opp_cell
    1828       66222 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1829       66222 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1830             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1831             : 
    1832       66222 :       NULLIFY (kpoints, cell_to_index, index_to_cell)
    1833             : 
    1834       66222 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1835       66222 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1836             : 
    1837      264888 :       opp_cell(:) = -index_to_cell(:, a_index)
    1838             : 
    1839      463554 :       IF (ANY([opp_cell(1), opp_cell(2), opp_cell(3)] < LBOUND(cell_to_index)) .OR. &
    1840             :           ANY([opp_cell(1), opp_cell(2), opp_cell(3)] > UBOUND(cell_to_index))) THEN
    1841             : 
    1842             :          opp_index = 0
    1843             :       ELSE
    1844       66222 :          opp_index = cell_to_index(opp_cell(1), opp_cell(2), opp_cell(3))
    1845             :       END IF
    1846             : 
    1847       66222 :    END FUNCTION get_opp_index
    1848             : 
    1849             : ! **************************************************************************************************
    1850             : !> \brief A routine that returns the actual non-symemtric density matrix for each image, by Fourier
    1851             : !>        transforming the kpoint density matrix
    1852             : !> \param rho_ao_t ...
    1853             : !> \param rho_ao ...
    1854             : !> \param scale_prev_p ...
    1855             : !> \param ri_data ...
    1856             : !> \param qs_env ...
    1857             : ! **************************************************************************************************
    1858         470 :    SUBROUTINE get_pmat_images(rho_ao_t, rho_ao, scale_prev_p, ri_data, qs_env)
    1859             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t
    1860             :       TYPE(dbcsr_p_type), DIMENSION(:, :), INTENT(INOUT) :: rho_ao
    1861             :       REAL(dp), INTENT(IN)                               :: scale_prev_p
    1862             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1863             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1864             : 
    1865             :       INTEGER                                            :: cell_j(3), i_img, i_spin, iatom, icol, &
    1866             :                                                             irow, j_img, jatom, mi_img, mj_img, &
    1867             :                                                             nimg, nspins
    1868         470 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1869             :       LOGICAL                                            :: found
    1870             :       REAL(dp)                                           :: fac
    1871         470 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock, pblock_desymm
    1872         470 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks, rho_desymm
    1873        4230 :       TYPE(dbt_type)                                     :: tmp
    1874             :       TYPE(dft_control_type), POINTER                    :: dft_control
    1875             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1876             :       TYPE(neighbor_list_iterator_p_type), &
    1877         470 :          DIMENSION(:), POINTER                           :: nl_iterator
    1878             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    1879         470 :          POINTER                                         :: sab_nl, sab_nl_nosym
    1880             :       TYPE(qs_scf_env_type), POINTER                     :: scf_env
    1881             : 
    1882         470 :       NULLIFY (rho_desymm, kpoints, sab_nl_nosym, scf_env, matrix_ks, dft_control, &
    1883         470 :                sab_nl, nl_iterator, cell_to_index, pblock, pblock_desymm)
    1884             : 
    1885         470 :       CALL get_qs_env(qs_env, kpoints=kpoints, scf_env=scf_env, matrix_ks_kp=matrix_ks, dft_control=dft_control)
    1886         470 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=sab_nl_nosym, cell_to_index=cell_to_index, sab_nl=sab_nl)
    1887             : 
    1888         470 :       IF (dft_control%do_admm) THEN
    1889         252 :          CALL get_admm_env(qs_env%admm_env, matrix_ks_aux_fit_kp=matrix_ks)
    1890             :       END IF
    1891             : 
    1892         470 :       nspins = SIZE(matrix_ks, 1)
    1893         470 :       nimg = ri_data%nimg
    1894             : 
    1895       28266 :       ALLOCATE (rho_desymm(nspins, nimg))
    1896       12008 :       DO i_img = 1, nimg
    1897       26856 :          DO i_spin = 1, nspins
    1898       14848 :             ALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1899             :             CALL dbcsr_create(rho_desymm(i_spin, i_img)%matrix, template=matrix_ks(i_spin, i_img)%matrix, &
    1900       14848 :                               matrix_type=dbcsr_type_no_symmetry)
    1901       26386 :             CALL cp_dbcsr_alloc_block_from_nbl(rho_desymm(i_spin, i_img)%matrix, sab_nl_nosym)
    1902             :          END DO
    1903             :       END DO
    1904         470 :       CALL dbt_create(rho_desymm(1, 1)%matrix, tmp)
    1905             : 
    1906             :       !We transfor the symmtric typed (but not actually symmetric: P_ab^i = P_ba^-i) real-spaced density
    1907             :       !matrix into proper non-symemtric ones (using the same nl for consistency)
    1908         470 :       CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
    1909       20878 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    1910       20408 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    1911       20408 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    1912       20408 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    1913             : 
    1914       13930 :          fac = 1.0_dp
    1915       13930 :          IF (iatom == jatom) fac = 0.5_dp
    1916       13930 :          mj_img = get_opp_index(j_img, qs_env)
    1917             :          !if no opposite image, then no sum of P^j + P^-j => need full diag
    1918       13930 :          IF (mj_img == 0) fac = 1.0_dp
    1919             : 
    1920       13930 :          irow = iatom
    1921       13930 :          icol = jatom
    1922       13930 :          IF (iatom > jatom) THEN
    1923             :             !because symmetric nl. Value for atom pair i,j is actually stored in j,i if i > j
    1924        4374 :             irow = jatom
    1925        4374 :             icol = iatom
    1926             :          END IF
    1927             : 
    1928       32618 :          DO i_spin = 1, nspins
    1929       18218 :             CALL dbcsr_get_block_p(rho_ao(i_spin, j_img)%matrix, irow, icol, pblock, found)
    1930       18218 :             IF (.NOT. found) CYCLE
    1931             : 
    1932             :             !distribution of symm and non-symm matrix match in that way
    1933       18218 :             CALL dbcsr_get_block_p(rho_desymm(i_spin, j_img)%matrix, iatom, jatom, pblock_desymm, found)
    1934       18218 :             IF (.NOT. found) CYCLE
    1935             : 
    1936       75062 :             IF (iatom > jatom) THEN
    1937      720396 :                pblock_desymm(:, :) = fac*TRANSPOSE(pblock(:, :))
    1938             :             ELSE
    1939     1709076 :                pblock_desymm(:, :) = fac*pblock(:, :)
    1940             :             END IF
    1941             :          END DO
    1942             :       END DO
    1943         470 :       CALL neighbor_list_iterator_release(nl_iterator)
    1944             : 
    1945       12008 :       DO i_img = 1, nimg
    1946       26856 :          DO i_spin = 1, nspins
    1947       14848 :             CALL dbt_scale(rho_ao_t(i_spin, i_img), scale_prev_p)
    1948             : 
    1949       14848 :             CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, i_img)%matrix, tmp)
    1950       14848 :             CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), summation=.TRUE., move_data=.TRUE.)
    1951             : 
    1952             :             !symmetrize by addin transpose of opp img
    1953       14848 :             mi_img = get_opp_index(i_img, qs_env)
    1954       14848 :             IF (mi_img > 0 .AND. mi_img .LE. nimg) THEN
    1955       13368 :                CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, mi_img)%matrix, tmp)
    1956       13368 :                CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), order=[2, 1], summation=.TRUE., move_data=.TRUE.)
    1957             :             END IF
    1958       26386 :             CALL dbt_filter(rho_ao_t(i_spin, i_img), ri_data%filter_eps)
    1959             :          END DO
    1960             :       END DO
    1961             : 
    1962       12008 :       DO i_img = 1, nimg
    1963       26856 :          DO i_spin = 1, nspins
    1964       14848 :             CALL dbcsr_release(rho_desymm(i_spin, i_img)%matrix)
    1965       26386 :             DEALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1966             :          END DO
    1967             :       END DO
    1968             : 
    1969         470 :       CALL dbt_destroy(tmp)
    1970         470 :       DEALLOCATE (rho_desymm)
    1971             : 
    1972         940 :    END SUBROUTINE get_pmat_images
    1973             : 
    1974             : ! **************************************************************************************************
    1975             : !> \brief A routine that, given a cell index b and atom indices ij, returns a 2c tensor with the HFX
    1976             : !>        potential (P_i^0|Q_j^b), within the extended RI basis
    1977             : !> \param t_2c_pot ...
    1978             : !> \param mat_orig ...
    1979             : !> \param atom_i ...
    1980             : !> \param atom_j ...
    1981             : !> \param img_b ...
    1982             : !> \param ri_data ...
    1983             : !> \param qs_env ...
    1984             : !> \param do_inverse ...
    1985             : !> \param para_env_ext ...
    1986             : !> \param blacs_env_ext ...
    1987             : !> \param dbcsr_template ...
    1988             : !> \param off_diagonal ...
    1989             : !> \param skip_inverse ...
    1990             : ! **************************************************************************************************
    1991        7704 :    SUBROUTINE get_ext_2c_int(t_2c_pot, mat_orig, atom_i, atom_j, img_b, ri_data, qs_env, do_inverse, &
    1992             :                              para_env_ext, blacs_env_ext, dbcsr_template, off_diagonal, skip_inverse)
    1993             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_pot
    1994             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_orig
    1995             :       INTEGER, INTENT(IN)                                :: atom_i, atom_j, img_b
    1996             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1997             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1998             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_inverse
    1999             :       TYPE(mp_para_env_type), OPTIONAL, POINTER          :: para_env_ext
    2000             :       TYPE(cp_blacs_env_type), OPTIONAL, POINTER         :: blacs_env_ext
    2001             :       TYPE(dbcsr_type), OPTIONAL, POINTER                :: dbcsr_template
    2002             :       LOGICAL, INTENT(IN), OPTIONAL                      :: off_diagonal, skip_inverse
    2003             : 
    2004             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_ext_2c_int'
    2005             : 
    2006             :       INTEGER :: blk, group, handle, handle2, i_img, i_RI, iatom, iblk, ikind, img_tot, j_img, &
    2007             :          j_RI, jatom, jblk, jkind, n_dependent, natom, nblks_RI, nimg, nkind
    2008        7704 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist1, dist2
    2009        7704 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: present_atoms_i, present_atoms_j
    2010             :       INTEGER, DIMENSION(3)                              :: cell_b, cell_i, cell_j, cell_tot
    2011        7704 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, col_dist_ext, ri_blk_size_ext, &
    2012        7704 :                                                             row_dist, row_dist_ext
    2013        7704 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell, pgrid
    2014        7704 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    2015             :       LOGICAL                                            :: do_inverse_prv, found, my_offd, &
    2016             :                                                             skip_inverse_prv, use_template
    2017             :       REAL(dp)                                           :: bfac, dij, r0, r1, threshold
    2018             :       REAL(dp), DIMENSION(3)                             :: ri, rij, rj, rref, scoord
    2019        7704 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock
    2020             :       TYPE(cell_type), POINTER                           :: cell
    2021             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
    2022             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist, dbcsr_dist_ext
    2023             :       TYPE(dbcsr_iterator_type)                          :: dbcsr_iter
    2024             :       TYPE(dbcsr_type)                                   :: work, work_tight, work_tight_inv
    2025       53928 :       TYPE(dbt_type)                                     :: t_2c_tmp
    2026             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    2027             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    2028        7704 :          DIMENSION(:), TARGET                            :: basis_set_RI
    2029             :       TYPE(kpoint_type), POINTER                         :: kpoints
    2030             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2031             :       TYPE(neighbor_list_iterator_p_type), &
    2032        7704 :          DIMENSION(:), POINTER                           :: nl_iterator
    2033             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    2034        7704 :          POINTER                                         :: nl_2c
    2035        7704 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    2036        7704 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    2037             : 
    2038        7704 :       NULLIFY (qs_kind_set, nl_2c, nl_iterator, cell, kpoints, cell_to_index, index_to_cell, dist_2d, &
    2039        7704 :                para_env, pblock, blacs_env, particle_set, col_dist, row_dist, pgrid, &
    2040        7704 :                col_dist_ext, row_dist_ext)
    2041             : 
    2042        7704 :       CALL timeset(routineN, handle)
    2043             : 
    2044             :       !Idea: run over the neighbor list once for i and once for j, and record in which cell the MIC
    2045             :       !      atoms are. Then loop over the atoms and only take the pairs the we need
    2046             : 
    2047             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    2048        7704 :                       kpoints=kpoints, para_env=para_env, blacs_env=blacs_env, particle_set=particle_set)
    2049        7704 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    2050             : 
    2051        7704 :       do_inverse_prv = .FALSE.
    2052        7704 :       IF (PRESENT(do_inverse)) do_inverse_prv = do_inverse
    2053         280 :       IF (do_inverse_prv) THEN
    2054         280 :          CPASSERT(atom_i == atom_j)
    2055             :       END IF
    2056             : 
    2057        7704 :       skip_inverse_prv = .FALSE.
    2058        7704 :       IF (PRESENT(skip_inverse)) skip_inverse_prv = skip_inverse
    2059             : 
    2060        7704 :       my_offd = .FALSE.
    2061        7704 :       IF (PRESENT(off_diagonal)) my_offd = off_diagonal
    2062             : 
    2063        7704 :       IF (PRESENT(para_env_ext)) para_env => para_env_ext
    2064        7704 :       IF (PRESENT(blacs_env_ext)) blacs_env => blacs_env_ext
    2065             : 
    2066        7704 :       nimg = SIZE(mat_orig)
    2067             : 
    2068        7704 :       CALL timeset(routineN//"_nl_iter", handle2)
    2069             : 
    2070             :       !create our own dist_2d in the subgroup
    2071       30816 :       ALLOCATE (dist1(natom), dist2(natom))
    2072       23112 :       DO iatom = 1, natom
    2073       15408 :          dist1(iatom) = MOD(iatom, blacs_env%num_pe(1))
    2074       23112 :          dist2(iatom) = MOD(iatom, blacs_env%num_pe(2))
    2075             :       END DO
    2076        7704 :       CALL distribution_2d_create(dist_2d, dist1, dist2, nkind, particle_set, blacs_env_ext=blacs_env)
    2077             : 
    2078       34860 :       ALLOCATE (basis_set_RI(nkind))
    2079        7704 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    2080             : 
    2081             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    2082        7704 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    2083             : 
    2084       46224 :       ALLOCATE (present_atoms_i(natom, nimg), present_atoms_j(natom, nimg))
    2085      748254 :       present_atoms_i = 0
    2086      748254 :       present_atoms_j = 0
    2087             : 
    2088        7704 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    2089      308400 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    2090             :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, r=rij, cell=cell_j, &
    2091      300696 :                                 ikind=ikind, jkind=jkind)
    2092             : 
    2093     1202784 :          dij = NORM2(rij)
    2094             : 
    2095      300696 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    2096      300696 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    2097             : 
    2098      298349 :          IF (iatom == atom_i .AND. dij .LE. ri_data%kp_RI_range) present_atoms_i(jatom, j_img) = 1
    2099      306053 :          IF (iatom == atom_j .AND. dij .LE. ri_data%kp_RI_range) present_atoms_j(jatom, j_img) = 1
    2100             :       END DO
    2101        7704 :       CALL neighbor_list_iterator_release(nl_iterator)
    2102        7704 :       CALL release_neighbor_list_sets(nl_2c)
    2103        7704 :       CALL distribution_2d_release(dist_2d)
    2104        7704 :       CALL timestop(handle2)
    2105             : 
    2106        7704 :       CALL para_env%sum(present_atoms_i)
    2107        7704 :       CALL para_env%sum(present_atoms_j)
    2108             : 
    2109             :       !Need to build a work matrix with matching distribution to mat_orig
    2110             :       !If template is provided, use it. If not, we create it.
    2111        7704 :       use_template = .FALSE.
    2112        7704 :       IF (PRESENT(dbcsr_template)) THEN
    2113        7032 :          IF (ASSOCIATED(dbcsr_template)) use_template = .TRUE.
    2114             :       END IF
    2115             : 
    2116             :       IF (use_template) THEN
    2117        6776 :          CALL dbcsr_create(work, template=dbcsr_template)
    2118             :       ELSE
    2119         928 :          CALL dbcsr_get_info(mat_orig(1), distribution=dbcsr_dist)
    2120         928 :          CALL dbcsr_distribution_get(dbcsr_dist, row_dist=row_dist, col_dist=col_dist, group=group, pgrid=pgrid)
    2121        3712 :          ALLOCATE (row_dist_ext(ri_data%ncell_RI*natom), col_dist_ext(ri_data%ncell_RI*natom))
    2122        1856 :          ALLOCATE (ri_blk_size_ext(ri_data%ncell_RI*natom))
    2123        6602 :          DO i_RI = 1, ri_data%ncell_RI
    2124       28370 :             row_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = row_dist(:)
    2125       28370 :             col_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = col_dist(:)
    2126       17950 :             RI_blk_size_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2127             :          END DO
    2128             : 
    2129             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2130         928 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2131             :          CALL dbcsr_create(work, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2132         928 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2133         928 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2134         928 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2135             : 
    2136        2784 :          IF (PRESENT(dbcsr_template)) THEN
    2137         256 :             ALLOCATE (dbcsr_template)
    2138         256 :             CALL dbcsr_create(dbcsr_template, template=work)
    2139             :          END IF
    2140             :       END IF !use_template
    2141             : 
    2142       30816 :       cell_b(:) = index_to_cell(:, img_b)
    2143      254554 :       DO i_img = 1, nimg
    2144      246850 :          i_RI = ri_data%img_to_RI_cell(i_img)
    2145      246850 :          IF (i_RI == 0) CYCLE
    2146      207140 :          cell_i(:) = index_to_cell(:, i_img)
    2147     2043969 :          DO j_img = 1, nimg
    2148     1984480 :             j_RI = ri_data%img_to_RI_cell(j_img)
    2149     1984480 :             IF (j_RI == 0) CYCLE
    2150     1799436 :             cell_j(:) = index_to_cell(:, j_img)
    2151     1799436 :             cell_tot = cell_j - cell_i + cell_b
    2152             : 
    2153     3107547 :             IF (ANY([cell_tot(1), cell_tot(2), cell_tot(3)] < LBOUND(cell_to_index)) .OR. &
    2154             :                 ANY([cell_tot(1), cell_tot(2), cell_tot(3)] > UBOUND(cell_to_index))) CYCLE
    2155      412907 :             img_tot = cell_to_index(cell_tot(1), cell_tot(2), cell_tot(3))
    2156      412907 :             IF (img_tot > nimg .OR. img_tot < 1) CYCLE
    2157             : 
    2158      276021 :             CALL dbcsr_iterator_start(dbcsr_iter, mat_orig(img_tot))
    2159      788793 :             DO WHILE (dbcsr_iterator_blocks_left(dbcsr_iter))
    2160      512772 :                CALL dbcsr_iterator_next_block(dbcsr_iter, row=iatom, column=jatom, blk=blk)
    2161      512772 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2162      191706 :                IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2163       83092 :                IF (my_offd .AND. (i_RI - 1)*natom + iatom == (j_RI - 1)*natom + jatom) CYCLE
    2164             : 
    2165       82805 :                CALL dbcsr_get_block_p(mat_orig(img_tot), iatom, jatom, pblock, found)
    2166       82805 :                IF (.NOT. found) CYCLE
    2167             : 
    2168      788793 :                CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2169             : 
    2170             :             END DO
    2171     2470399 :             CALL dbcsr_iterator_stop(dbcsr_iter)
    2172             : 
    2173             :          END DO !j_img
    2174             :       END DO !i_img
    2175        7704 :       CALL dbcsr_finalize(work)
    2176             : 
    2177        7704 :       IF (do_inverse_prv) THEN
    2178             : 
    2179         280 :          r1 = ri_data%kp_RI_range
    2180         280 :          r0 = ri_data%kp_bump_rad
    2181             : 
    2182             :          !Because there are a lot of empty rows/cols in work, we need to get rid of them for inversion
    2183       20008 :          nblks_RI = SUM(present_atoms_i)
    2184        1400 :          ALLOCATE (col_dist_ext(nblks_RI), row_dist_ext(nblks_RI), RI_blk_size_ext(nblks_RI))
    2185         280 :          iblk = 0
    2186        6856 :          DO i_img = 1, nimg
    2187        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2188        6576 :             IF (i_RI == 0) CYCLE
    2189        5512 :             DO iatom = 1, natom
    2190        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2191        1148 :                iblk = iblk + 1
    2192        1148 :                col_dist_ext(iblk) = col_dist(iatom)
    2193        1148 :                row_dist_ext(iblk) = row_dist(iatom)
    2194       10064 :                RI_blk_size_ext(iblk) = ri_data%bsizes_RI(iatom)
    2195             :             END DO
    2196             :          END DO
    2197             : 
    2198             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2199         280 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2200             :          CALL dbcsr_create(work_tight, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2201         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2202             :          CALL dbcsr_create(work_tight_inv, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2203         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2204         280 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2205         280 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2206             : 
    2207             :          !We apply a bump function to the RI metric inverse for smooth RI basis extension:
    2208             :          ! S^-1 = B * ((P|Q)_D + B*(P|Q)_OD*B)^-1 * B, with D block-diagonal blocks and OD off-diagonal
    2209         280 :          rref = pbc(particle_set(atom_i)%r, cell)
    2210             : 
    2211         280 :          iblk = 0
    2212        6856 :          DO i_img = 1, nimg
    2213        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2214        6576 :             IF (i_RI == 0) CYCLE
    2215        5512 :             DO iatom = 1, natom
    2216        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2217        1148 :                iblk = iblk + 1
    2218             : 
    2219        1148 :                CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    2220        4592 :                CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    2221             : 
    2222        1148 :                jblk = 0
    2223       40068 :                DO j_img = 1, nimg
    2224       32344 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2225       32344 :                   IF (j_RI == 0) CYCLE
    2226       29240 :                   DO jatom = 1, natom
    2227       17168 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2228        5476 :                      jblk = jblk + 1
    2229             : 
    2230        5476 :                      CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    2231       21904 :                      CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    2232             : 
    2233        5476 :                      CALL dbcsr_get_block_p(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock, found)
    2234        5476 :                      IF (.NOT. found) CYCLE
    2235             : 
    2236        2460 :                      bfac = 1.0_dp
    2237       13776 :                      IF (iblk .NE. jblk) bfac = bump(NORM2(ri - rref), r0, r1)*bump(NORM2(rj - rref), r0, r1)
    2238     5054592 :                      CALL dbcsr_put_block(work_tight, iblk, jblk, bfac*pblock(:, :))
    2239             :                   END DO
    2240             :                END DO
    2241             :             END DO
    2242             :          END DO
    2243         280 :          CALL dbcsr_finalize(work_tight)
    2244         280 :          CALL dbcsr_clear(work)
    2245             : 
    2246         280 :          IF (.NOT. skip_inverse_prv) THEN
    2247         140 :             SELECT CASE (ri_data%t2c_method)
    2248             :             CASE (hfx_ri_do_2c_iter)
    2249           0 :                threshold = MAX(ri_data%filter_eps, 1.0e-12_dp)
    2250           0 :                CALL invert_hotelling(work_tight_inv, work_tight, threshold=threshold, silent=.FALSE.)
    2251             :             CASE (hfx_ri_do_2c_cholesky)
    2252         140 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2253         140 :                CALL cp_dbcsr_cholesky_decompose(work_tight_inv, para_env=para_env, blacs_env=blacs_env)
    2254             :                CALL cp_dbcsr_cholesky_invert(work_tight_inv, para_env=para_env, blacs_env=blacs_env, &
    2255         140 :                                              upper_to_full=.TRUE.)
    2256             :             CASE (hfx_ri_do_2c_diag)
    2257           0 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2258             :                CALL cp_dbcsr_power(work_tight_inv, -1.0_dp, ri_data%eps_eigval, n_dependent, &
    2259         140 :                                    para_env, blacs_env, verbose=ri_data%unit_nr_dbcsr > 0)
    2260             :             END SELECT
    2261             :          ELSE
    2262         140 :             CALL dbcsr_copy(work_tight_inv, work_tight)
    2263             :          END IF
    2264             : 
    2265             :          !move back data to standard extended RI pattern
    2266             :          !Note: we apply the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 later, because this matrix
    2267             :          !      is required for forces
    2268         280 :          iblk = 0
    2269        6856 :          DO i_img = 1, nimg
    2270        6576 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2271        6576 :             IF (i_RI == 0) CYCLE
    2272        5512 :             DO iatom = 1, natom
    2273        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2274        1148 :                iblk = iblk + 1
    2275             : 
    2276        1148 :                jblk = 0
    2277       40068 :                DO j_img = 1, nimg
    2278       32344 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2279       32344 :                   IF (j_RI == 0) CYCLE
    2280       29240 :                   DO jatom = 1, natom
    2281       17168 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2282        5476 :                      jblk = jblk + 1
    2283             : 
    2284        5476 :                      CALL dbcsr_get_block_p(work_tight_inv, iblk, jblk, pblock, found)
    2285        5476 :                      IF (.NOT. found) CYCLE
    2286             : 
    2287       52111 :                      CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2288             :                   END DO
    2289             :                END DO
    2290             :             END DO
    2291             :          END DO
    2292         280 :          CALL dbcsr_finalize(work)
    2293             : 
    2294         280 :          CALL dbcsr_release(work_tight)
    2295         560 :          CALL dbcsr_release(work_tight_inv)
    2296             :       END IF
    2297             : 
    2298        7704 :       CALL dbt_create(work, t_2c_tmp)
    2299        7704 :       CALL dbt_copy_matrix_to_tensor(work, t_2c_tmp)
    2300        7704 :       CALL dbt_copy(t_2c_tmp, t_2c_pot, move_data=.TRUE.)
    2301        7704 :       CALL dbt_filter(t_2c_pot, ri_data%filter_eps)
    2302             : 
    2303        7704 :       CALL dbt_destroy(t_2c_tmp)
    2304        7704 :       CALL dbcsr_release(work)
    2305             : 
    2306        7704 :       CALL timestop(handle)
    2307             : 
    2308       30816 :    END SUBROUTINE get_ext_2c_int
    2309             : 
    2310             : ! **************************************************************************************************
    2311             : !> \brief Pre-contract the density matrices with the 3-center integrals:
    2312             : !>        P_sigma^a,lambda^a+c (mu^0 sigma^a| P^0)
    2313             : !> \param t_3c_apc ...
    2314             : !> \param rho_ao_t ...
    2315             : !> \param ri_data ...
    2316             : !> \param qs_env ...
    2317             : ! **************************************************************************************************
    2318         256 :    SUBROUTINE contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
    2319             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, rho_ao_t
    2320             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2321             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2322             : 
    2323             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'contract_pmat_3c'
    2324             : 
    2325             :       INTEGER                                            :: apc_img, batch_size, handle, i_batch, &
    2326             :                                                             i_img, i_spin, j_batch, n_batch_img, &
    2327             :                                                             n_batch_nze, nimg, nimg_nze, nspins
    2328             :       INTEGER(int_8)                                     :: nflop, nze
    2329         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_img, batch_ranges_nze, &
    2330         256 :                                                             int_indices
    2331         256 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ac_pairs
    2332             :       REAL(dp)                                           :: occ, t1, t2
    2333        2304 :       TYPE(dbt_type)                                     :: t_3c_tmp
    2334         256 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ints_stack, res_stack, rho_stack
    2335             :       TYPE(dft_control_type), POINTER                    :: dft_control
    2336             : 
    2337         256 :       CALL timeset(routineN, handle)
    2338             : 
    2339         256 :       CALL get_qs_env(qs_env, dft_control=dft_control)
    2340             : 
    2341         256 :       nimg = ri_data%nimg
    2342         256 :       nimg_nze = ri_data%nimg_nze
    2343         256 :       nspins = dft_control%nspins
    2344             : 
    2345         256 :       CALL dbt_create(t_3c_apc(1, 1), t_3c_tmp)
    2346             : 
    2347         256 :       batch_size = nimg/ri_data%n_mem
    2348             : 
    2349             :       !batching over all images
    2350         256 :       n_batch_img = nimg/batch_size
    2351         256 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch_img = n_batch_img + 1
    2352         768 :       ALLOCATE (batch_ranges_img(n_batch_img + 1))
    2353         890 :       DO i_batch = 1, n_batch_img
    2354         890 :          batch_ranges_img(i_batch) = (i_batch - 1)*batch_size + 1
    2355             :       END DO
    2356         256 :       batch_ranges_img(n_batch_img + 1) = nimg + 1
    2357             : 
    2358             :       !batching over images with non-zero 3c integrals
    2359         256 :       n_batch_nze = nimg_nze/batch_size
    2360         256 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
    2361         768 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
    2362         830 :       DO i_batch = 1, n_batch_nze
    2363         830 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
    2364             :       END DO
    2365         256 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
    2366             : 
    2367             :       !Create the stack tensors in the approriate distribution
    2368        7936 :       ALLOCATE (rho_stack(2), ints_stack(2), res_stack(2))
    2369             :       CALL get_stack_tensors(res_stack, rho_stack, ints_stack, rho_ao_t(1, 1), &
    2370         256 :                              ri_data%t_3c_int_ctr_1(1, 1), batch_size, ri_data, qs_env)
    2371             : 
    2372        1280 :       ALLOCATE (ac_pairs(nimg, 2), int_indices(nimg_nze))
    2373        4888 :       DO i_img = 1, nimg_nze
    2374        4888 :          int_indices(i_img) = i_img
    2375             :       END DO
    2376             : 
    2377         256 :       t1 = m_walltime()
    2378         830 :       DO j_batch = 1, n_batch_nze
    2379             :          !First batch is over the integrals. They are always in the same order, consistent with get_ac_pairs
    2380             :          CALL fill_3c_stack(ints_stack(1), ri_data%t_3c_int_ctr_1(1, :), int_indices, 3, ri_data, &
    2381        1722 :                             img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)])
    2382         574 :          CALL dbt_copy(ints_stack(1), ints_stack(2), move_data=.TRUE.)
    2383             : 
    2384        1640 :          DO i_spin = 1, nspins
    2385        3448 :             DO i_batch = 1, n_batch_img
    2386             :                !Second batch is over the P matrix. Here we fill the stacked rho tensors col by col
    2387       18476 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2388       16412 :                   CALL get_ac_pairs(ac_pairs, apc_img, ri_data, qs_env)
    2389             :                   CALL fill_2c_stack(rho_stack(1), rho_ao_t(i_spin, :), ac_pairs(:, 2), 1, ri_data, &
    2390             :                                      img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)], &
    2391       51300 :                                      shift=apc_img - batch_ranges_img(i_batch) + 1)
    2392             : 
    2393             :                END DO !apc_img
    2394        2064 :                CALL get_tensor_occupancy(rho_stack(1), nze, occ)
    2395        2064 :                IF (nze == 0) CYCLE
    2396        1806 :                CALL dbt_copy(rho_stack(1), rho_stack(2), move_data=.TRUE.)
    2397             : 
    2398             :                !The actual contraction
    2399        1806 :                CALL dbt_batched_contract_init(rho_stack(2))
    2400             :                CALL dbt_contract(1.0_dp, ints_stack(2), rho_stack(2), &
    2401             :                                  0.0_dp, res_stack(2), map_1=[1, 2], map_2=[3], &
    2402             :                                  contract_1=[3], notcontract_1=[1, 2], &
    2403             :                                  contract_2=[1], notcontract_2=[2], &
    2404        1806 :                                  filter_eps=ri_data%filter_eps, flop=nflop)
    2405        1806 :                ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2406        1806 :                CALL dbt_batched_contract_finalize(rho_stack(2))
    2407        1806 :                CALL dbt_copy(res_stack(2), res_stack(1), move_data=.TRUE.)
    2408             : 
    2409       20230 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2410             :                   !Destack the resulting tensor and put it in t_3c_apc with correct apc_img
    2411       15550 :                   CALL unstack_t_3c_apc(t_3c_tmp, res_stack(1), apc_img - batch_ranges_img(i_batch) + 1)
    2412       17614 :                   CALL dbt_copy(t_3c_tmp, t_3c_apc(i_spin, apc_img), summation=.TRUE., move_data=.TRUE.)
    2413             :                END DO
    2414             : 
    2415             :             END DO !i_batch
    2416             :          END DO !i_spin
    2417             :       END DO !j_batch
    2418         256 :       DEALLOCATE (batch_ranges_img)
    2419         256 :       DEALLOCATE (batch_ranges_nze)
    2420         256 :       t2 = m_walltime()
    2421         256 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    2422             : 
    2423         256 :       CALL dbt_destroy(rho_stack(1))
    2424         256 :       CALL dbt_destroy(rho_stack(2))
    2425         256 :       CALL dbt_destroy(ints_stack(1))
    2426         256 :       CALL dbt_destroy(ints_stack(2))
    2427         256 :       CALL dbt_destroy(res_stack(1))
    2428         256 :       CALL dbt_destroy(res_stack(2))
    2429         256 :       CALL dbt_destroy(t_3c_tmp)
    2430             : 
    2431         256 :       CALL timestop(handle)
    2432             : 
    2433        2560 :    END SUBROUTINE contract_pmat_3c
    2434             : 
    2435             : ! **************************************************************************************************
    2436             : !> \brief Pre-contract 3-center integrals with the bumped invrse RI metric, for each atom
    2437             : !> \param t_3c_int ...
    2438             : !> \param ri_data ...
    2439             : !> \param qs_env ...
    2440             : ! **************************************************************************************************
    2441          70 :    SUBROUTINE precontract_3c_ints(t_3c_int, ri_data, qs_env)
    2442             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_int
    2443             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2444             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2445             : 
    2446             :       CHARACTER(len=*), PARAMETER :: routineN = 'precontract_3c_ints'
    2447             : 
    2448             :       INTEGER                                            :: batch_size, handle, i_batch, i_img, &
    2449             :                                                             i_RI, iatom, is, n_batch, natom, &
    2450             :                                                             nblks, nblks_3c(3), nimg
    2451             :       INTEGER(int_8)                                     :: nflop
    2452          70 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: batch_ranges, bsizes_RI_ext, bsizes_RI_ext_split, &
    2453          70 :          bsizes_stack, dist1, dist2, dist3, dist_stack3, idx_to_at_AO, int_indices
    2454         630 :       TYPE(dbt_distribution_type)                        :: t_dist
    2455       14630 :       TYPE(dbt_type)                                     :: t_2c_RI_tmp(2), t_3c_tmp(3)
    2456             : 
    2457          70 :       CALL timeset(routineN, handle)
    2458             : 
    2459          70 :       CALL get_qs_env(qs_env, natom=natom)
    2460             : 
    2461          70 :       nimg = ri_data%nimg
    2462         210 :       ALLOCATE (int_indices(nimg))
    2463        1714 :       DO i_img = 1, nimg
    2464        1714 :          int_indices(i_img) = i_img
    2465             :       END DO
    2466             : 
    2467         210 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
    2468          70 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
    2469             : 
    2470          70 :       nblks = SIZE(ri_data%bsizes_RI_split)
    2471         210 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    2472         210 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    2473         506 :       DO i_RI = 1, ri_data%ncell_RI
    2474        1308 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2475        2344 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    2476             :       END DO
    2477             :       CALL create_2c_tensor(t_2c_RI_tmp(1), dist1, dist2, ri_data%pgrid_2d, &
    2478             :                             bsizes_RI_ext, bsizes_RI_ext, &
    2479             :                             name="(RI | RI)")
    2480          70 :       DEALLOCATE (dist1, dist2)
    2481             :       CALL create_2c_tensor(t_2c_RI_tmp(2), dist1, dist2, ri_data%pgrid_2d, &
    2482             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    2483             :                             name="(RI | RI)")
    2484          70 :       DEALLOCATE (dist1, dist2)
    2485             : 
    2486             :       !For more efficiency, we stack multiple images of the 3-center integrals into a single tensor
    2487          70 :       batch_size = nimg/ri_data%n_mem
    2488          70 :       n_batch = nimg/batch_size
    2489          70 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch = n_batch + 1
    2490         210 :       ALLOCATE (batch_ranges(n_batch + 1))
    2491         246 :       DO i_batch = 1, n_batch
    2492         246 :          batch_ranges(i_batch) = (i_batch - 1)*batch_size + 1
    2493             :       END DO
    2494          70 :       batch_ranges(n_batch + 1) = nimg + 1
    2495             : 
    2496          70 :       nblks = SIZE(ri_data%bsizes_AO_split)
    2497         210 :       ALLOCATE (bsizes_stack(batch_size*nblks))
    2498         874 :       DO is = 1, batch_size
    2499        3394 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    2500             :       END DO
    2501             : 
    2502          70 :       CALL dbt_get_info(t_3c_int(1, 1), nblks_total=nblks_3c)
    2503         630 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)), dist_stack3(batch_size*nblks_3c(3)))
    2504          70 :       CALL dbt_get_info(t_3c_int(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, proc_dist_3=dist3)
    2505         874 :       DO is = 1, batch_size
    2506        3394 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    2507             :       END DO
    2508             : 
    2509          70 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid, dist1, dist2, dist_stack3)
    2510             :       CALL dbt_create(t_3c_tmp(1), "ints_stack", t_dist, [1], [2, 3], bsizes_RI_ext_split, &
    2511          70 :                       ri_data%bsizes_AO_split, bsizes_stack)
    2512          70 :       CALL dbt_distribution_destroy(t_dist)
    2513          70 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    2514             : 
    2515          70 :       CALL dbt_create(t_3c_tmp(1), t_3c_tmp(2))
    2516          70 :       CALL dbt_create(t_3c_int(1, 1), t_3c_tmp(3))
    2517             : 
    2518         210 :       DO iatom = 1, natom
    2519         140 :          CALL dbt_copy(ri_data%t_2c_inv(1, iatom), t_2c_RI_tmp(1))
    2520         140 :          CALL apply_bump(t_2c_RI_tmp(1), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    2521         140 :          CALL dbt_copy(t_2c_RI_tmp(1), t_2c_RI_tmp(2), move_data=.TRUE.)
    2522             : 
    2523         140 :          CALL dbt_batched_contract_init(t_2c_RI_tmp(2))
    2524         492 :          DO i_batch = 1, n_batch
    2525             : 
    2526             :             CALL fill_3c_stack(t_3c_tmp(1), t_3c_int(1, :), int_indices, 3, ri_data, &
    2527             :                                img_bounds=[batch_ranges(i_batch), batch_ranges(i_batch + 1)], &
    2528        1056 :                                filter_at=iatom, filter_dim=2, idx_to_at=idx_to_at_AO)
    2529             : 
    2530             :             CALL dbt_contract(1.0_dp, t_2c_RI_tmp(2), t_3c_tmp(1), &
    2531             :                               0.0_dp, t_3c_tmp(2), map_1=[1], map_2=[2, 3], &
    2532             :                               contract_1=[2], notcontract_1=[1], &
    2533             :                               contract_2=[1], notcontract_2=[2, 3], &
    2534         352 :                               filter_eps=ri_data%filter_eps, flop=nflop)
    2535         352 :             ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2536             : 
    2537        3640 :             DO i_img = batch_ranges(i_batch), batch_ranges(i_batch + 1) - 1
    2538        3288 :                CALL unstack_t_3c_apc(t_3c_tmp(3), t_3c_tmp(2), i_img - batch_ranges(i_batch) + 1)
    2539             :                CALL dbt_copy(t_3c_tmp(3), ri_data%t_3c_int_ctr_1(1, i_img), summation=.TRUE., &
    2540        3640 :                              order=[2, 1, 3], move_data=.TRUE.)
    2541             :             END DO
    2542         492 :             CALL dbt_clear(t_3c_tmp(1))
    2543             :          END DO
    2544         210 :          CALL dbt_batched_contract_finalize(t_2c_RI_tmp(2))
    2545             : 
    2546             :       END DO
    2547          70 :       CALL dbt_destroy(t_2c_RI_tmp(1))
    2548          70 :       CALL dbt_destroy(t_2c_RI_tmp(2))
    2549          70 :       CALL dbt_destroy(t_3c_tmp(1))
    2550          70 :       CALL dbt_destroy(t_3c_tmp(2))
    2551          70 :       CALL dbt_destroy(t_3c_tmp(3))
    2552             : 
    2553        1714 :       DO i_img = 1, nimg
    2554        1714 :          CALL dbt_destroy(t_3c_int(1, i_img))
    2555             :       END DO
    2556             : 
    2557          70 :       CALL timestop(handle)
    2558             : 
    2559         350 :    END SUBROUTINE precontract_3c_ints
    2560             : 
    2561             : ! **************************************************************************************************
    2562             : !> \brief Copy the data of a 2D tensor living in the main MPI group to a sub-group, given the proc
    2563             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2564             : !> \param t2c_sub ...
    2565             : !> \param t2c_main ...
    2566             : !> \param group_size ...
    2567             : !> \param ngroups ...
    2568             : !> \param para_env ...
    2569             : ! **************************************************************************************************
    2570        8388 :    SUBROUTINE copy_2c_to_subgroup(t2c_sub, t2c_main, group_size, ngroups, para_env)
    2571             :       TYPE(dbt_type), INTENT(INOUT)                      :: t2c_sub, t2c_main
    2572             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2573             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2574             : 
    2575             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iblk, &
    2576             :                                                             igroup, iproc, ir, is, jblk, n_batch, &
    2577             :                                                             nocc, tag
    2578        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2
    2579        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: block_dest, block_source
    2580        8388 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: current_dest
    2581             :       INTEGER, DIMENSION(2)                              :: ind, nblks
    2582             :       LOGICAL                                            :: found
    2583        8388 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2584        8388 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2585             :       TYPE(dbt_iterator_type)                            :: iter
    2586        8388 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2587             : 
    2588             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2589             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2590             :       !         to have blocks pre-reserved though
    2591             : 
    2592        8388 :       CALL dbt_get_info(t2c_main, nblks_total=nblks)
    2593             : 
    2594             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2595       33552 :       ALLOCATE (block_source(nblks(1), nblks(2)))
    2596      169184 :       block_source = -1
    2597        8388 :       nocc = 0
    2598        8388 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t2c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2599             :       CALL dbt_iterator_start(iter, t2c_main)
    2600             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2601             :          CALL dbt_iterator_next_block(iter, ind)
    2602             :          CALL dbt_get_block(t2c_main, ind, blk, found)
    2603             :          IF (.NOT. found) CYCLE
    2604             : 
    2605             :          block_source(ind(1), ind(2)) = para_env%mepos
    2606             : !$OMP ATOMIC
    2607             :          nocc = nocc + 1
    2608             :          DEALLOCATE (blk)
    2609             :       END DO
    2610             :       CALL dbt_iterator_stop(iter)
    2611             : !$OMP END PARALLEL
    2612             : 
    2613        8388 :       CALL para_env%sum(nocc)
    2614        8388 :       CALL para_env%sum(block_source)
    2615      169184 :       block_source = block_source + para_env%num_pe - 1
    2616        8388 :       IF (nocc == 0) RETURN
    2617             : 
    2618             :       !Loop over the sub tensor, get the block destination
    2619        8268 :       igroup = para_env%mepos/group_size
    2620       24804 :       ALLOCATE (block_dest(nblks(1), nblks(2)))
    2621      168344 :       block_dest = -1
    2622       31184 :       DO jblk = 1, nblks(2)
    2623      168344 :          DO iblk = 1, nblks(1)
    2624      137160 :             IF (block_source(iblk, jblk) == -1) CYCLE
    2625             : 
    2626      101028 :             CALL dbt_get_stored_coordinates(t2c_sub, [iblk, jblk], iproc)
    2627      160076 :             block_dest(iblk, jblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2628             :          END DO
    2629             :       END DO
    2630             : 
    2631       41340 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)))
    2632        8268 :       CALL dbt_get_info(t2c_main, blk_size_1=bsizes1, blk_size_2=bsizes2)
    2633             : 
    2634       41340 :       ALLOCATE (current_dest(nblks(1), nblks(2), 0:ngroups - 1))
    2635       24804 :       DO igroup = 0, ngroups - 1
    2636             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2637      336688 :          current_dest(:, :, igroup) = block_dest(:, :)
    2638       24804 :          CALL para_env%bcast(current_dest(:, :, igroup), source=igroup*group_size) !bcast from first proc in sub-group
    2639             :       END DO
    2640             : 
    2641             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2642        8268 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2643        8268 :       n_batch = (nocc*ngroups)/batch_size
    2644        8268 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2645             : 
    2646       16536 :       DO i_batch = 1, n_batch
    2647             :          !Loop over groups, blocks and send/receive
    2648      167776 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2649      167776 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2650             :          ir = 0
    2651             :          is = 0
    2652             :          i_msg = 0
    2653       31184 :          DO jblk = 1, nblks(2)
    2654      168344 :             DO iblk = 1, nblks(1)
    2655      434396 :                DO igroup = 0, ngroups - 1
    2656      274320 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2657             : 
    2658       67352 :                   i_msg = i_msg + 1
    2659       67352 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2660             : 
    2661             :                   !a unique tag per block, within this batch
    2662       67352 :                   tag = i_msg - (i_batch - 1)*batch_size
    2663             : 
    2664       67352 :                   found = .FALSE.
    2665       67352 :                   IF (para_env%mepos == block_source(iblk, jblk)) THEN
    2666      101028 :                      CALL dbt_get_block(t2c_main, [iblk, jblk], blk, found)
    2667             :                   END IF
    2668             : 
    2669             :                   !If blocks live on same proc, simply copy. Else MPI send/recv
    2670       67352 :                   IF (block_source(iblk, jblk) == current_dest(iblk, jblk, igroup)) THEN
    2671      101028 :                      IF (found) CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2672             :                   ELSE
    2673       33676 :                      IF (para_env%mepos == block_source(iblk, jblk) .AND. found) THEN
    2674       67352 :                         ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2675    21791618 :                         send_buff(tag)%array(:, :) = blk(:, :)
    2676       16838 :                         is = is + 1
    2677             :                         CALL para_env%isend(msgin=send_buff(tag)%array, dest=current_dest(iblk, jblk, igroup), &
    2678       16838 :                                             request=send_req(is), tag=tag)
    2679             :                      END IF
    2680             : 
    2681       33676 :                      IF (para_env%mepos == current_dest(iblk, jblk, igroup)) THEN
    2682       67352 :                         ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2683       16838 :                         ir = ir + 1
    2684             :                         CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk), &
    2685       16838 :                                             request=recv_req(ir), tag=tag)
    2686             :                      END IF
    2687             :                   END IF
    2688             : 
    2689      204512 :                   IF (found) DEALLOCATE (blk)
    2690             :                END DO
    2691             :             END DO
    2692             :          END DO
    2693             : 
    2694        8268 :          CALL mp_waitall(send_req(1:is))
    2695        8268 :          CALL mp_waitall(recv_req(1:ir))
    2696             :          !clean-up
    2697       75620 :          DO i = 1, batch_size
    2698       75620 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2699             :          END DO
    2700             : 
    2701             :          !Finally copy the data from the buffer to the sub-tensor
    2702             :          i_msg = 0
    2703       31184 :          DO jblk = 1, nblks(2)
    2704      168344 :             DO iblk = 1, nblks(1)
    2705      434396 :                DO igroup = 0, ngroups - 1
    2706      274320 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2707             : 
    2708       67352 :                   i_msg = i_msg + 1
    2709       67352 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2710             : 
    2711             :                   !a unique tag per block, within this batch
    2712       67352 :                   tag = i_msg - (i_batch - 1)*batch_size
    2713             : 
    2714       67352 :                   IF (para_env%mepos == current_dest(iblk, jblk, igroup) .AND. &
    2715      137160 :                       block_source(iblk, jblk) .NE. current_dest(iblk, jblk, igroup)) THEN
    2716             : 
    2717       67352 :                      ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk)))
    2718    21791618 :                      blk(:, :) = recv_buff(tag)%array(:, :)
    2719       84190 :                      CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2720       16838 :                      DEALLOCATE (blk)
    2721             :                   END IF
    2722             :                END DO
    2723             :             END DO
    2724             :          END DO
    2725             : 
    2726             :          !clean-up
    2727       75620 :          DO i = 1, batch_size
    2728       75620 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2729             :          END DO
    2730       16536 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2731             :       END DO !i_batch
    2732        8268 :       CALL dbt_finalize(t2c_sub)
    2733             : 
    2734       16776 :    END SUBROUTINE copy_2c_to_subgroup
    2735             : 
    2736             : ! **************************************************************************************************
    2737             : !> \brief Copy the data of a 3D tensor living in the main MPI group to a sub-group, given the proc
    2738             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2739             : !> \param t3c_sub ...
    2740             : !> \param t3c_main ...
    2741             : !> \param group_size ...
    2742             : !> \param ngroups ...
    2743             : !> \param para_env ...
    2744             : !> \param iatom_to_subgroup ...
    2745             : !> \param dim_at ...
    2746             : !> \param idx_to_at ...
    2747             : ! **************************************************************************************************
    2748       12220 :    SUBROUTINE copy_3c_to_subgroup(t3c_sub, t3c_main, group_size, ngroups, para_env, iatom_to_subgroup, &
    2749       12220 :                                   dim_at, idx_to_at)
    2750             :       TYPE(dbt_type), INTENT(INOUT)                      :: t3c_sub, t3c_main
    2751             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2752             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2753             :       TYPE(cp_1d_logical_p_type), DIMENSION(:), &
    2754             :          INTENT(INOUT), OPTIONAL                         :: iatom_to_subgroup
    2755             :       INTEGER, INTENT(IN), OPTIONAL                      :: dim_at
    2756             :       INTEGER, DIMENSION(:), OPTIONAL                    :: idx_to_at
    2757             : 
    2758             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iatom, &
    2759             :                                                             iblk, igroup, iproc, ir, is, jblk, &
    2760             :                                                             kblk, n_batch, nocc, tag
    2761       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2, bsizes3
    2762       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: block_dest, block_source
    2763       12220 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :)        :: current_dest
    2764             :       INTEGER, DIMENSION(3)                              :: ind, nblks
    2765             :       LOGICAL                                            :: filter_at, found
    2766       12220 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    2767       12220 :       TYPE(cp_3d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2768             :       TYPE(dbt_iterator_type)                            :: iter
    2769       12220 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2770             : 
    2771             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2772             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2773             :       !         to have blocks pre-reserved though
    2774             : 
    2775       12220 :       CALL dbt_get_info(t3c_main, nblks_total=nblks)
    2776             : 
    2777             :       !in some cases, only copy a fraction of the 3c tensor to a given subgroup (corresponding to some atoms)
    2778       12220 :       filter_at = .FALSE.
    2779       12220 :       IF (PRESENT(iatom_to_subgroup) .AND. PRESENT(dim_at) .AND. PRESENT(idx_to_at)) THEN
    2780        7130 :          filter_at = .TRUE.
    2781        7130 :          CPASSERT(nblks(dim_at) == SIZE(idx_to_at))
    2782             :       END IF
    2783             : 
    2784             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2785       61100 :       ALLOCATE (block_source(nblks(1), nblks(2), nblks(3)))
    2786      661340 :       block_source = -1
    2787       12220 :       nocc = 0
    2788       12220 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t3c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2789             :       CALL dbt_iterator_start(iter, t3c_main)
    2790             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2791             :          CALL dbt_iterator_next_block(iter, ind)
    2792             :          CALL dbt_get_block(t3c_main, ind, blk, found)
    2793             :          IF (.NOT. found) CYCLE
    2794             : 
    2795             :          block_source(ind(1), ind(2), ind(3)) = para_env%mepos
    2796             : !$OMP ATOMIC
    2797             :          nocc = nocc + 1
    2798             :          DEALLOCATE (blk)
    2799             :       END DO
    2800             :       CALL dbt_iterator_stop(iter)
    2801             : !$OMP END PARALLEL
    2802             : 
    2803       12220 :       CALL para_env%sum(nocc)
    2804       12220 :       CALL para_env%sum(block_source)
    2805      661340 :       block_source = block_source + para_env%num_pe - 1
    2806       12220 :       IF (nocc == 0) RETURN
    2807             : 
    2808             :       !Loop over the sub tensor, get the block destination
    2809       12220 :       igroup = para_env%mepos/group_size
    2810       48880 :       ALLOCATE (block_dest(nblks(1), nblks(2), nblks(3)))
    2811      661340 :       block_dest = -1
    2812       36660 :       DO kblk = 1, nblks(3)
    2813      170956 :          DO jblk = 1, nblks(2)
    2814      649120 :             DO iblk = 1, nblks(1)
    2815      490384 :                IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2816             : 
    2817      516168 :                CALL dbt_get_stored_coordinates(t3c_sub, [iblk, jblk, kblk], iproc)
    2818      624680 :                block_dest(iblk, jblk, kblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2819             :             END DO
    2820             :          END DO
    2821             :       END DO
    2822             : 
    2823       85540 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)), bsizes3(nblks(3)))
    2824       12220 :       CALL dbt_get_info(t3c_main, blk_size_1=bsizes1, blk_size_2=bsizes2, blk_size_3=bsizes3)
    2825             : 
    2826       73320 :       ALLOCATE (current_dest(nblks(1), nblks(2), nblks(3), 0:ngroups - 1))
    2827       36660 :       DO igroup = 0, ngroups - 1
    2828             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2829     1322680 :          current_dest(:, :, :, igroup) = block_dest(:, :, :)
    2830       36660 :          CALL para_env%bcast(current_dest(:, :, :, igroup), source=igroup*group_size) !bcast from first proc in subgroup
    2831             :       END DO
    2832             : 
    2833             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2834       12220 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2835       12220 :       n_batch = (nocc*ngroups)/batch_size
    2836       12220 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2837             : 
    2838       24440 :       DO i_batch = 1, n_batch
    2839             :          !Loop over groups, blocks and send/receive
    2840      565048 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2841      565048 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2842             :          ir = 0
    2843             :          is = 0
    2844             :          i_msg = 0
    2845       36660 :          DO kblk = 1, nblks(3)
    2846      170956 :             DO jblk = 1, nblks(2)
    2847      649120 :                DO iblk = 1, nblks(1)
    2848     1605448 :                   DO igroup = 0, ngroups - 1
    2849      980768 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2850             : 
    2851      258084 :                      i_msg = i_msg + 1
    2852      258084 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2853             : 
    2854             :                      !a unique tag per block, within this batch
    2855      258084 :                      tag = i_msg - (i_batch - 1)*batch_size
    2856             : 
    2857      258084 :                      IF (filter_at) THEN
    2858      744464 :                         ind(:) = [iblk, jblk, kblk]
    2859      186116 :                         iatom = idx_to_at(ind(dim_at))
    2860      186116 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2861             :                      END IF
    2862             : 
    2863      165026 :                      found = .FALSE.
    2864      165026 :                      IF (para_env%mepos == block_source(iblk, jblk, kblk)) THEN
    2865      330052 :                         CALL dbt_get_block(t3c_main, [iblk, jblk, kblk], blk, found)
    2866             :                      END IF
    2867             : 
    2868             :                      !If blocks live on same proc, simply copy. Else MPI send/recv
    2869      165026 :                      IF (block_source(iblk, jblk, kblk) == current_dest(iblk, jblk, kblk, igroup)) THEN
    2870      341664 :                         IF (found) CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2871             :                      ELSE
    2872       79610 :                         IF (para_env%mepos == block_source(iblk, jblk, kblk) .AND. found) THEN
    2873      199025 :                            ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2874   197767057 :                            send_buff(tag)%array(:, :, :) = blk(:, :, :)
    2875       39805 :                            is = is + 1
    2876             :                            CALL para_env%isend(msgin=send_buff(tag)%array, &
    2877             :                                                dest=current_dest(iblk, jblk, kblk, igroup), &
    2878       39805 :                                                request=send_req(is), tag=tag)
    2879             :                         END IF
    2880             : 
    2881       79610 :                         IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup)) THEN
    2882      199025 :                            ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2883       39805 :                            ir = ir + 1
    2884             :                            CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk, kblk), &
    2885       39805 :                                                request=recv_req(ir), tag=tag)
    2886             :                         END IF
    2887             :                      END IF
    2888             : 
    2889      655410 :                      IF (found) DEALLOCATE (blk)
    2890             :                   END DO
    2891             :                END DO
    2892             :             END DO
    2893             :          END DO
    2894             : 
    2895       12220 :          CALL mp_waitall(send_req(1:is))
    2896       12220 :          CALL mp_waitall(recv_req(1:ir))
    2897             :          !clean-up
    2898      270304 :          DO i = 1, batch_size
    2899      270304 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2900             :          END DO
    2901             : 
    2902             :          !Finally copy the data from the buffer to the sub-tensor
    2903             :          i_msg = 0
    2904       36660 :          DO kblk = 1, nblks(3)
    2905      170956 :             DO jblk = 1, nblks(2)
    2906      649120 :                DO iblk = 1, nblks(1)
    2907     1605448 :                   DO igroup = 0, ngroups - 1
    2908      980768 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2909             : 
    2910      258084 :                      i_msg = i_msg + 1
    2911      258084 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2912             : 
    2913             :                      !a unique tag per block, within this batch
    2914      258084 :                      tag = i_msg - (i_batch - 1)*batch_size
    2915             : 
    2916      258084 :                      IF (filter_at) THEN
    2917      744464 :                         ind(:) = [iblk, jblk, kblk]
    2918      186116 :                         iatom = idx_to_at(ind(dim_at))
    2919      186116 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2920             :                      END IF
    2921             : 
    2922      165026 :                      IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup) .AND. &
    2923      490384 :                          block_source(iblk, jblk, kblk) .NE. current_dest(iblk, jblk, kblk, igroup)) THEN
    2924             : 
    2925      199025 :                         ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2926   197767057 :                         blk(:, :, :) = recv_buff(tag)%array(:, :, :)
    2927      278635 :                         CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2928       39805 :                         DEALLOCATE (blk)
    2929             :                      END IF
    2930             :                   END DO
    2931             :                END DO
    2932             :             END DO
    2933             :          END DO
    2934             : 
    2935             :          !clean-up
    2936      270304 :          DO i = 1, batch_size
    2937      270304 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2938             :          END DO
    2939       24440 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2940             :       END DO !i_batch
    2941       12220 :       CALL dbt_finalize(t3c_sub)
    2942             : 
    2943       24440 :    END SUBROUTINE copy_3c_to_subgroup
    2944             : 
    2945             : ! **************************************************************************************************
    2946             : !> \brief A routine that gather the pieces of the KS matrix accross the subgroup and puts it in the
    2947             : !>        main group. Each b_img, iatom, jatom tuple is one a single CPU
    2948             : !> \param ks_t ...
    2949             : !> \param ks_t_sub ...
    2950             : !> \param group_size ...
    2951             : !> \param sparsity_pattern ...
    2952             : !> \param para_env ...
    2953             : !> \param ri_data ...
    2954             : ! **************************************************************************************************
    2955         214 :    SUBROUTINE gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
    2956             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t, ks_t_sub
    2957             :       INTEGER, INTENT(IN)                                :: group_size
    2958             :       INTEGER, DIMENSION(:, :, :), INTENT(IN)            :: sparsity_pattern
    2959             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2960             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2961             : 
    2962             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'gather_ks_matrix'
    2963             : 
    2964             :       INTEGER                                            :: b_img, dest, handle, i, i_spin, iatom, &
    2965             :                                                             igroup, ir, is, jatom, n_mess, natom, &
    2966             :                                                             nimg, nspins, source, tag
    2967             :       LOGICAL                                            :: found
    2968         214 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2969         214 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2970         214 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2971             : 
    2972         214 :       CALL timeset(routineN, handle)
    2973             : 
    2974         214 :       nimg = SIZE(sparsity_pattern, 3)
    2975         214 :       natom = SIZE(sparsity_pattern, 2)
    2976         214 :       nspins = SIZE(ks_t, 1)
    2977             : 
    2978        5514 :       DO b_img = 1, nimg
    2979             :          n_mess = 0
    2980       12172 :          DO i_spin = 1, nspins
    2981       25916 :             DO jatom = 1, natom
    2982       48104 :                DO iatom = 1, natom
    2983       41232 :                   IF (sparsity_pattern(iatom, jatom, b_img) > -1) n_mess = n_mess + 1
    2984             :                END DO
    2985             :             END DO
    2986             :          END DO
    2987             : 
    2988       40880 :          ALLOCATE (send_buff(n_mess), recv_buff(n_mess))
    2989       46180 :          ALLOCATE (send_req(n_mess), recv_req(n_mess))
    2990        5300 :          ir = 0
    2991        5300 :          is = 0
    2992        5300 :          n_mess = 0
    2993        5300 :          tag = 0
    2994             : 
    2995       12172 :          DO i_spin = 1, nspins
    2996       25916 :             DO jatom = 1, natom
    2997       48104 :                DO iatom = 1, natom
    2998       27488 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    2999        9608 :                   n_mess = n_mess + 1
    3000        9608 :                   tag = tag + 1
    3001             : 
    3002             :                   !sending the message
    3003       28824 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3004       28824 :                   CALL dbt_get_stored_coordinates(ks_t_sub(i_spin, b_img), [iatom, jatom], source) !source within sub
    3005        9608 :                   igroup = sparsity_pattern(iatom, jatom, b_img)
    3006        9608 :                   source = source + igroup*group_size
    3007        9608 :                   IF (para_env%mepos == source) THEN
    3008       14412 :                      CALL dbt_get_block(ks_t_sub(i_spin, b_img), [iatom, jatom], blk, found)
    3009        4804 :                      IF (source == dest) THEN
    3010        4105 :                         IF (found) CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3011             :                      ELSE
    3012       14620 :                         ALLOCATE (send_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3013      309987 :                         send_buff(n_mess)%array(:, :) = 0.0_dp
    3014        3655 :                         IF (found) THEN
    3015      208918 :                            send_buff(n_mess)%array(:, :) = blk(:, :)
    3016             :                         END IF
    3017        3655 :                         is = is + 1
    3018             :                         CALL para_env%isend(msgin=send_buff(n_mess)%array, dest=dest, &
    3019        3655 :                                             request=send_req(is), tag=tag)
    3020             :                      END IF
    3021        4804 :                      DEALLOCATE (blk)
    3022             :                   END IF
    3023             : 
    3024             :                   !receiving the message
    3025       23352 :                   IF (para_env%mepos == dest .AND. source .NE. dest) THEN
    3026       14620 :                      ALLOCATE (recv_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3027        3655 :                      ir = ir + 1
    3028             :                      CALL para_env%irecv(msgout=recv_buff(n_mess)%array, source=source, &
    3029        3655 :                                          request=recv_req(ir), tag=tag)
    3030             :                   END IF
    3031             :                END DO !iatom
    3032             :             END DO !jatom
    3033             :          END DO !ispin
    3034             : 
    3035        5300 :          CALL mp_waitall(send_req(1:is))
    3036        5300 :          CALL mp_waitall(recv_req(1:ir))
    3037             : 
    3038             :          !Copy the messages received into the KS matrix
    3039        5300 :          n_mess = 0
    3040       12172 :          DO i_spin = 1, nspins
    3041       25916 :             DO jatom = 1, natom
    3042       48104 :                DO iatom = 1, natom
    3043       27488 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3044        9608 :                   n_mess = n_mess + 1
    3045             : 
    3046       28824 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3047       23352 :                   IF (para_env%mepos == dest) THEN
    3048        4804 :                      IF (.NOT. ASSOCIATED(recv_buff(n_mess)%array)) CYCLE
    3049       14620 :                      ALLOCATE (blk(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3050      309987 :                      blk(:, :) = recv_buff(n_mess)%array(:, :)
    3051       18275 :                      CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3052        3655 :                      DEALLOCATE (blk)
    3053             :                   END IF
    3054             :                END DO
    3055             :             END DO
    3056             :          END DO
    3057             : 
    3058             :          !clean-up
    3059       14908 :          DO i = 1, n_mess
    3060        9608 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    3061       14908 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    3062             :          END DO
    3063        5514 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    3064             :       END DO !b_img
    3065             : 
    3066         214 :       CALL timestop(handle)
    3067             : 
    3068         214 :    END SUBROUTINE gather_ks_matrix
    3069             : 
    3070             : ! **************************************************************************************************
    3071             : !> \brief copy all required 2c tensors from the main MPI group to the subgroups
    3072             : !> \param mat_2c_pot ...
    3073             : !> \param t_2c_work ...
    3074             : !> \param t_2c_ao_tmp ...
    3075             : !> \param ks_t_split ...
    3076             : !> \param ks_t_sub ...
    3077             : !> \param group_size ...
    3078             : !> \param ngroups ...
    3079             : !> \param para_env ...
    3080             : !> \param para_env_sub ...
    3081             : !> \param ri_data ...
    3082             : ! **************************************************************************************************
    3083         214 :    SUBROUTINE get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
    3084             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3085             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3086             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work, t_2c_ao_tmp, ks_t_split
    3087             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t_sub
    3088             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3089             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3090             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3091             : 
    3092             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_tensors'
    3093             : 
    3094             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, iproc, &
    3095             :                                                             j, natom, nblks, nimg, nspins
    3096             :       INTEGER(int_8)                                     :: nze
    3097             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3098         214 :                                                             dist1, dist2
    3099             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3100         428 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3101         214 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3102             :       REAL(dp)                                           :: occ
    3103             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3104         642 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3105        2782 :       TYPE(dbt_type)                                     :: work, work_sub
    3106             : 
    3107         214 :       CALL timeset(routineN, handle)
    3108             : 
    3109             :       !Create the 2d pgrid
    3110         214 :       pdims_2d = 0
    3111         214 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3112             : 
    3113         214 :       natom = SIZE(ri_data%bsizes_RI)
    3114         214 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3115         642 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3116         642 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3117        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3118        3882 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3119        6686 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3120             :       END DO
    3121             : 
    3122             :       !nRI x nRI 2c tensors
    3123             :       CALL create_2c_tensor(t_2c_work(1), dist1, dist2, pgrid_2d, &
    3124             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3125             :                             name="(RI | RI)")
    3126         214 :       DEALLOCATE (dist1, dist2)
    3127             : 
    3128             :       CALL create_2c_tensor(t_2c_work(2), dist1, dist2, pgrid_2d, &
    3129             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3130         214 :                             name="(RI | RI)")
    3131         214 :       DEALLOCATE (dist1, dist2)
    3132             : 
    3133             :       !the AO based tensors
    3134             :       CALL create_2c_tensor(ks_t_split(1), dist1, dist2, pgrid_2d, &
    3135             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3136             :                             name="(AO | AO)")
    3137         214 :       DEALLOCATE (dist1, dist2)
    3138         214 :       CALL dbt_create(ks_t_split(1), ks_t_split(2))
    3139             : 
    3140             :       CALL create_2c_tensor(t_2c_ao_tmp(1), dist1, dist2, pgrid_2d, &
    3141             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, &
    3142             :                             name="(AO | AO)")
    3143         214 :       DEALLOCATE (dist1, dist2)
    3144             : 
    3145         214 :       nspins = SIZE(ks_t_sub, 1)
    3146         214 :       nimg = SIZE(ks_t_sub, 2)
    3147        5514 :       DO i_img = 1, nimg
    3148       12386 :          DO i_spin = 1, nspins
    3149       12172 :             CALL dbt_create(t_2c_ao_tmp(1), ks_t_sub(i_spin, i_img))
    3150             :          END DO
    3151             :       END DO
    3152             : 
    3153             :       !Finally the HFX potential matrices
    3154             :       !For now, we do a convoluted things where we go to tensors first, then back to matrices.
    3155             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3156             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3157             :                             name="(RI | RI)")
    3158         214 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3159             : 
    3160         856 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3161         214 :       iproc = 0
    3162         428 :       DO i = 0, pdims_2d(1) - 1
    3163         642 :          DO j = 0, pdims_2d(2) - 1
    3164         214 :             dbcsr_pgrid(i, j) = iproc
    3165         428 :             iproc = iproc + 1
    3166             :          END DO
    3167             :       END DO
    3168             : 
    3169             :       !We need to have the same exact 2d block dist as the tensors
    3170         856 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3171         642 :       row_dist(:) = dist1(:)
    3172         642 :       col_dist(:) = dist2(:)
    3173             : 
    3174         428 :       ALLOCATE (RI_blk_size(natom))
    3175         642 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3176             : 
    3177             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3178         214 :                                   row_dist=row_dist, col_dist=col_dist)
    3179             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3180         214 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3181             : 
    3182        5514 :       DO i_img = 1, nimg
    3183        5300 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3184        5300 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3185        5300 :          CALL get_tensor_occupancy(work, nze, occ)
    3186        5300 :          IF (nze == 0) CYCLE
    3187             : 
    3188        4132 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3189        4132 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3190        4132 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3191        9646 :          CALL dbt_clear(work_sub)
    3192             :       END DO
    3193             : 
    3194         214 :       CALL dbt_destroy(work)
    3195         214 :       CALL dbt_destroy(work_sub)
    3196         214 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3197         214 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3198         214 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3199         214 :       CALL timestop(handle)
    3200             : 
    3201        1926 :    END SUBROUTINE get_subgroup_2c_tensors
    3202             : 
    3203             : ! **************************************************************************************************
    3204             : !> \brief copy all required 3c tensors from the main MPI group to the subgroups
    3205             : !> \param t_3c_int ...
    3206             : !> \param t_3c_work_2 ...
    3207             : !> \param t_3c_work_3 ...
    3208             : !> \param t_3c_apc ...
    3209             : !> \param t_3c_apc_sub ...
    3210             : !> \param group_size ...
    3211             : !> \param ngroups ...
    3212             : !> \param para_env ...
    3213             : !> \param para_env_sub ...
    3214             : !> \param ri_data ...
    3215             : ! **************************************************************************************************
    3216         214 :    SUBROUTINE get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
    3217             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3218             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_int, t_3c_work_2, t_3c_work_3
    3219             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, t_3c_apc_sub
    3220             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3221             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3222             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3223             : 
    3224             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_tensors'
    3225             : 
    3226             :       INTEGER                                            :: batch_size, bfac, bo(2), handle, &
    3227             :                                                             handle2, i_blk, i_img, i_RI, i_spin, &
    3228             :                                                             ib, natom, nblks_AO, nblks_RI, nimg, &
    3229             :                                                             nspins
    3230             :       INTEGER(int_8)                                     :: nze
    3231         214 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3232         214 :                                                             bsizes_stack, bsizes_tmp, dist1, &
    3233         214 :                                                             dist2, dist3, dist_stack, idx_to_at
    3234             :       INTEGER, DIMENSION(3)                              :: pdims
    3235             :       REAL(dp)                                           :: occ
    3236        1926 :       TYPE(dbt_distribution_type)                        :: t_dist
    3237         642 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3238        5350 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3239             : 
    3240         214 :       CALL timeset(routineN, handle)
    3241             : 
    3242         214 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3243         642 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3244        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3245        6686 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3246             :       END DO
    3247             : 
    3248             :       !Preparing larger block sizes for efficient communication (less, bigger messages)
    3249             :       !we put 2 atoms per RI block
    3250         214 :       bfac = 2
    3251         214 :       natom = SIZE(ri_data%bsizes_RI)
    3252         214 :       nblks_RI = MAX(1, natom/bfac)
    3253         642 :       ALLOCATE (bsizes_tmp(nblks_RI))
    3254         428 :       DO i_blk = 1, nblks_RI
    3255         214 :          bo = get_limit(natom, nblks_RI, i_blk - 1)
    3256         856 :          bsizes_tmp(i_blk) = SUM(ri_data%bsizes_RI(bo(1):bo(2)))
    3257             :       END DO
    3258         642 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3259        1508 :       DO i_RI = 1, ri_data%ncell_RI
    3260        2802 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = bsizes_tmp(:)
    3261             :       END DO
    3262             : 
    3263         214 :       batch_size = ri_data%kp_stack_size
    3264         214 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3265         642 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3266        7062 :       DO ib = 1, batch_size
    3267       33558 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3268             :       END DO
    3269             : 
    3270             :       !Create the pgrid for the configuration correspoinding to ri_data%t_3c_int_ctr_3
    3271         214 :       natom = SIZE(ri_data%bsizes_RI)
    3272         214 :       pdims = 0
    3273             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3274         856 :                             tensor_dims=[SIZE(bsizes_RI_ext_split), 1, batch_size*SIZE(ri_data%bsizes_AO_split)])
    3275             : 
    3276             :       !Create all required 3c tensors in that configuration
    3277             :       CALL create_3c_tensor(t_3c_int(1), dist1, dist2, dist3, &
    3278             :                             pgrid, bsizes_RI_ext_split, ri_data%bsizes_AO_split, &
    3279         214 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(RI | AO AO)")
    3280         214 :       nimg = SIZE(t_3c_int)
    3281        5300 :       DO i_img = 2, nimg
    3282        5300 :          CALL dbt_create(t_3c_int(1), t_3c_int(i_img))
    3283             :       END DO
    3284             : 
    3285             :       !The stacked work tensors, in a distribution that matches that of t_3c_int
    3286         428 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3287        7062 :       DO ib = 1, batch_size
    3288       33558 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3289             :       END DO
    3290             : 
    3291         214 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3292             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3293         214 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3294         214 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3295         214 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3296         214 :       CALL dbt_distribution_destroy(t_dist)
    3297         214 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3298             : 
    3299             :       !For more efficient communication, we use intermediate tensors with larger block size
    3300             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3301             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3302         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3303         214 :       DEALLOCATE (dist1, dist2, dist3)
    3304             : 
    3305             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3306             :                             ri_data%pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3307         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3308         214 :       DEALLOCATE (dist1, dist2, dist3)
    3309             : 
    3310             :       !Finally copy the integrals into the subgroups (if not there already)
    3311         214 :       CALL timeset(routineN//"_ints", handle2)
    3312         214 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
    3313        3800 :          DO i_img = 1, nimg
    3314        3800 :             CALL dbt_copy(ri_data%kp_t_3c_int(i_img), t_3c_int(i_img), move_data=.TRUE.)
    3315             :          END DO
    3316             :       ELSE
    3317        2414 :          ALLOCATE (ri_data%kp_t_3c_int(nimg))
    3318        1714 :          DO i_img = 1, nimg
    3319        1644 :             CALL dbt_create(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img))
    3320        1644 :             CALL get_tensor_occupancy(ri_data%t_3c_int_ctr_1(1, i_img), nze, occ)
    3321        1644 :             IF (nze == 0) CYCLE
    3322        1238 :             CALL dbt_copy(ri_data%t_3c_int_ctr_1(1, i_img), work_atom_block, order=[2, 1, 3])
    3323        1238 :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, ngroups, para_env)
    3324        2952 :             CALL dbt_copy(work_atom_block_sub, t_3c_int(i_img), move_data=.TRUE.)
    3325             :          END DO
    3326             :       END IF
    3327         214 :       CALL timestop(handle2)
    3328         214 :       CALL dbt_pgrid_destroy(pgrid)
    3329         214 :       CALL dbt_destroy(work_atom_block)
    3330         214 :       CALL dbt_destroy(work_atom_block_sub)
    3331             : 
    3332             :       !Do the same for the t_3c_ctr_2 configuration
    3333         214 :       pdims = 0
    3334             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3335         856 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3336             : 
    3337             :       !For more efficient communication, we use intermediate tensors with larger block size
    3338             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3339             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3340         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3341         214 :       DEALLOCATE (dist1, dist2, dist3)
    3342             : 
    3343             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3344             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3345         214 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3346         214 :       DEALLOCATE (dist1, dist2, dist3)
    3347             : 
    3348             :       !template for t_3c_apc_sub
    3349             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3350             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3351         214 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3352             : 
    3353             :       !create t_3c_work_2 tensors in a distribution that matches the above
    3354         428 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3355        7062 :       DO ib = 1, batch_size
    3356       33558 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3357             :       END DO
    3358             : 
    3359         214 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3360             :       CALL dbt_create(t_3c_work_2(1), "work_2_stack", t_dist, [1], [2, 3], &
    3361         214 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3362         214 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3363         214 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3364         214 :       CALL dbt_distribution_destroy(t_dist)
    3365         214 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3366             : 
    3367             :       !Finally copy data from t_3c_apc to the subgroups
    3368         642 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3369         214 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3370         214 :       nspins = SIZE(t_3c_apc, 1)
    3371         214 :       CALL timeset(routineN//"_apc", handle2)
    3372        5514 :       DO i_img = 1, nimg
    3373       12172 :          DO i_spin = 1, nspins
    3374        6872 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3375        6872 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3376        6872 :             IF (nze == 0) CYCLE
    3377        6032 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3378             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3379        6032 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3380       18204 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3381             :          END DO
    3382       12386 :          DO i_spin = 1, nspins
    3383       12172 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3384             :          END DO
    3385             :       END DO
    3386         214 :       CALL timestop(handle2)
    3387         214 :       CALL dbt_pgrid_destroy(pgrid)
    3388         214 :       CALL dbt_destroy(tmp)
    3389         214 :       CALL dbt_destroy(work_atom_block)
    3390         214 :       CALL dbt_destroy(work_atom_block_sub)
    3391             : 
    3392         214 :       CALL timestop(handle)
    3393             : 
    3394         856 :    END SUBROUTINE get_subgroup_3c_tensors
    3395             : 
    3396             : ! **************************************************************************************************
    3397             : !> \brief copy all required 2c force tensors from the main MPI group to the subgroups
    3398             : !> \param t_2c_inv ...
    3399             : !> \param t_2c_bint ...
    3400             : !> \param t_2c_metric ...
    3401             : !> \param mat_2c_pot ...
    3402             : !> \param t_2c_work ...
    3403             : !> \param rho_ao_t ...
    3404             : !> \param rho_ao_t_sub ...
    3405             : !> \param t_2c_der_metric ...
    3406             : !> \param t_2c_der_metric_sub ...
    3407             : !> \param mat_der_pot ...
    3408             : !> \param mat_der_pot_sub ...
    3409             : !> \param group_size ...
    3410             : !> \param ngroups ...
    3411             : !> \param para_env ...
    3412             : !> \param para_env_sub ...
    3413             : !> \param ri_data ...
    3414             : !> \note Main MPI group tensors are deleted within this routine, for memory optimization
    3415             : ! **************************************************************************************************
    3416          84 :    SUBROUTINE get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
    3417          42 :                                      rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
    3418          42 :                                      mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
    3419             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_inv, t_2c_bint, t_2c_metric
    3420             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3421             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work
    3422             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
    3423             :                                                             t_2c_der_metric_sub
    3424             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot, mat_der_pot_sub
    3425             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3426             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3427             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3428             : 
    3429             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_derivs'
    3430             : 
    3431             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, i_xyz, &
    3432             :                                                             iatom, iproc, j, natom, nblks, nimg, &
    3433             :                                                             nspins
    3434             :       INTEGER(int_8)                                     :: nze
    3435             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3436          42 :                                                             dist1, dist2
    3437             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3438          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3439          42 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3440             :       REAL(dp)                                           :: occ
    3441             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3442         126 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3443         546 :       TYPE(dbt_type)                                     :: work, work_sub
    3444             : 
    3445          42 :       CALL timeset(routineN, handle)
    3446             : 
    3447             :       !Note: a fair portion of this routine is copied from the energy version of it
    3448             :       !Create the 2d pgrid
    3449          42 :       pdims_2d = 0
    3450          42 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3451             : 
    3452          42 :       natom = SIZE(ri_data%bsizes_RI)
    3453          42 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3454         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3455         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3456         294 :       DO i_RI = 1, ri_data%ncell_RI
    3457         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3458        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3459             :       END DO
    3460             : 
    3461             :       !nRI x nRI 2c tensors
    3462             :       CALL create_2c_tensor(t_2c_inv(1), dist1, dist2, pgrid_2d, &
    3463             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3464             :                             name="(RI | RI)")
    3465          42 :       DEALLOCATE (dist1, dist2)
    3466             : 
    3467          42 :       CALL dbt_create(t_2c_inv(1), t_2c_bint(1))
    3468          42 :       CALL dbt_create(t_2c_inv(1), t_2c_metric(1))
    3469          84 :       DO iatom = 2, natom
    3470          42 :          CALL dbt_create(t_2c_inv(1), t_2c_inv(iatom))
    3471          42 :          CALL dbt_create(t_2c_inv(1), t_2c_bint(iatom))
    3472          84 :          CALL dbt_create(t_2c_inv(1), t_2c_metric(iatom))
    3473             :       END DO
    3474          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(1))
    3475          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(2))
    3476          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(3))
    3477          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(4))
    3478             : 
    3479             :       CALL create_2c_tensor(t_2c_work(5), dist1, dist2, pgrid_2d, &
    3480             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3481          42 :                             name="(RI | RI)")
    3482          42 :       DEALLOCATE (dist1, dist2)
    3483             : 
    3484             :       !copy the data from the main group.
    3485         126 :       DO iatom = 1, natom
    3486          84 :          CALL copy_2c_to_subgroup(t_2c_inv(iatom), ri_data%t_2c_inv(1, iatom), group_size, ngroups, para_env)
    3487          84 :          CALL copy_2c_to_subgroup(t_2c_bint(iatom), ri_data%t_2c_int(1, iatom), group_size, ngroups, para_env)
    3488         126 :          CALL copy_2c_to_subgroup(t_2c_metric(iatom), ri_data%t_2c_pot(1, iatom), group_size, ngroups, para_env)
    3489             :       END DO
    3490             : 
    3491             :       !This includes the derivatives of the RI metric, for which there is one per atom
    3492         168 :       DO i_xyz = 1, 3
    3493         420 :          DO iatom = 1, natom
    3494         252 :             CALL dbt_create(t_2c_inv(1), t_2c_der_metric_sub(iatom, i_xyz))
    3495             :             CALL copy_2c_to_subgroup(t_2c_der_metric_sub(iatom, i_xyz), t_2c_der_metric(iatom, i_xyz), &
    3496         252 :                                      group_size, ngroups, para_env)
    3497         378 :             CALL dbt_destroy(t_2c_der_metric(iatom, i_xyz))
    3498             :          END DO
    3499             :       END DO
    3500             : 
    3501             :       !AO x AO 2c tensors
    3502             :       CALL create_2c_tensor(rho_ao_t_sub(1, 1), dist1, dist2, pgrid_2d, &
    3503             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3504             :                             name="(AO | AO)")
    3505          42 :       DEALLOCATE (dist1, dist2)
    3506          42 :       nspins = SIZE(rho_ao_t, 1)
    3507          42 :       nimg = SIZE(rho_ao_t, 2)
    3508             : 
    3509         980 :       DO i_img = 1, nimg
    3510        2084 :          DO i_spin = 1, nspins
    3511        1104 :             IF (.NOT. (i_img == 1 .AND. i_spin == 1)) &
    3512        1062 :                CALL dbt_create(rho_ao_t_sub(1, 1), rho_ao_t_sub(i_spin, i_img))
    3513             :             CALL copy_2c_to_subgroup(rho_ao_t_sub(i_spin, i_img), rho_ao_t(i_spin, i_img), &
    3514        1104 :                                      group_size, ngroups, para_env)
    3515        2042 :             CALL dbt_destroy(rho_ao_t(i_spin, i_img))
    3516             :          END DO
    3517             :       END DO
    3518             : 
    3519             :       !The RIxRI matrices, going through tensors
    3520             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3521             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3522             :                             name="(RI | RI)")
    3523          42 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3524             : 
    3525         168 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3526          42 :       iproc = 0
    3527          84 :       DO i = 0, pdims_2d(1) - 1
    3528         126 :          DO j = 0, pdims_2d(2) - 1
    3529          42 :             dbcsr_pgrid(i, j) = iproc
    3530          84 :             iproc = iproc + 1
    3531             :          END DO
    3532             :       END DO
    3533             : 
    3534             :       !We need to have the same exact 2d block dist as the tensors
    3535         168 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3536         126 :       row_dist(:) = dist1(:)
    3537         126 :       col_dist(:) = dist2(:)
    3538             : 
    3539          84 :       ALLOCATE (RI_blk_size(natom))
    3540         126 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3541             : 
    3542             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3543          42 :                                   row_dist=row_dist, col_dist=col_dist)
    3544             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3545          42 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3546             : 
    3547             :       !The HFX potential
    3548         980 :       DO i_img = 1, nimg
    3549         938 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3550         938 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3551         938 :          CALL get_tensor_occupancy(work, nze, occ)
    3552         938 :          IF (nze == 0) CYCLE
    3553             : 
    3554         662 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3555         662 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3556         662 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3557        1642 :          CALL dbt_clear(work_sub)
    3558             :       END DO
    3559             : 
    3560             :       !The derivatives of the HFX potential
    3561         168 :       DO i_xyz = 1, 3
    3562        2982 :          DO i_img = 1, nimg
    3563        2814 :             CALL dbcsr_create(mat_der_pot_sub(i_img, i_xyz), template=mat_2c_pot(1))
    3564        2814 :             CALL dbt_copy_matrix_to_tensor(mat_der_pot(i_img, i_xyz), work)
    3565        2814 :             CALL dbcsr_release(mat_der_pot(i_img, i_xyz))
    3566        2814 :             CALL get_tensor_occupancy(work, nze, occ)
    3567        2814 :             IF (nze == 0) CYCLE
    3568             : 
    3569        1986 :             CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3570        1986 :             CALL dbt_copy_tensor_to_matrix(work_sub, mat_der_pot_sub(i_img, i_xyz))
    3571        1986 :             CALL dbcsr_filter(mat_der_pot_sub(i_img, i_xyz), ri_data%filter_eps)
    3572        4926 :             CALL dbt_clear(work_sub)
    3573             :          END DO
    3574             :       END DO
    3575             : 
    3576          42 :       CALL dbt_destroy(work)
    3577          42 :       CALL dbt_destroy(work_sub)
    3578          42 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3579          42 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3580          42 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3581             : 
    3582          42 :       CALL timestop(handle)
    3583             : 
    3584         336 :    END SUBROUTINE get_subgroup_2c_derivs
    3585             : 
    3586             : ! **************************************************************************************************
    3587             : !> \brief copy all required 3c derivative tensors from the main MPI group to the subgroups
    3588             : !> \param t_3c_work_2 ...
    3589             : !> \param t_3c_work_3 ...
    3590             : !> \param t_3c_der_AO ...
    3591             : !> \param t_3c_der_AO_sub ...
    3592             : !> \param t_3c_der_RI ...
    3593             : !> \param t_3c_der_RI_sub ...
    3594             : !> \param t_3c_apc ...
    3595             : !> \param t_3c_apc_sub ...
    3596             : !> \param t_3c_der_stack ...
    3597             : !> \param group_size ...
    3598             : !> \param ngroups ...
    3599             : !> \param para_env ...
    3600             : !> \param para_env_sub ...
    3601             : !> \param ri_data ...
    3602             : !> \note the tensor containing the derivatives in the main MPI group are deleted for memory
    3603             : ! **************************************************************************************************
    3604          42 :    SUBROUTINE get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
    3605          42 :                                      t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, &
    3606          42 :                                      t_3c_der_stack, group_size, ngroups, para_env, para_env_sub, &
    3607             :                                      ri_data)
    3608             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_work_2, t_3c_work_3
    3609             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_AO, t_3c_der_AO_sub, &
    3610             :                                                             t_3c_der_RI, t_3c_der_RI_sub, &
    3611             :                                                             t_3c_apc, t_3c_apc_sub
    3612             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_der_stack
    3613             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3614             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3615             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3616             : 
    3617             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_derivs'
    3618             : 
    3619             :       INTEGER                                            :: batch_size, handle, i_img, i_RI, i_spin, &
    3620             :                                                             i_xyz, ib, nblks_AO, nblks_RI, nimg, &
    3621             :                                                             nspins, pdims(3)
    3622             :       INTEGER(int_8)                                     :: nze
    3623          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3624          42 :                                                             bsizes_stack, dist1, dist2, dist3, &
    3625          42 :                                                             dist_stack, idx_to_at
    3626             :       REAL(dp)                                           :: occ
    3627         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    3628         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3629        1050 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3630             : 
    3631          42 :       CALL timeset(routineN, handle)
    3632             : 
    3633             :       !We use intermediate tensors with larger block size for more optimized communication
    3634          42 :       nblks_RI = SIZE(ri_data%bsizes_RI)
    3635         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3636         294 :       DO i_RI = 1, ri_data%ncell_RI
    3637         798 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
    3638             :       END DO
    3639             : 
    3640          42 :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), pdims=pdims)
    3641          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3642             : 
    3643             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3644             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3645          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3646          42 :       DEALLOCATE (dist1, dist2, dist3)
    3647             : 
    3648             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3649             :                             ri_data%pgrid_2, bsizes_RI_ext, ri_data%bsizes_AO, &
    3650          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3651          42 :       DEALLOCATE (dist1, dist2, dist3)
    3652          42 :       CALL dbt_pgrid_destroy(pgrid)
    3653             : 
    3654             :       !We use the 3c integrals on the subgroup as template for the derivatives
    3655          42 :       nimg = ri_data%nimg
    3656         168 :       DO i_xyz = 1, 3
    3657        2940 :          DO i_img = 1, nimg
    3658        2814 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_AO_sub(i_img, i_xyz))
    3659        2814 :             CALL get_tensor_occupancy(t_3c_der_AO(i_img, i_xyz), nze, occ)
    3660        2814 :             IF (nze == 0) CYCLE
    3661             : 
    3662        1932 :             CALL dbt_copy(t_3c_der_AO(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3663             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3664        1932 :                                      group_size, ngroups, para_env)
    3665        4872 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_AO_sub(i_img, i_xyz), move_data=.TRUE.)
    3666             :          END DO
    3667             : 
    3668        2940 :          DO i_img = 1, nimg
    3669        2814 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_RI_sub(i_img, i_xyz))
    3670        2814 :             CALL get_tensor_occupancy(t_3c_der_RI(i_img, i_xyz), nze, occ)
    3671        2814 :             IF (nze == 0) CYCLE
    3672             : 
    3673        1920 :             CALL dbt_copy(t_3c_der_RI(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3674             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3675        1920 :                                      group_size, ngroups, para_env)
    3676        4860 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_RI_sub(i_img, i_xyz), move_data=.TRUE.)
    3677             :          END DO
    3678             : 
    3679        2982 :          DO i_img = 1, nimg
    3680        2814 :             CALL dbt_destroy(t_3c_der_RI(i_img, i_xyz))
    3681        2940 :             CALL dbt_destroy(t_3c_der_AO(i_img, i_xyz))
    3682             :          END DO
    3683             :       END DO
    3684          42 :       CALL dbt_destroy(work_atom_block_sub)
    3685          42 :       CALL dbt_destroy(work_atom_block)
    3686             : 
    3687             :       !Deal with t_3c_apc
    3688          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3689         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3690         294 :       DO i_RI = 1, ri_data%ncell_RI
    3691        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3692             :       END DO
    3693             : 
    3694          42 :       pdims = 0
    3695             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3696         168 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3697             : 
    3698             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3699             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3700          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3701          42 :       DEALLOCATE (dist1, dist2, dist3)
    3702             : 
    3703             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3704             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3705          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3706          42 :       DEALLOCATE (dist1, dist2, dist3)
    3707             : 
    3708             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3709             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3710          42 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3711          42 :       DEALLOCATE (dist1, dist2, dist3)
    3712             : 
    3713         126 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3714          42 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3715          42 :       nspins = SIZE(t_3c_apc, 1)
    3716         980 :       DO i_img = 1, nimg
    3717        2042 :          DO i_spin = 1, nspins
    3718        1104 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3719        1104 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3720        1104 :             IF (nze == 0) CYCLE
    3721        1098 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3722             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3723        1098 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3724        3140 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3725             :          END DO
    3726        2084 :          DO i_spin = 1, nspins
    3727        2042 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3728             :          END DO
    3729             :       END DO
    3730          42 :       CALL dbt_destroy(tmp)
    3731          42 :       CALL dbt_destroy(work_atom_block)
    3732          42 :       CALL dbt_destroy(work_atom_block_sub)
    3733          42 :       CALL dbt_pgrid_destroy(pgrid)
    3734             : 
    3735             :       !t_3c_work_3 based on structure of 3c integrals/derivs
    3736          42 :       batch_size = ri_data%kp_stack_size
    3737          42 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3738         126 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3739        1386 :       DO ib = 1, batch_size
    3740        6826 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3741             :       END DO
    3742             : 
    3743         294 :       ALLOCATE (dist1(ri_data%ncell_RI*nblks_RI), dist2(nblks_AO), dist3(nblks_AO))
    3744             :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3745          42 :                         proc_dist_3=dist3, pdims=pdims)
    3746             : 
    3747         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3748        1386 :       DO ib = 1, batch_size
    3749        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3750             :       END DO
    3751             : 
    3752          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3753          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3754             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3755          42 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3756          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3757          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3758          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(4))
    3759          42 :       CALL dbt_distribution_destroy(t_dist)
    3760          42 :       CALL dbt_pgrid_destroy(pgrid)
    3761          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3762             : 
    3763             :       !the derivatives are stacked in the same way
    3764          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(1))
    3765          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(2))
    3766          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(3))
    3767          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(4))
    3768          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(5))
    3769          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(6))
    3770             : 
    3771             :       !t_3c_work_2 based on structure of t_3c_apc
    3772         294 :       ALLOCATE (dist1(nblks_AO), dist2(ri_data%ncell_RI*nblks_RI), dist3(nblks_AO))
    3773             :       CALL dbt_get_info(t_3c_apc_sub(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3774          42 :                         proc_dist_3=dist3, pdims=pdims)
    3775             : 
    3776         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3777        1386 :       DO ib = 1, batch_size
    3778        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3779             :       END DO
    3780             : 
    3781          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3782          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3783             :       CALL dbt_create(t_3c_work_2(1), "work_3_stack", t_dist, [1], [2, 3], &
    3784          42 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3785          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3786          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3787          42 :       CALL dbt_distribution_destroy(t_dist)
    3788          42 :       CALL dbt_pgrid_destroy(pgrid)
    3789          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3790             : 
    3791          42 :       CALL timestop(handle)
    3792             : 
    3793          84 :    END SUBROUTINE get_subgroup_3c_derivs
    3794             : 
    3795             : ! **************************************************************************************************
    3796             : !> \brief A routine that reorders the t_3c_int tensors such that all items which are fully empty
    3797             : !>        are bunched together. This way, we can get much more efficient screening based on NZE
    3798             : !> \param t_3c_ints ...
    3799             : !> \param ri_data ...
    3800             : ! **************************************************************************************************
    3801          70 :    SUBROUTINE reorder_3c_ints(t_3c_ints, ri_data)
    3802             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_ints
    3803             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3804             : 
    3805             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_ints'
    3806             : 
    3807             :       INTEGER                                            :: handle, i_img, idx, idx_empty, idx_full, &
    3808             :                                                             nimg
    3809             :       INTEGER(int_8)                                     :: nze
    3810             :       REAL(dp)                                           :: occ
    3811          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3812             : 
    3813          70 :       CALL timeset(routineN, handle)
    3814             : 
    3815          70 :       nimg = ri_data%nimg
    3816        2414 :       ALLOCATE (t_3c_tmp(nimg))
    3817        1714 :       DO i_img = 1, nimg
    3818        1644 :          CALL dbt_create(t_3c_ints(i_img), t_3c_tmp(i_img))
    3819        1714 :          CALL dbt_copy(t_3c_ints(i_img), t_3c_tmp(i_img), move_data=.TRUE.)
    3820             :       END DO
    3821             : 
    3822             :       !Loop over the images, check if ints have NZE == 0, and put them at the start or end of the
    3823             :       !initial tensor array. Keep the mapping in an array
    3824         210 :       ALLOCATE (ri_data%idx_to_img(nimg))
    3825          70 :       idx_full = 0
    3826          70 :       idx_empty = nimg + 1
    3827             : 
    3828        1714 :       DO i_img = 1, nimg
    3829        1644 :          CALL get_tensor_occupancy(t_3c_tmp(i_img), nze, occ)
    3830        1644 :          IF (nze == 0) THEN
    3831         406 :             idx_empty = idx_empty - 1
    3832         406 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_empty), move_data=.TRUE.)
    3833         406 :             ri_data%idx_to_img(idx_empty) = i_img
    3834             :          ELSE
    3835        1238 :             idx_full = idx_full + 1
    3836        1238 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_full), move_data=.TRUE.)
    3837        1238 :             ri_data%idx_to_img(idx_full) = i_img
    3838             :          END IF
    3839        3358 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3840             :       END DO
    3841             : 
    3842             :       !store the highest image index with non-zero integrals
    3843          70 :       ri_data%nimg_nze = idx_full
    3844             : 
    3845         140 :       ALLOCATE (ri_data%img_to_idx(nimg))
    3846        1714 :       DO idx = 1, nimg
    3847        1714 :          ri_data%img_to_idx(ri_data%idx_to_img(idx)) = idx
    3848             :       END DO
    3849             : 
    3850          70 :       CALL timestop(handle)
    3851             : 
    3852        1784 :    END SUBROUTINE reorder_3c_ints
    3853             : 
    3854             : ! **************************************************************************************************
    3855             : !> \brief A routine that reorders the 3c derivatives, the same way that the integrals are, also to
    3856             : !>        increase efficiency of screening
    3857             : !> \param t_3c_derivs ...
    3858             : !> \param ri_data ...
    3859             : ! **************************************************************************************************
    3860          84 :    SUBROUTINE reorder_3c_derivs(t_3c_derivs, ri_data)
    3861             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_derivs
    3862             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3863             : 
    3864             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_derivs'
    3865             : 
    3866             :       INTEGER                                            :: handle, i_img, i_xyz, idx, nimg
    3867             :       INTEGER(int_8)                                     :: nze
    3868             :       REAL(dp)                                           :: occ
    3869          84 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3870             : 
    3871          84 :       CALL timeset(routineN, handle)
    3872             : 
    3873          84 :       nimg = ri_data%nimg
    3874        2800 :       ALLOCATE (t_3c_tmp(nimg))
    3875        1960 :       DO i_img = 1, nimg
    3876        1960 :          CALL dbt_create(t_3c_derivs(1, 1), t_3c_tmp(i_img))
    3877             :       END DO
    3878             : 
    3879         336 :       DO i_xyz = 1, 3
    3880        5880 :          DO i_img = 1, nimg
    3881        5880 :             CALL dbt_copy(t_3c_derivs(i_img, i_xyz), t_3c_tmp(i_img), move_data=.TRUE.)
    3882             :          END DO
    3883        5964 :          DO i_img = 1, nimg
    3884        5628 :             idx = ri_data%img_to_idx(i_img)
    3885        5628 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_derivs(idx, i_xyz), move_data=.TRUE.)
    3886        5628 :             CALL get_tensor_occupancy(t_3c_derivs(idx, i_xyz), nze, occ)
    3887        5880 :             IF (nze > 0) ri_data%nimg_nze = MAX(idx, ri_data%nimg_nze)
    3888             :          END DO
    3889             :       END DO
    3890             : 
    3891        1960 :       DO i_img = 1, nimg
    3892        1960 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3893             :       END DO
    3894             : 
    3895          84 :       CALL timestop(handle)
    3896             : 
    3897        2044 :    END SUBROUTINE reorder_3c_derivs
    3898             : 
    3899             : ! **************************************************************************************************
    3900             : !> \brief Get the sparsity pattern related to the non-symmetric AO basis overlap neighbor list
    3901             : !> \param pattern ...
    3902             : !> \param ri_data ...
    3903             : !> \param qs_env ...
    3904             : ! **************************************************************************************************
    3905         256 :    SUBROUTINE get_sparsity_pattern(pattern, ri_data, qs_env)
    3906             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: pattern
    3907             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3908             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    3909             : 
    3910             :       INTEGER                                            :: iatom, j_img, jatom, mj_img, natom, nimg
    3911         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bins
    3912         256 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: tmp_pattern
    3913             :       INTEGER, DIMENSION(3)                              :: cell_j
    3914         256 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    3915         256 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    3916             :       TYPE(dft_control_type), POINTER                    :: dft_control
    3917             :       TYPE(kpoint_type), POINTER                         :: kpoints
    3918             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    3919             :       TYPE(neighbor_list_iterator_p_type), &
    3920         256 :          DIMENSION(:), POINTER                           :: nl_iterator
    3921             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    3922         256 :          POINTER                                         :: nl_2c
    3923             : 
    3924         256 :       NULLIFY (nl_2c, nl_iterator, kpoints, cell_to_index, dft_control, index_to_cell, para_env)
    3925             : 
    3926         256 :       CALL get_qs_env(qs_env, kpoints=kpoints, dft_control=dft_control, para_env=para_env, natom=natom)
    3927         256 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell, sab_nl=nl_2c)
    3928             : 
    3929         256 :       nimg = ri_data%nimg
    3930       43922 :       pattern(:, :, :) = 0
    3931             : 
    3932             :       !We use the symmetric nl for all images that have an opposite cell
    3933         256 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3934       11318 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3935       11062 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3936             : 
    3937       11062 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3938       11062 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3939             : 
    3940        7515 :          mj_img = get_opp_index(j_img, qs_env)
    3941        7515 :          IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3942             : 
    3943        7206 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3944             : 
    3945       11062 :          pattern(iatom, jatom, j_img) = 1
    3946             :       END DO
    3947         256 :       CALL neighbor_list_iterator_release(nl_iterator)
    3948             : 
    3949             :       !If there is no opposite cell present, then we take into account the non-symmetric nl
    3950         256 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=nl_2c)
    3951             : 
    3952         256 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3953       14800 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3954       14544 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3955             : 
    3956       14544 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3957       14544 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3958             : 
    3959        9597 :          mj_img = get_opp_index(j_img, qs_env)
    3960        9597 :          IF (mj_img .LE. nimg .AND. mj_img > 0) CYCLE
    3961             : 
    3962         321 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3963             : 
    3964       14544 :          pattern(iatom, jatom, j_img) = 1
    3965             :       END DO
    3966         256 :       CALL neighbor_list_iterator_release(nl_iterator)
    3967             : 
    3968       87588 :       CALL para_env%sum(pattern)
    3969             : 
    3970             :       !If the opposite image is considered, then there is no need to compute diagonal twice
    3971        6238 :       DO j_img = 2, nimg
    3972       18202 :          DO iatom = 1, natom
    3973       17946 :             IF (pattern(iatom, iatom, j_img) .NE. 0) THEN
    3974        4148 :                mj_img = get_opp_index(j_img, qs_env)
    3975        4148 :                IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3976        4148 :                pattern(iatom, iatom, mj_img) = 0
    3977             :             END IF
    3978             :          END DO
    3979             :       END DO
    3980             : 
    3981             :       ! We want to equilibrate the sparsity pattern such that there are same amount of blocks
    3982             :       ! for each atom i of i,j pairs
    3983         768 :       ALLOCATE (bins(natom))
    3984         768 :       bins(:) = 0
    3985             : 
    3986        1280 :       ALLOCATE (tmp_pattern(natom, natom, nimg))
    3987       43922 :       tmp_pattern(:, :, :) = 0
    3988        6494 :       DO j_img = 1, nimg
    3989       18970 :          DO jatom = 1, natom
    3990       43666 :             DO iatom = 1, natom
    3991       24952 :                IF (pattern(iatom, jatom, j_img) == 0) CYCLE
    3992        8374 :                mj_img = get_opp_index(j_img, qs_env)
    3993             : 
    3994             :                !Should we take the i,j,b or th j,i,-b atomic block?
    3995       20850 :                IF (mj_img > nimg .OR. mj_img < 1) THEN
    3996             :                   !No opposite image, no choice
    3997         214 :                   bins(iatom) = bins(iatom) + 1
    3998         214 :                   tmp_pattern(iatom, jatom, j_img) = 1
    3999             :                ELSE
    4000             : 
    4001        8160 :                   IF (bins(iatom) > bins(jatom)) THEN
    4002        1646 :                      bins(jatom) = bins(jatom) + 1
    4003        1646 :                      tmp_pattern(jatom, iatom, mj_img) = 1
    4004             :                   ELSE
    4005        6514 :                      bins(iatom) = bins(iatom) + 1
    4006        6514 :                      tmp_pattern(iatom, jatom, j_img) = 1
    4007             :                   END IF
    4008             :                END IF
    4009             :             END DO
    4010             :          END DO
    4011             :       END DO
    4012             : 
    4013             :       ! -1 => unoccupied, 0 => occupied
    4014       43922 :       pattern(:, :, :) = tmp_pattern(:, :, :) - 1
    4015             : 
    4016         512 :    END SUBROUTINE get_sparsity_pattern
    4017             : 
    4018             : ! **************************************************************************************************
    4019             : !> \brief Distribute the iatom, jatom, b_img triplet over the subgroupd to spread the load
    4020             : !>        the group id for each triplet is passed as the value of sparsity_pattern(i, j, b),
    4021             : !>        with -1 being an unoccupied block
    4022             : !> \param sparsity_pattern ...
    4023             : !> \param ngroups ...
    4024             : !> \param ri_data ...
    4025             : ! **************************************************************************************************
    4026         256 :    SUBROUTINE get_sub_dist(sparsity_pattern, ngroups, ri_data)
    4027             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: sparsity_pattern
    4028             :       INTEGER, INTENT(IN)                                :: ngroups
    4029             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4030             : 
    4031             :       INTEGER                                            :: b_img, ctr, iat, iatom, igroup, jatom, &
    4032             :                                                             natom, nimg, ub
    4033         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: max_at_per_group
    4034             :       REAL(dp)                                           :: cost
    4035         256 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4036             : 
    4037         256 :       natom = SIZE(sparsity_pattern, 2)
    4038         256 :       nimg = SIZE(sparsity_pattern, 3)
    4039             : 
    4040             :       !To avoid unnecessary data replication accross the subgroups, we want to have a limited number
    4041             :       !of subgroup with the data of a given iatom. At the minimum, all groups have 1 atom
    4042             :       !We assume that the cost associated to each iatom is roughly the same
    4043         256 :       IF (.NOT. ALLOCATED(ri_data%iatom_to_subgroup)) THEN
    4044         350 :          ALLOCATE (ri_data%iatom_to_subgroup(natom), max_at_per_group(ngroups))
    4045         150 :          DO iatom = 1, natom
    4046         100 :             NULLIFY (ri_data%iatom_to_subgroup(iatom)%array)
    4047         200 :             ALLOCATE (ri_data%iatom_to_subgroup(iatom)%array(ngroups))
    4048         350 :             ri_data%iatom_to_subgroup(iatom)%array(:) = .FALSE.
    4049             :          END DO
    4050             : 
    4051          50 :          ub = natom/ngroups
    4052          50 :          IF (ub*ngroups < natom) ub = ub + 1
    4053         150 :          max_at_per_group(:) = MAX(1, ub)
    4054             : 
    4055             :          !We want each atom to be present the same amount of times. Some groups might have more atoms
    4056             :          !than other to achieve this.
    4057             :          ctr = 0
    4058         150 :          DO WHILE (MODULO(SUM(max_at_per_group), natom) .NE. 0)
    4059           0 :             igroup = MODULO(ctr, ngroups) + 1
    4060           0 :             max_at_per_group(igroup) = max_at_per_group(igroup) + 1
    4061          50 :             ctr = ctr + 1
    4062             :          END DO
    4063             : 
    4064             :          ctr = 0
    4065         150 :          DO igroup = 1, ngroups
    4066         250 :             DO iat = 1, max_at_per_group(igroup)
    4067         100 :                iatom = MODULO(ctr, natom) + 1
    4068         100 :                ri_data%iatom_to_subgroup(iatom)%array(igroup) = .TRUE.
    4069         200 :                ctr = ctr + 1
    4070             :             END DO
    4071             :          END DO
    4072             :       END IF
    4073             : 
    4074         768 :       ALLOCATE (bins(ngroups))
    4075         768 :       bins = 0.0_dp
    4076        6494 :       DO b_img = 1, nimg
    4077       18970 :          DO jatom = 1, natom
    4078       43666 :             DO iatom = 1, natom
    4079       24952 :                IF (sparsity_pattern(iatom, jatom, b_img) == -1) CYCLE
    4080       41870 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4081             : 
    4082             :                !Use cost information from previous SCF if available
    4083      533682 :                IF (ANY(ri_data%kp_cost > EPSILON(0.0_dp))) THEN
    4084        6152 :                   cost = ri_data%kp_cost(iatom, jatom, b_img)
    4085             :                ELSE
    4086        2222 :                   cost = REAL(ri_data%bsizes_AO(iatom)*ri_data%bsizes_AO(jatom), dp)
    4087             :                END IF
    4088        8374 :                bins(igroup + 1) = bins(igroup + 1) + cost
    4089       37428 :                sparsity_pattern(iatom, jatom, b_img) = igroup
    4090             :             END DO
    4091             :          END DO
    4092             :       END DO
    4093             : 
    4094         256 :    END SUBROUTINE get_sub_dist
    4095             : 
    4096             : ! **************************************************************************************************
    4097             : !> \brief A rouine that updates the sparsity pattern for force calculation, where all i,j,b combinations
    4098             : !>        are visited.
    4099             : !> \param force_pattern ...
    4100             : !> \param scf_pattern ...
    4101             : !> \param ngroups ...
    4102             : !> \param ri_data ...
    4103             : !> \param qs_env ...
    4104             : ! **************************************************************************************************
    4105          42 :    SUBROUTINE update_pattern_to_forces(force_pattern, scf_pattern, ngroups, ri_data, qs_env)
    4106             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: force_pattern, scf_pattern
    4107             :       INTEGER, INTENT(IN)                                :: ngroups
    4108             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4109             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4110             : 
    4111             :       INTEGER                                            :: b_img, iatom, igroup, jatom, mb_img, &
    4112             :                                                             natom, nimg
    4113          42 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4114             : 
    4115          42 :       natom = SIZE(scf_pattern, 2)
    4116          42 :       nimg = SIZE(scf_pattern, 3)
    4117             : 
    4118         126 :       ALLOCATE (bins(ngroups))
    4119         126 :       bins = 0.0_dp
    4120             : 
    4121         980 :       DO b_img = 1, nimg
    4122         938 :          mb_img = get_opp_index(b_img, qs_env)
    4123        2856 :          DO jatom = 1, natom
    4124        6566 :             DO iatom = 1, natom
    4125             :                !Important: same distribution as KS matrix, because reuse t_3c_apc
    4126       18760 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4127             : 
    4128             :                !check that block not already treated
    4129        3752 :                IF (scf_pattern(iatom, jatom, b_img) > -1) CYCLE
    4130             : 
    4131             :                !If not, take the cost of block j, i, -b (same energy contribution)
    4132        4466 :                IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
    4133        2210 :                   IF (scf_pattern(jatom, iatom, mb_img) == -1) CYCLE
    4134        1042 :                   bins(igroup + 1) = bins(igroup + 1) + ri_data%kp_cost(jatom, iatom, mb_img)
    4135        1042 :                   force_pattern(iatom, jatom, b_img) = igroup
    4136             :                END IF
    4137             :             END DO
    4138             :          END DO
    4139             :       END DO
    4140             : 
    4141          42 :    END SUBROUTINE update_pattern_to_forces
    4142             : 
    4143             : ! **************************************************************************************************
    4144             : !> \brief A routine that determines the extend of the KP RI-HFX periodic images, including for the
    4145             : !>        extension of the RI basis
    4146             : !> \param ri_data ...
    4147             : !> \param qs_env ...
    4148             : ! **************************************************************************************************
    4149          70 :    SUBROUTINE get_kp_and_ri_images(ri_data, qs_env)
    4150             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4151             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4152             : 
    4153             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_kp_and_ri_images'
    4154             : 
    4155             :       INTEGER :: cell_j(3), cell_k(3), handle, i_img, iatom, ikind, j_img, jatom, jcell, katom, &
    4156             :          kcell, kp_index_lbounds(3), kp_index_ubounds(3), natom, ngroups, nimg, nkind, pcoord(3), &
    4157             :          pdims(3)
    4158          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist_AO_1, dist_AO_2, dist_RI, &
    4159          70 :                                                             nRI_per_atom, present_img, RI_cells
    4160          70 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    4161             :       REAL(dp)                                           :: bump_fact, dij, dik, image_range, &
    4162             :                                                             RI_range, rij(3), rik(3)
    4163         490 :       TYPE(dbt_type)                                     :: t_dummy
    4164             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4165             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4166             :       TYPE(distribution_3d_type)                         :: dist_3d
    4167             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4168          70 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4169             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4170          70 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4171             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4172             :       TYPE(neighbor_list_3c_iterator_type)               :: nl_3c_iter
    4173             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4174             :       TYPE(neighbor_list_iterator_p_type), &
    4175          70 :          DIMENSION(:), POINTER                           :: nl_iterator
    4176             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4177          70 :          POINTER                                         :: nl_2c
    4178          70 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4179          70 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4180             :       TYPE(section_vals_type), POINTER                   :: hfx_section
    4181             : 
    4182          70 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, nl_iterator, dft_control, &
    4183          70 :                particle_set, kpoints, para_env, cell_to_index, hfx_section)
    4184             : 
    4185          70 :       CALL timeset(routineN, handle)
    4186             : 
    4187             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, &
    4188             :                       dft_control=dft_control, particle_set=particle_set, kpoints=kpoints, &
    4189          70 :                       para_env=para_env, natom=natom)
    4190          70 :       nimg = dft_control%nimages
    4191          70 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index)
    4192         280 :       kp_index_lbounds = LBOUND(cell_to_index)
    4193         280 :       kp_index_ubounds = UBOUND(cell_to_index)
    4194             : 
    4195          70 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
    4196          70 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
    4197             : 
    4198         496 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4199          70 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4200          70 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4201             : 
    4202             :       !In case of shortrange HFX potential, it is imprtant to be consistent with the rest of the KP
    4203             :       !code, and use EPS_SCHWARZ to determine the range (rather than eps_filter_2c in normal RI-HFX)
    4204          70 :       IF (ri_data%hfx_pot%potential_type == do_potential_short) THEN
    4205           0 :          CALL erfc_cutoff(ri_data%eps_schwarz, ri_data%hfx_pot%omega, ri_data%hfx_pot%cutoff_radius)
    4206             :       END IF
    4207             : 
    4208             :       !Determine the range for contributing periodic images, and for the RI basis extension
    4209          70 :       ri_data%kp_RI_range = 0.0_dp
    4210          70 :       ri_data%kp_image_range = 0.0_dp
    4211         178 :       DO ikind = 1, nkind
    4212             : 
    4213         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4214         108 :          CALL get_gto_basis_set(basis_set_AO(ikind)%gto_basis_set, kind_radius=RI_range)
    4215         108 :          ri_data%kp_RI_range = MAX(RI_range, ri_data%kp_RI_range)
    4216             : 
    4217         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4218         108 :          CALL init_interaction_radii_orb_basis(basis_set_RI(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4219         108 :          CALL get_gto_basis_set(basis_set_RI(ikind)%gto_basis_set, kind_radius=image_range)
    4220             : 
    4221         108 :          image_range = 2.0_dp*image_range + cutoff_screen_factor*ri_data%hfx_pot%cutoff_radius
    4222         286 :          ri_data%kp_image_range = MAX(image_range, ri_data%kp_image_range)
    4223             :       END DO
    4224             : 
    4225          70 :       CALL section_vals_val_get(hfx_section, "KP_RI_BUMP_FACTOR", r_val=bump_fact)
    4226          70 :       ri_data%kp_bump_rad = bump_fact*ri_data%kp_RI_range
    4227             : 
    4228             :       !For the extent of the KP RI-HFX images, we are limited by the RI-HFX potential in
    4229             :       !(mu^0 sigma^a|P^0) (P^0|Q^b) (Q^b|nu^b lambda^a+c), if there is no contact between
    4230             :       !any P^0 and Q^b, then image b does not contribute
    4231             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4232          70 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4233             : 
    4234         210 :       ALLOCATE (present_img(nimg))
    4235        3120 :       present_img = 0
    4236          70 :       ri_data%nimg = 0
    4237          70 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    4238        1568 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    4239        1498 :          CALL get_iterator_info(nl_iterator, r=rij, cell=cell_j)
    4240             : 
    4241        5992 :          dij = NORM2(rij)
    4242             : 
    4243        1498 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4244        1498 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    4245             : 
    4246        1466 :          IF (dij > ri_data%kp_image_range) CYCLE
    4247             : 
    4248        1466 :          ri_data%nimg = MAX(j_img, ri_data%nimg)
    4249        1498 :          present_img(j_img) = 1
    4250             : 
    4251             :       END DO
    4252          70 :       CALL neighbor_list_iterator_release(nl_iterator)
    4253          70 :       CALL release_neighbor_list_sets(nl_2c)
    4254          70 :       CALL para_env%max(ri_data%nimg)
    4255          70 :       IF (ri_data%nimg > nimg) &
    4256           0 :          CPABORT("Make sure the smallest exponent of the RI-HFX basis is larger than that of the ORB basis.")
    4257             : 
    4258             :       !Keep track of which images will not contribute, so that can be ignored before calculation
    4259          70 :       CALL para_env%sum(present_img)
    4260         210 :       ALLOCATE (ri_data%present_images(ri_data%nimg))
    4261        1714 :       ri_data%present_images = 0
    4262        1714 :       DO i_img = 1, ri_data%nimg
    4263        1714 :          IF (present_img(i_img) > 0) ri_data%present_images(i_img) = 1
    4264             :       END DO
    4265             : 
    4266             :       CALL create_3c_tensor(t_dummy, dist_AO_1, dist_AO_2, dist_RI, &
    4267             :                             ri_data%pgrid, ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4268          70 :                             map1=[1, 2], map2=[3], name="(AO AO | RI)")
    4269             : 
    4270          70 :       CALL dbt_mp_environ_pgrid(ri_data%pgrid, pdims, pcoord)
    4271          70 :       CALL mp_comm_t3c%create(ri_data%pgrid%mp_comm_2d, 3, pdims)
    4272             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4273          70 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4274          70 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4275          70 :       CALL dbt_destroy(t_dummy)
    4276             : 
    4277             :       !For the extension of the RI basis P in (mu^0 sigma^a |P^i), we consider an atom if the distance,
    4278             :       !between mu^0 and P^i if smaller or equal to the kind radius of mu^0
    4279             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, &
    4280             :                                    ri_data%ri_metric, "HFX_3c_nl", qs_env, op_pos=2, sym_ij=.FALSE., &
    4281          70 :                                    own_dist=.TRUE.)
    4282             : 
    4283         140 :       ALLOCATE (RI_cells(nimg))
    4284        3120 :       RI_cells = 0
    4285             : 
    4286         210 :       ALLOCATE (nRI_per_atom(natom))
    4287         210 :       nRI_per_atom = 0
    4288             : 
    4289          70 :       CALL neighbor_list_3c_iterator_create(nl_3c_iter, nl_3c)
    4290       58584 :       DO WHILE (neighbor_list_3c_iterate(nl_3c_iter) == 0)
    4291             :          CALL get_3c_iterator_info(nl_3c_iter, cell_k=cell_k, rik=rik, cell_j=cell_j, &
    4292       58514 :                                    iatom=iatom, jatom=jatom, katom=katom)
    4293      234056 :          dik = NORM2(rik)
    4294             : 
    4295      409598 :          IF (ANY([cell_j(1), cell_j(2), cell_j(3)] < kp_index_lbounds) .OR. &
    4296             :              ANY([cell_j(1), cell_j(2), cell_j(3)] > kp_index_ubounds)) CYCLE
    4297             : 
    4298       58514 :          jcell = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4299       58514 :          IF (jcell > nimg .OR. jcell < 1) CYCLE
    4300             : 
    4301      386199 :          IF (ANY([cell_k(1), cell_k(2), cell_k(3)] < kp_index_lbounds) .OR. &
    4302             :              ANY([cell_k(1), cell_k(2), cell_k(3)] > kp_index_ubounds)) CYCLE
    4303             : 
    4304       51245 :          kcell = cell_to_index(cell_k(1), cell_k(2), cell_k(3))
    4305       51245 :          IF (kcell > nimg .OR. kcell < 1) CYCLE
    4306             : 
    4307       43587 :          IF (dik > ri_data%kp_RI_range) CYCLE
    4308        5751 :          RI_cells(kcell) = 1
    4309             : 
    4310        5821 :          IF (jcell == 1 .AND. iatom == jatom) nRI_per_atom(iatom) = nRI_per_atom(iatom) + ri_data%bsizes_RI(katom)
    4311             :       END DO
    4312          70 :       CALL neighbor_list_3c_iterator_destroy(nl_3c_iter)
    4313          70 :       CALL neighbor_list_3c_destroy(nl_3c)
    4314          70 :       CALL para_env%sum(RI_cells)
    4315          70 :       CALL para_env%sum(nRI_per_atom)
    4316             : 
    4317         140 :       ALLOCATE (ri_data%img_to_RI_cell(nimg))
    4318          70 :       ri_data%ncell_RI = 0
    4319        3120 :       ri_data%img_to_RI_cell = 0
    4320        3120 :       DO i_img = 1, nimg
    4321        3120 :          IF (RI_cells(i_img) > 0) THEN
    4322         436 :             ri_data%ncell_RI = ri_data%ncell_RI + 1
    4323         436 :             ri_data%img_to_RI_cell(i_img) = ri_data%ncell_RI
    4324             :          END IF
    4325             :       END DO
    4326             : 
    4327         210 :       ALLOCATE (ri_data%RI_cell_to_img(ri_data%ncell_RI))
    4328        3120 :       DO i_img = 1, nimg
    4329        3120 :          IF (ri_data%img_to_RI_cell(i_img) > 0) ri_data%RI_cell_to_img(ri_data%img_to_RI_cell(i_img)) = i_img
    4330             :       END DO
    4331             : 
    4332             :       !Print some info
    4333          70 :       IF (ri_data%unit_nr > 0) THEN
    4334             :          WRITE (ri_data%unit_nr, FMT="(/T3,A,I29)") &
    4335          35 :             "KP-HFX_RI_INFO| Number of RI-KP parallel groups:", ngroups
    4336             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F31.3,A)") &
    4337          35 :             "KP-HFX_RI_INFO| RI basis extension radius:", ri_data%kp_RI_range*angstrom, " Ang"
    4338             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F12.3,A, F6.3, A)") &
    4339          35 :             "KP-HFX_RI_INFO| RI basis bump factor and bump radius:", bump_fact, " /", &
    4340          70 :             ri_data%kp_bump_rad*angstrom, " Ang"
    4341             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I16,A)") &
    4342          35 :             "KP-HFX_RI_INFO| The extended RI bases cover up to ", ri_data%ncell_RI, " unit cells"
    4343             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I18)") &
    4344         105 :             "KP-HFX_RI_INFO| Average number of sgf in extended RI bases:", SUM(nRI_per_atom)/natom
    4345             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F13.3,A)") &
    4346          35 :             "KP-HFX_RI_INFO| Consider all image cells within a radius of ", ri_data%kp_image_range*angstrom, " Ang"
    4347             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I27/)") &
    4348          35 :             "KP-HFX_RI_INFO| Number of image cells considered: ", ri_data%nimg
    4349          35 :          CALL m_flush(ri_data%unit_nr)
    4350             :       END IF
    4351             : 
    4352          70 :       CALL timestop(handle)
    4353             : 
    4354         840 :    END SUBROUTINE get_kp_and_ri_images
    4355             : 
    4356             : ! **************************************************************************************************
    4357             : !> \brief A routine that creates tensors structure for rho_ao and 3c_ints in a stacked format for
    4358             : !>        the efficient contractions of rho_sigma^0,lambda^c * (mu^0 sigam^a | P) => TAS tensors
    4359             : !> \param res_stack ...
    4360             : !> \param rho_stack ...
    4361             : !> \param ints_stack ...
    4362             : !> \param rho_template ...
    4363             : !> \param ints_template ...
    4364             : !> \param stack_size ...
    4365             : !> \param ri_data ...
    4366             : !> \param qs_env ...
    4367             : !> \note The result tensor has the exact same shape and distribution as the integral tensor
    4368             : ! **************************************************************************************************
    4369         256 :    SUBROUTINE get_stack_tensors(res_stack, rho_stack, ints_stack, rho_template, ints_template, &
    4370             :                                 stack_size, ri_data, qs_env)
    4371             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: res_stack, rho_stack, ints_stack
    4372             :       TYPE(dbt_type), INTENT(INOUT)                      :: rho_template, ints_template
    4373             :       INTEGER, INTENT(IN)                                :: stack_size
    4374             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4375             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4376             : 
    4377             :       INTEGER                                            :: is, nblks, nblks_3c(3), pdims_3d(3)
    4378         256 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_stack, dist1, &
    4379         256 :                                                             dist2, dist3, dist_stack1, &
    4380         256 :                                                             dist_stack2, dist_stack3
    4381        2304 :       TYPE(dbt_distribution_type)                        :: t_dist
    4382         768 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4383             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4384             : 
    4385         256 :       NULLIFY (para_env)
    4386             : 
    4387         256 :       CALL get_qs_env(qs_env, para_env=para_env)
    4388             : 
    4389         256 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4390         768 :       ALLOCATE (bsizes_stack(stack_size*nblks))
    4391        3314 :       DO is = 1, stack_size
    4392       12476 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    4393             :       END DO
    4394             : 
    4395        2304 :       ALLOCATE (dist1(nblks), dist2(nblks), dist_stack1(stack_size*nblks), dist_stack2(stack_size*nblks))
    4396         256 :       CALL dbt_get_info(rho_template, proc_dist_1=dist1, proc_dist_2=dist2)
    4397        3314 :       DO is = 1, stack_size
    4398       12220 :          dist_stack1((is - 1)*nblks + 1:is*nblks) = dist1(:)
    4399       12476 :          dist_stack2((is - 1)*nblks + 1:is*nblks) = dist2(:)
    4400             :       END DO
    4401             : 
    4402             :       !First 2c tensor matches the distribution of template
    4403             :       !It is stacked in both directions
    4404         256 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_2d, dist_stack1, dist_stack2)
    4405         256 :       CALL dbt_create(rho_stack(1), "RHO_stack", t_dist, [1], [2], bsizes_stack, bsizes_stack)
    4406         256 :       CALL dbt_distribution_destroy(t_dist)
    4407         256 :       DEALLOCATE (dist1, dist2, dist_stack1, dist_stack2)
    4408             : 
    4409             :       !Second 2c tensor has optimal distribution on the 2d pgrid
    4410         256 :       CALL create_2c_tensor(rho_stack(2), dist1, dist2, ri_data%pgrid_2d, bsizes_stack, bsizes_stack, name="RHO_stack")
    4411         256 :       DEALLOCATE (dist1, dist2)
    4412             : 
    4413         256 :       CALL dbt_get_info(ints_template, nblks_total=nblks_3c)
    4414        1792 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)))
    4415        1280 :       ALLOCATE (dist_stack3(stack_size*nblks_3c(3)), bsizes_RI_ext(nblks_3c(2)))
    4416             :       CALL dbt_get_info(ints_template, proc_dist_1=dist1, proc_dist_2=dist2, &
    4417         256 :                         proc_dist_3=dist3, blk_size_2=bsizes_RI_ext)
    4418        3314 :       DO is = 1, stack_size
    4419       12476 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    4420             :       END DO
    4421             : 
    4422             :       !First 3c tensor matches the distribution of template
    4423         256 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_1, dist1, dist2, dist_stack3)
    4424             :       CALL dbt_create(ints_stack(1), "ints_stack", t_dist, [1, 2], [3], ri_data%bsizes_AO_split, &
    4425         256 :                       bsizes_RI_ext, bsizes_stack)
    4426         256 :       CALL dbt_distribution_destroy(t_dist)
    4427         256 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    4428             : 
    4429             :       !Second 3c tensor has optimal pgrid
    4430         256 :       pdims_3d = 0
    4431        1024 :       CALL dbt_pgrid_create(para_env, pdims_3d, pgrid, tensor_dims=[nblks_3c(1), nblks_3c(2), stack_size*nblks_3c(3)])
    4432             :       CALL create_3c_tensor(ints_stack(2), dist1, dist2, dist3, pgrid, ri_data%bsizes_AO_split, &
    4433         256 :                             bsizes_RI_ext, bsizes_stack, [1, 2], [3], name="ints_stack")
    4434         256 :       DEALLOCATE (dist1, dist2, dist3)
    4435         256 :       CALL dbt_pgrid_destroy(pgrid)
    4436             : 
    4437             :       !The result tensor has the same shape and dist as the integral tensor
    4438         256 :       CALL dbt_create(ints_stack(1), res_stack(1))
    4439         256 :       CALL dbt_create(ints_stack(2), res_stack(2))
    4440             : 
    4441         512 :    END SUBROUTINE get_stack_tensors
    4442             : 
    4443             : ! **************************************************************************************************
    4444             : !> \brief Fill the stack of 3c tensors accrding to the order in the images input
    4445             : !> \param t_3c_stack ...
    4446             : !> \param t_3c_in ...
    4447             : !> \param images ...
    4448             : !> \param stack_dim ...
    4449             : !> \param ri_data ...
    4450             : !> \param filter_at ...
    4451             : !> \param filter_dim ...
    4452             : !> \param idx_to_at ...
    4453             : !> \param img_bounds ...
    4454             : ! **************************************************************************************************
    4455       21893 :    SUBROUTINE fill_3c_stack(t_3c_stack, t_3c_in, images, stack_dim, ri_data, filter_at, filter_dim, &
    4456       21893 :                             idx_to_at, img_bounds)
    4457             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_stack
    4458             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_in
    4459             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4460             :       INTEGER, INTENT(IN)                                :: stack_dim
    4461             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4462             :       INTEGER, INTENT(IN), OPTIONAL                      :: filter_at, filter_dim
    4463             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: idx_to_at
    4464             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2)
    4465             : 
    4466             :       INTEGER                                            :: dest(3), i_img, idx, ind(3), lb, nblks, &
    4467             :                                                             nimg, offset, ub
    4468             :       LOGICAL                                            :: do_filter, found
    4469       21893 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4470             :       TYPE(dbt_iterator_type)                            :: iter
    4471             : 
    4472             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4473             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4474       21893 :       nimg = ri_data%nimg
    4475       21893 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4476             : 
    4477       21893 :       do_filter = .FALSE.
    4478       21319 :       IF (PRESENT(filter_at) .AND. PRESENT(filter_dim) .AND. PRESENT(idx_to_at)) do_filter = .TRUE.
    4479             : 
    4480       21893 :       lb = 1
    4481       21893 :       ub = nimg
    4482       21893 :       offset = 0
    4483       21893 :       IF (PRESENT(img_bounds)) THEN
    4484       21893 :          lb = img_bounds(1)
    4485       21893 :          ub = img_bounds(2) - 1
    4486       21893 :          offset = lb - 1
    4487             :       END IF
    4488             : 
    4489      433835 :       DO idx = lb, ub
    4490      411942 :          i_img = images(idx)
    4491      411942 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4492             : 
    4493             : !$OMP PARALLEL DEFAULT(NONE) &
    4494             : !$OMP SHARED(idx,i_img,t_3c_in,t_3c_stack,nblks,stack_dim,filter_at,filter_dim,idx_to_at,do_filter,offset) &
    4495      433835 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4496             :          CALL dbt_iterator_start(iter, t_3c_in(i_img))
    4497             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4498             :             CALL dbt_iterator_next_block(iter, ind)
    4499             :             CALL dbt_get_block(t_3c_in(i_img), ind, blk, found)
    4500             :             IF (.NOT. found) CYCLE
    4501             : 
    4502             :             IF (do_filter) THEN
    4503             :                IF (.NOT. idx_to_at(ind(filter_dim)) == filter_at) CYCLE
    4504             :             END IF
    4505             : 
    4506             :             IF (stack_dim == 1) THEN
    4507             :                dest = [(idx - offset - 1)*nblks + ind(1), ind(2), ind(3)]
    4508             :             ELSE IF (stack_dim == 2) THEN
    4509             :                dest = [ind(1), (idx - offset - 1)*nblks + ind(2), ind(3)]
    4510             :             ELSE
    4511             :                dest = [ind(1), ind(2), (idx - offset - 1)*nblks + ind(3)]
    4512             :             END IF
    4513             : 
    4514             :             CALL dbt_put_block(t_3c_stack, dest, SHAPE(blk), blk)
    4515             :             DEALLOCATE (blk)
    4516             :          END DO
    4517             :          CALL dbt_iterator_stop(iter)
    4518             : !$OMP END PARALLEL
    4519             :       END DO !i_img
    4520       21893 :       CALL dbt_finalize(t_3c_stack)
    4521             : 
    4522       43786 :    END SUBROUTINE fill_3c_stack
    4523             : 
    4524             : ! **************************************************************************************************
    4525             : !> \brief Fill the stack of 2c tensors based on the content of images input
    4526             : !> \param t_2c_stack ...
    4527             : !> \param t_2c_in ...
    4528             : !> \param images ...
    4529             : !> \param stack_dim ...
    4530             : !> \param ri_data ...
    4531             : !> \param img_bounds ...
    4532             : !> \param shift ...
    4533             : ! **************************************************************************************************
    4534       16412 :    SUBROUTINE fill_2c_stack(t_2c_stack, t_2c_in, images, stack_dim, ri_data, img_bounds, shift)
    4535             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_stack
    4536             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_in
    4537             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4538             :       INTEGER, INTENT(IN)                                :: stack_dim
    4539             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4540             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2), shift
    4541             : 
    4542             :       INTEGER                                            :: dest(2), i_img, idx, ind(2), lb, &
    4543             :                                                             my_shift, nblks, nimg, offset, ub
    4544             :       LOGICAL                                            :: found
    4545       16412 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    4546             :       TYPE(dbt_iterator_type)                            :: iter
    4547             : 
    4548             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4549             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4550       16412 :       nimg = ri_data%nimg
    4551       16412 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4552             : 
    4553       16412 :       lb = 1
    4554       16412 :       ub = nimg
    4555       16412 :       offset = 0
    4556       16412 :       IF (PRESENT(img_bounds)) THEN
    4557       16412 :          lb = img_bounds(1)
    4558       16412 :          ub = img_bounds(2) - 1
    4559       16412 :          offset = lb - 1
    4560             :       END IF
    4561             : 
    4562       16412 :       my_shift = 1
    4563       16412 :       IF (PRESENT(shift)) my_shift = shift
    4564             : 
    4565      204304 :       DO idx = lb, ub
    4566      187892 :          i_img = images(idx)
    4567      187892 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4568             : 
    4569             : !$OMP PARALLEL DEFAULT(NONE) SHARED(idx,i_img,t_2c_in,t_2c_stack,nblks,stack_dim,offset,my_shift) &
    4570      204304 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4571             :          CALL dbt_iterator_start(iter, t_2c_in(i_img))
    4572             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4573             :             CALL dbt_iterator_next_block(iter, ind)
    4574             :             CALL dbt_get_block(t_2c_in(i_img), ind, blk, found)
    4575             :             IF (.NOT. found) CYCLE
    4576             : 
    4577             :             IF (stack_dim == 1) THEN
    4578             :                dest = [(idx - offset - 1)*nblks + ind(1), (my_shift - 1)*nblks + ind(2)]
    4579             :             ELSE
    4580             :                dest = [(my_shift - 1)*nblks + ind(1), (idx - offset - 1)*nblks + ind(2)]
    4581             :             END IF
    4582             : 
    4583             :             CALL dbt_put_block(t_2c_stack, dest, SHAPE(blk), blk)
    4584             :             DEALLOCATE (blk)
    4585             :          END DO
    4586             :          CALL dbt_iterator_stop(iter)
    4587             : !$OMP END PARALLEL
    4588             :       END DO !idx
    4589       16412 :       CALL dbt_finalize(t_2c_stack)
    4590             : 
    4591       32824 :    END SUBROUTINE fill_2c_stack
    4592             : 
    4593             : ! **************************************************************************************************
    4594             : !> \brief Unstacks a stacked 3c tensor containing t_3c_apc
    4595             : !> \param t_3c_apc ...
    4596             : !> \param t_stacked ...
    4597             : !> \param idx ...
    4598             : ! **************************************************************************************************
    4599       18838 :    SUBROUTINE unstack_t_3c_apc(t_3c_apc, t_stacked, idx)
    4600             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_apc, t_stacked
    4601             :       INTEGER, INTENT(IN)                                :: idx
    4602             : 
    4603             :       INTEGER                                            :: current_idx
    4604             :       INTEGER, DIMENSION(3)                              :: ind, nblks_3c
    4605             :       LOGICAL                                            :: found
    4606       18838 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4607             :       TYPE(dbt_iterator_type)                            :: iter
    4608             : 
    4609             :       !Note: t_3c_apc and t_stacked must have the same ditribution
    4610       18838 :       CALL dbt_get_info(t_3c_apc, nblks_total=nblks_3c)
    4611             : 
    4612       18838 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_apc,t_stacked,idx,nblks_3c) PRIVATE(iter,ind,blk,found,current_idx)
    4613             :       CALL dbt_iterator_start(iter, t_stacked)
    4614             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4615             :          CALL dbt_iterator_next_block(iter, ind)
    4616             : 
    4617             :          !tensor is stacked along the 3rd dimension
    4618             :          current_idx = (ind(3) - 1)/nblks_3c(3) + 1
    4619             :          IF (.NOT. idx == current_idx) CYCLE
    4620             : 
    4621             :          CALL dbt_get_block(t_stacked, ind, blk, found)
    4622             :          IF (.NOT. found) CYCLE
    4623             : 
    4624             :          CALL dbt_put_block(t_3c_apc, [ind(1), ind(2), ind(3) - (idx - 1)*nblks_3c(3)], SHAPE(blk), blk)
    4625             :          DEALLOCATE (blk)
    4626             :       END DO
    4627             :       CALL dbt_iterator_stop(iter)
    4628             : !$OMP END PARALLEL
    4629             : 
    4630       18838 :    END SUBROUTINE unstack_t_3c_apc
    4631             : 
    4632             : ! **************************************************************************************************
    4633             : !> \brief copies the 3c integrals correspoinding to a single atom mu from the general (P^0| mu^0 sigam^a)
    4634             : !> \param t_3c_at ...
    4635             : !> \param t_3c_ints ...
    4636             : !> \param iatom ...
    4637             : !> \param dim_at ...
    4638             : !> \param idx_to_at ...
    4639             : ! **************************************************************************************************
    4640           0 :    SUBROUTINE get_atom_3c_ints(t_3c_at, t_3c_ints, iatom, dim_at, idx_to_at)
    4641             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_at, t_3c_ints
    4642             :       INTEGER, INTENT(IN)                                :: iatom, dim_at
    4643             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: idx_to_at
    4644             : 
    4645             :       INTEGER, DIMENSION(3)                              :: ind
    4646             :       LOGICAL                                            :: found
    4647           0 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4648             :       TYPE(dbt_iterator_type)                            :: iter
    4649             : 
    4650           0 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_ints,t_3c_at,iatom,idx_to_at,dim_at) PRIVATE(iter,ind,blk,found)
    4651             :       CALL dbt_iterator_start(iter, t_3c_ints)
    4652             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4653             :          CALL dbt_iterator_next_block(iter, ind)
    4654             :          IF (.NOT. idx_to_at(ind(dim_at)) == iatom) CYCLE
    4655             : 
    4656             :          CALL dbt_get_block(t_3c_ints, ind, blk, found)
    4657             :          IF (.NOT. found) CYCLE
    4658             : 
    4659             :          CALL dbt_put_block(t_3c_at, ind, SHAPE(blk), blk)
    4660             :          DEALLOCATE (blk)
    4661             :       END DO
    4662             :       CALL dbt_iterator_stop(iter)
    4663             : !$OMP END PARALLEL
    4664           0 :       CALL dbt_finalize(t_3c_at)
    4665             : 
    4666           0 :    END SUBROUTINE get_atom_3c_ints
    4667             : 
    4668             : ! **************************************************************************************************
    4669             : !> \brief Precalculate the 3c and 2c derivatives tensors
    4670             : !> \param t_3c_der_RI ...
    4671             : !> \param t_3c_der_AO ...
    4672             : !> \param mat_der_pot ...
    4673             : !> \param t_2c_der_metric ...
    4674             : !> \param ri_data ...
    4675             : !> \param qs_env ...
    4676             : ! **************************************************************************************************
    4677          42 :    SUBROUTINE precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
    4678             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_RI, t_3c_der_AO
    4679             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot
    4680             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_2c_der_metric
    4681             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4682             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4683             : 
    4684             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'precalc_derivatives'
    4685             : 
    4686             :       INTEGER                                            :: handle, handle2, i_img, i_mem, i_RI, &
    4687             :                                                             i_xyz, iatom, n_mem, natom, nblks_RI, &
    4688             :                                                             ncell_RI, nimg, nkind, nthreads
    4689             :       INTEGER(int_8)                                     :: nze
    4690          42 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: bsizes_RI_ext, bsizes_RI_ext_split, dist_AO_1, &
    4691          84 :          dist_AO_2, dist_RI, dist_RI_ext, dummy_end, dummy_start, end_blocks, start_blocks
    4692             :       INTEGER, DIMENSION(3)                              :: pcoord, pdims
    4693          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_bsize, row_bsize
    4694             :       REAL(dp)                                           :: occ
    4695             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist
    4696             :       TYPE(dbcsr_type)                                   :: dbcsr_template
    4697          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_metric
    4698         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    4699         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4700         378 :       TYPE(dbt_type)                                     :: t_3c_template
    4701          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :, :)    :: t_3c_der_AO_prv, t_3c_der_RI_prv
    4702             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4703             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4704             :       TYPE(distribution_3d_type)                         :: dist_3d
    4705             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4706          42 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4707          42 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4708             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4709             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4710             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4711          42 :          POINTER                                         :: nl_2c
    4712          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4713          42 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4714             : 
    4715          42 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, particle_set, dft_control, para_env, row_bsize, col_bsize)
    4716             : 
    4717          42 :       CALL timeset(routineN, handle)
    4718             : 
    4719             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, natom=natom, &
    4720          42 :                       particle_set=particle_set, dft_control=dft_control, para_env=para_env)
    4721             : 
    4722          42 :       nimg = ri_data%nimg
    4723          42 :       ncell_RI = ri_data%ncell_RI
    4724             : 
    4725         300 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4726          42 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4727          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_RI)
    4728          42 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4729          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_AO)
    4730             : 
    4731             :       !Dealing with the 3c derivatives
    4732          42 :       nthreads = 1
    4733          42 : !$    nthreads = omp_get_num_threads()
    4734          42 :       pdims = 0
    4735         168 :       CALL dbt_pgrid_create(para_env, pdims, pgrid, tensor_dims=[MAX(1, natom/(ri_data%n_mem*nthreads)), natom, natom])
    4736             : 
    4737             :       CALL create_3c_tensor(t_3c_template, dist_AO_1, dist_AO_2, dist_RI, pgrid, &
    4738             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4739          42 :                             map1=[1, 2], map2=[3], name="tmp")
    4740          42 :       CALL dbt_destroy(t_3c_template)
    4741             : 
    4742             :       !We stack the RI basis images. Keep consistent distribution
    4743          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    4744         126 :       ALLOCATE (dist_RI_ext(natom*ncell_RI))
    4745          84 :       ALLOCATE (bsizes_RI_ext(natom*ncell_RI))
    4746         126 :       ALLOCATE (bsizes_RI_ext_split(nblks_RI*ncell_RI))
    4747         294 :       DO i_RI = 1, ncell_RI
    4748         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    4749         756 :          dist_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = dist_RI(:)
    4750        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    4751             :       END DO
    4752             : 
    4753          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist_AO_1, dist_AO_2, dist_RI_ext)
    4754             :       CALL dbt_create(t_3c_template, "KP_3c_der", t_dist, [1, 2], [3], &
    4755          42 :                       ri_data%bsizes_AO, ri_data%bsizes_AO, bsizes_RI_ext)
    4756          42 :       CALL dbt_distribution_destroy(t_dist)
    4757             : 
    4758        6972 :       ALLOCATE (t_3c_der_RI_prv(nimg, 1, 3), t_3c_der_AO_prv(nimg, 1, 3))
    4759         168 :       DO i_xyz = 1, 3
    4760        2982 :          DO i_img = 1, nimg
    4761        2814 :             CALL dbt_create(t_3c_template, t_3c_der_RI_prv(i_img, 1, i_xyz))
    4762        2940 :             CALL dbt_create(t_3c_template, t_3c_der_AO_prv(i_img, 1, i_xyz))
    4763             :          END DO
    4764             :       END DO
    4765          42 :       CALL dbt_destroy(t_3c_template)
    4766             : 
    4767          42 :       CALL dbt_mp_environ_pgrid(pgrid, pdims, pcoord)
    4768          42 :       CALL mp_comm_t3c%create(pgrid%mp_comm_2d, 3, pdims)
    4769             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4770          42 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4771          42 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4772          42 :       CALL dbt_pgrid_destroy(pgrid)
    4773             : 
    4774             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, ri_data%ri_metric, &
    4775          42 :                                    "HFX_3c_nl", qs_env, op_pos=2, sym_jk=.FALSE., own_dist=.TRUE.)
    4776             : 
    4777          42 :       n_mem = ri_data%n_mem
    4778             :       CALL create_tensor_batches(ri_data%bsizes_RI, n_mem, dummy_start, dummy_end, &
    4779             :                                  start_blocks, end_blocks)
    4780          42 :       DEALLOCATE (dummy_start, dummy_end)
    4781             : 
    4782             :       CALL create_3c_tensor(t_3c_template, dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid_2, &
    4783             :                             bsizes_RI_ext_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    4784          42 :                             map1=[1], map2=[2, 3], name="der (RI | AO AO)")
    4785         168 :       DO i_xyz = 1, 3
    4786        2982 :          DO i_img = 1, nimg
    4787        2814 :             CALL dbt_create(t_3c_template, t_3c_der_RI(i_img, i_xyz))
    4788        2940 :             CALL dbt_create(t_3c_template, t_3c_der_AO(i_img, i_xyz))
    4789             :          END DO
    4790             :       END DO
    4791             : 
    4792         116 :       DO i_mem = 1, n_mem
    4793             :          CALL build_3c_derivatives(t_3c_der_AO_prv, t_3c_der_RI_prv, ri_data%filter_eps, qs_env, &
    4794             :                                    nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, &
    4795             :                                    ri_data%ri_metric, der_eps=ri_data%eps_schwarz_forces, op_pos=2, &
    4796             :                                    do_kpoints=.TRUE., do_hfx_kpoints=.TRUE., &
    4797             :                                    bounds_k=[start_blocks(i_mem), end_blocks(i_mem)], &
    4798         222 :                                    RI_range=ri_data%kp_RI_range, img_to_RI_cell=ri_data%img_to_RI_cell)
    4799             : 
    4800          74 :          CALL timeset(routineN//"_cpy", handle2)
    4801             :          !We go from (mu^0 sigma^i | P^j) to (P^i| sigma^j mu^0) and finally to (P^i| mu^0 sigma^j)
    4802        1850 :          DO i_img = 1, nimg
    4803        7178 :             DO i_xyz = 1, 3
    4804             :                !derivative wrt to mu^0
    4805        5328 :                CALL get_tensor_occupancy(t_3c_der_AO_prv(i_img, 1, i_xyz), nze, occ)
    4806        5328 :                IF (nze > 0) THEN
    4807             :                   CALL dbt_copy(t_3c_der_AO_prv(i_img, 1, i_xyz), t_3c_template, &
    4808        3454 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4809        3454 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4810             :                   CALL dbt_copy(t_3c_template, t_3c_der_AO(i_img, i_xyz), &
    4811        3454 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4812             :                END IF
    4813             : 
    4814             :                !derivative wrt to P^i
    4815        5328 :                CALL get_tensor_occupancy(t_3c_der_RI_prv(i_img, 1, i_xyz), nze, occ)
    4816       12432 :                IF (nze > 0) THEN
    4817             :                   CALL dbt_copy(t_3c_der_RI_prv(i_img, 1, i_xyz), t_3c_template, &
    4818        3452 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4819        3452 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4820             :                   CALL dbt_copy(t_3c_template, t_3c_der_RI(i_img, i_xyz), &
    4821        3452 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4822             :                END IF
    4823             :             END DO
    4824             :          END DO
    4825         190 :          CALL timestop(handle2)
    4826             :       END DO
    4827          42 :       CALL dbt_destroy(t_3c_template)
    4828             : 
    4829          42 :       CALL neighbor_list_3c_destroy(nl_3c)
    4830         168 :       DO i_xyz = 1, 3
    4831        2982 :          DO i_img = 1, nimg
    4832        2814 :             CALL dbt_destroy(t_3c_der_RI_prv(i_img, 1, i_xyz))
    4833        2940 :             CALL dbt_destroy(t_3c_der_AO_prv(i_img, 1, i_xyz))
    4834             :          END DO
    4835             :       END DO
    4836        5670 :       DEALLOCATE (t_3c_der_RI_prv, t_3c_der_AO_prv)
    4837             : 
    4838             :       !Reorder 3c derivatives to be consistant with ints
    4839          42 :       CALL reorder_3c_derivs(t_3c_der_RI, ri_data)
    4840          42 :       CALL reorder_3c_derivs(t_3c_der_AO, ri_data)
    4841             : 
    4842          42 :       CALL timeset(routineN//"_2c", handle2)
    4843             :       !The 2-center derivatives
    4844          42 :       CALL cp_dbcsr_dist2d_to_dist(dist_2d, dbcsr_dist)
    4845         126 :       ALLOCATE (row_bsize(SIZE(ri_data%bsizes_RI)))
    4846          84 :       ALLOCATE (col_bsize(SIZE(ri_data%bsizes_RI)))
    4847         126 :       row_bsize(:) = ri_data%bsizes_RI
    4848         126 :       col_bsize(:) = ri_data%bsizes_RI
    4849             : 
    4850             :       CALL dbcsr_create(dbcsr_template, "2c_der", dbcsr_dist, dbcsr_type_no_symmetry, &
    4851          42 :                         row_bsize, col_bsize)
    4852          42 :       CALL dbcsr_distribution_release(dbcsr_dist)
    4853          42 :       DEALLOCATE (col_bsize, row_bsize)
    4854             : 
    4855        3066 :       ALLOCATE (mat_der_metric(nimg, 3))
    4856         168 :       DO i_xyz = 1, 3
    4857        2982 :          DO i_img = 1, nimg
    4858        2814 :             CALL dbcsr_create(mat_der_pot(i_img, i_xyz), template=dbcsr_template)
    4859        2940 :             CALL dbcsr_create(mat_der_metric(i_img, i_xyz), template=dbcsr_template)
    4860             :          END DO
    4861             :       END DO
    4862          42 :       CALL dbcsr_release(dbcsr_template)
    4863             : 
    4864             :       !HFX potential derivatives
    4865             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4866          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4867             :       CALL build_2c_derivatives(mat_der_pot, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4868          42 :                                 basis_set_RI, basis_set_RI, ri_data%hfx_pot, do_kpoints=.TRUE.)
    4869          42 :       CALL release_neighbor_list_sets(nl_2c)
    4870             : 
    4871             :       !RI metric derivatives
    4872             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    4873          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4874             :       CALL build_2c_derivatives(mat_der_metric, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4875          42 :                                 basis_set_RI, basis_set_RI, ri_data%ri_metric, do_kpoints=.TRUE.)
    4876          42 :       CALL release_neighbor_list_sets(nl_2c)
    4877             : 
    4878             :       !Get into extended RI basis and tensor format
    4879         168 :       DO i_xyz = 1, 3
    4880         378 :          DO iatom = 1, natom
    4881         252 :             CALL dbt_create(ri_data%t_2c_inv(1, 1), t_2c_der_metric(iatom, i_xyz))
    4882             :             CALL get_ext_2c_int(t_2c_der_metric(iatom, i_xyz), mat_der_metric(:, i_xyz), &
    4883         378 :                                 iatom, iatom, 1, ri_data, qs_env)
    4884             :          END DO
    4885        2982 :          DO i_img = 1, nimg
    4886        2940 :             CALL dbcsr_release(mat_der_metric(i_img, i_xyz))
    4887             :          END DO
    4888             :       END DO
    4889          42 :       CALL timestop(handle2)
    4890             : 
    4891          42 :       CALL timestop(handle)
    4892             : 
    4893         252 :    END SUBROUTINE precalc_derivatives
    4894             : 
    4895             : ! **************************************************************************************************
    4896             : !> \brief Update the forces due to the derivative of the a 2-center product d/dR (Q|R)
    4897             : !> \param force ...
    4898             : !> \param t_2c_contr A precontracted tensor containing sum_abcdPS (ab|P)(P|Q)^-1 (R|S)^-1 (S|cd) P_ac P_bd
    4899             : !> \param t_2c_der the d/dR (Q|R) tensor, in all 3 cartesian directions
    4900             : !> \param atom_of_kind ...
    4901             : !> \param kind_of ...
    4902             : !> \param img in which periodic image the second center of the tensor is
    4903             : !> \param pref ...
    4904             : !> \param ri_data ...
    4905             : !> \param qs_env ...
    4906             : !> \param work_virial ...
    4907             : !> \param cell ...
    4908             : !> \param particle_set ...
    4909             : !> \param diag ...
    4910             : !> \param offdiag ...
    4911             : !> \note IMPORTANT: t_tc_contr and t_2c_der need to have the same distribution. Atomic block sizes are
    4912             : !>                  assumed
    4913             : ! **************************************************************************************************
    4914        2905 :    SUBROUTINE get_2c_der_force(force, t_2c_contr, t_2c_der, atom_of_kind, kind_of, img, pref, &
    4915             :                                ri_data, qs_env, work_virial, cell, particle_set, diag, offdiag)
    4916             : 
    4917             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    4918             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_contr
    4919             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_der
    4920             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    4921             :       INTEGER, INTENT(IN)                                :: img
    4922             :       REAL(dp), INTENT(IN)                               :: pref
    4923             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4924             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4925             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    4926             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    4927             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    4928             :          POINTER                                         :: particle_set
    4929             :       LOGICAL, INTENT(IN), OPTIONAL                      :: diag, offdiag
    4930             : 
    4931             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_2c_der_force'
    4932             : 
    4933             :       INTEGER                                            :: handle, i_img, i_RI, i_xyz, iat, &
    4934             :                                                             iat_of_kind, ikind, j_img, j_RI, &
    4935             :                                                             j_xyz, jat, jat_of_kind, jkind, natom
    4936             :       INTEGER, DIMENSION(2)                              :: ind
    4937        2905 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    4938             :       LOGICAL                                            :: found, my_diag, my_offdiag, use_virial
    4939             :       REAL(dp)                                           :: new_force
    4940        2905 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :), TARGET     :: contr_blk, der_blk
    4941             :       REAL(dp), DIMENSION(3)                             :: scoord
    4942             :       TYPE(dbt_iterator_type)                            :: iter
    4943             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4944             : 
    4945        2905 :       NULLIFY (kpoints, index_to_cell)
    4946             : 
    4947             :       !Loop over the blocks of d/dR (Q|R), contract with the corresponding block of t_2c_contr and
    4948             :       !update the relevant force
    4949             : 
    4950        2905 :       CALL timeset(routineN, handle)
    4951             : 
    4952        2905 :       use_virial = .FALSE.
    4953        2905 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    4954             : 
    4955        2905 :       my_diag = .FALSE.
    4956        2905 :       IF (PRESENT(diag)) my_diag = diag
    4957             : 
    4958        2324 :       my_offdiag = .FALSE.
    4959        2324 :       IF (PRESENT(diag)) my_offdiag = offdiag
    4960             : 
    4961        2905 :       CALL get_qs_env(qs_env, kpoints=kpoints, natom=natom)
    4962        2905 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    4963             : 
    4964             : !$OMP PARALLEL DEFAULT(NONE) &
    4965             : !$OMP SHARED(t_2c_der,t_2c_contr,work_virial,force,use_virial,natom,index_to_cell,ri_data,img) &
    4966             : !$OMP SHARED(pref,atom_of_kind,kind_of,particle_set,cell,my_diag,my_offdiag) &
    4967             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk,contr_blk,found,new_force,i_RI,i_img,j_RI,j_img) &
    4968        2905 : !$OMP PRIVATE(iat,jat,iat_of_kind,jat_of_kind,ikind,jkind,scoord)
    4969             :       DO i_xyz = 1, 3
    4970             :          CALL dbt_iterator_start(iter, t_2c_der(i_xyz))
    4971             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4972             :             CALL dbt_iterator_next_block(iter, ind)
    4973             : 
    4974             :             !Only take forecs due to block diagonal or block off-diagonal, depending on arguments
    4975             :             IF ((my_diag .AND. .NOT. my_offdiag) .OR. (.NOT. my_diag .AND. my_offdiag)) THEN
    4976             :                IF (my_diag .AND. (ind(1) .NE. ind(2))) CYCLE
    4977             :                IF (my_offdiag .AND. (ind(1) == ind(2))) CYCLE
    4978             :             END IF
    4979             : 
    4980             :             CALL dbt_get_block(t_2c_der(i_xyz), ind, der_blk, found)
    4981             :             CPASSERT(found)
    4982             :             CALL dbt_get_block(t_2c_contr, ind, contr_blk, found)
    4983             : 
    4984             :             IF (found) THEN
    4985             : 
    4986             :                !an element of d/dR (Q|R) corresponds to 2 things because of translational invariance
    4987             :                !(Q'| R) = - (Q| R'), once wrt the center on Q, and once on R
    4988             :                new_force = pref*SUM(der_blk(:, :)*contr_blk(:, :))
    4989             : 
    4990             :                i_RI = (ind(1) - 1)/natom + 1
    4991             :                i_img = ri_data%RI_cell_to_img(i_RI)
    4992             :                iat = ind(1) - (i_RI - 1)*natom
    4993             :                iat_of_kind = atom_of_kind(iat)
    4994             :                ikind = kind_of(iat)
    4995             : 
    4996             :                j_RI = (ind(2) - 1)/natom + 1
    4997             :                j_img = ri_data%RI_cell_to_img(j_RI)
    4998             :                jat = ind(2) - (j_RI - 1)*natom
    4999             :                jat_of_kind = atom_of_kind(jat)
    5000             :                jkind = kind_of(jat)
    5001             : 
    5002             :                !Force on iatom (first center)
    5003             : !$OMP ATOMIC
    5004             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5005             :                                                           + new_force
    5006             : 
    5007             :                IF (use_virial) THEN
    5008             : 
    5009             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5010             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_img), dp)
    5011             : 
    5012             :                   DO j_xyz = 1, 3
    5013             : !$OMP ATOMIC
    5014             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5015             :                   END DO
    5016             :                END IF
    5017             : 
    5018             :                !Force on jatom (second center)
    5019             : !$OMP ATOMIC
    5020             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5021             :                                                           - new_force
    5022             : 
    5023             :                IF (use_virial) THEN
    5024             : 
    5025             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5026             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, j_img) + index_to_cell(:, img), dp)
    5027             : 
    5028             :                   DO j_xyz = 1, 3
    5029             : !$OMP ATOMIC
    5030             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) - new_force*scoord(j_xyz)
    5031             :                   END DO
    5032             :                END IF
    5033             : 
    5034             :                DEALLOCATE (contr_blk)
    5035             :             END IF
    5036             : 
    5037             :             DEALLOCATE (der_blk)
    5038             :          END DO !iter
    5039             :          CALL dbt_iterator_stop(iter)
    5040             : 
    5041             :       END DO !i_xyz
    5042             : !$OMP END PARALLEL
    5043        2905 :       CALL timestop(handle)
    5044             : 
    5045        5810 :    END SUBROUTINE get_2c_der_force
    5046             : 
    5047             : ! **************************************************************************************************
    5048             : !> \brief This routines calculates the force contribution from a trace over 3D tensors, i.e.
    5049             : !>        force = sum_ijk A_ijk B_ijk., the B tensor is (P^0| sigma^0 lambda^img), with P in the
    5050             : !>        extended RI basis. Note that all tensors are stacked along the 3rd dimension
    5051             : !> \param force ...
    5052             : !> \param t_3c_contr ...
    5053             : !> \param t_3c_der_1 ...
    5054             : !> \param t_3c_der_2 ...
    5055             : !> \param atom_of_kind ...
    5056             : !> \param kind_of ...
    5057             : !> \param idx_to_at_RI ...
    5058             : !> \param idx_to_at_AO ...
    5059             : !> \param i_images ...
    5060             : !> \param lb_img ...
    5061             : !> \param pref ...
    5062             : !> \param ri_data ...
    5063             : !> \param qs_env ...
    5064             : !> \param work_virial ...
    5065             : !> \param cell ...
    5066             : !> \param particle_set ...
    5067             : ! **************************************************************************************************
    5068        1525 :    SUBROUTINE get_force_from_3c_trace(force, t_3c_contr, t_3c_der_1, t_3c_der_2, atom_of_kind, kind_of, &
    5069        3050 :                                       idx_to_at_RI, idx_to_at_AO, i_images, lb_img, pref, &
    5070             :                                       ri_data, qs_env, work_virial, cell, particle_set)
    5071             : 
    5072             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    5073             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_contr
    5074             :       TYPE(dbt_type), DIMENSION(3), INTENT(INOUT)        :: t_3c_der_1, t_3c_der_2
    5075             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of, idx_to_at_RI, &
    5076             :                                                             idx_to_at_AO, i_images
    5077             :       INTEGER, INTENT(IN)                                :: lb_img
    5078             :       REAL(dp), INTENT(IN)                               :: pref
    5079             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    5080             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    5081             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    5082             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    5083             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    5084             :          POINTER                                         :: particle_set
    5085             : 
    5086             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_force_from_3c_trace'
    5087             : 
    5088             :       INTEGER :: handle, i_RI, i_xyz, iat, iat_of_kind, idx, ikind, j_xyz, jat, jat_of_kind, &
    5089             :          jkind, kat, kat_of_kind, kkind, nblks_AO, nblks_RI, RI_img
    5090             :       INTEGER, DIMENSION(3)                              :: ind
    5091        1525 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    5092             :       LOGICAL                                            :: found, found_1, found_2, use_virial
    5093             :       REAL(dp)                                           :: new_force
    5094        1525 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET  :: contr_blk, der_blk_1, der_blk_2, &
    5095        1525 :                                                             der_blk_3
    5096             :       REAL(dp), DIMENSION(3)                             :: scoord
    5097             :       TYPE(dbt_iterator_type)                            :: iter
    5098             :       TYPE(kpoint_type), POINTER                         :: kpoints
    5099             : 
    5100        1525 :       NULLIFY (kpoints, index_to_cell)
    5101             : 
    5102        1525 :       CALL timeset(routineN, handle)
    5103             : 
    5104        1525 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    5105        1525 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    5106             : 
    5107        1525 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    5108        1525 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    5109             : 
    5110        1525 :       use_virial = .FALSE.
    5111        1525 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    5112             : 
    5113             : !$OMP PARALLEL DEFAULT(NONE) &
    5114             : !$OMP SHARED(t_3c_der_1, t_3c_der_2,t_3c_contr,work_virial,force,use_virial,index_to_cell,i_images,lb_img) &
    5115             : !$OMP SHARED(pref,idx_to_at_AO,atom_of_kind,kind_of,particle_set,cell,idx_to_at_RI,ri_data,nblks_RI,nblks_AO) &
    5116             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk_1,contr_blk,found,new_force,iat,iat_of_kind,ikind,scoord) &
    5117        1525 : !$OMP PRIVATE(jat,kat,jat_of_kind,kat_of_kind,jkind,kkind,i_RI,RI_img,der_blk_2,der_blk_3,found_1,found_2,idx)
    5118             :       CALL dbt_iterator_start(iter, t_3c_contr)
    5119             :       DO WHILE (dbt_iterator_blocks_left(iter))
    5120             :          CALL dbt_iterator_next_block(iter, ind)
    5121             : 
    5122             :          CALL dbt_get_block(t_3c_contr, ind, contr_blk, found)
    5123             :          IF (found) THEN
    5124             : 
    5125             :             DO i_xyz = 1, 3
    5126             :                CALL dbt_get_block(t_3c_der_1(i_xyz), ind, der_blk_1, found_1)
    5127             :                IF (.NOT. found_1) THEN
    5128             :                   DEALLOCATE (der_blk_1)
    5129             :                   ALLOCATE (der_blk_1(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5130             :                   der_blk_1(:, :, :) = 0.0_dp
    5131             :                END IF
    5132             :                CALL dbt_get_block(t_3c_der_2(i_xyz), ind, der_blk_2, found_2)
    5133             :                IF (.NOT. found_2) THEN
    5134             :                   DEALLOCATE (der_blk_2)
    5135             :                   ALLOCATE (der_blk_2(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5136             :                   der_blk_2(:, :, :) = 0.0_dp
    5137             :                END IF
    5138             : 
    5139             :                ALLOCATE (der_blk_3(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5140             :                der_blk_3(:, :, :) = -(der_blk_1(:, :, :) + der_blk_2(:, :, :))
    5141             : 
    5142             :                !We assume the tensors are in the format (P^0| sigma^0 mu^a+c-b), with P a member of the
    5143             :                !extended RI basis set
    5144             : 
    5145             :                !Force for the first center (RI extended basis, zero cell)
    5146             :                new_force = pref*SUM(der_blk_1(:, :, :)*contr_blk(:, :, :))
    5147             : 
    5148             :                i_RI = (ind(1) - 1)/nblks_RI + 1
    5149             :                RI_img = ri_data%RI_cell_to_img(i_RI)
    5150             :                iat = idx_to_at_RI(ind(1) - (i_RI - 1)*nblks_RI)
    5151             :                iat_of_kind = atom_of_kind(iat)
    5152             :                ikind = kind_of(iat)
    5153             : 
    5154             : !$OMP ATOMIC
    5155             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5156             :                                                           + new_force
    5157             : 
    5158             :                IF (use_virial) THEN
    5159             : 
    5160             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5161             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, RI_img), dp)
    5162             : 
    5163             :                   DO j_xyz = 1, 3
    5164             : !$OMP ATOMIC
    5165             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5166             :                   END DO
    5167             :                END IF
    5168             : 
    5169             :                !Force with respect to the second center (AO basis, zero cell)
    5170             :                new_force = pref*SUM(der_blk_2(:, :, :)*contr_blk(:, :, :))
    5171             :                jat = idx_to_at_AO(ind(2))
    5172             :                jat_of_kind = atom_of_kind(jat)
    5173             :                jkind = kind_of(jat)
    5174             : 
    5175             : !$OMP ATOMIC
    5176             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5177             :                                                           + new_force
    5178             : 
    5179             :                IF (use_virial) THEN
    5180             : 
    5181             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5182             : 
    5183             :                   DO j_xyz = 1, 3
    5184             : !$OMP ATOMIC
    5185             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5186             :                   END DO
    5187             :                END IF
    5188             : 
    5189             :                !Force with respect to the third center (AO basis, apc_img - b_img)
    5190             :                !Note: tensors are stacked along the 3rd direction
    5191             :                new_force = pref*SUM(der_blk_3(:, :, :)*contr_blk(:, :, :))
    5192             :                idx = (ind(3) - 1)/nblks_AO + 1
    5193             :                kat = idx_to_at_AO(ind(3) - (idx - 1)*nblks_AO)
    5194             :                kat_of_kind = atom_of_kind(kat)
    5195             :                kkind = kind_of(kat)
    5196             : 
    5197             : !$OMP ATOMIC
    5198             :                force(kkind)%fock_4c(i_xyz, kat_of_kind) = force(kkind)%fock_4c(i_xyz, kat_of_kind) &
    5199             :                                                           + new_force
    5200             : 
    5201             :                IF (use_virial) THEN
    5202             :                   CALL real_to_scaled(scoord, pbc(particle_set(kat)%r, cell), cell)
    5203             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_images(lb_img - 1 + idx)), dp)
    5204             : 
    5205             :                   DO j_xyz = 1, 3
    5206             : !$OMP ATOMIC
    5207             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5208             :                   END DO
    5209             :                END IF
    5210             : 
    5211             :                DEALLOCATE (der_blk_1, der_blk_2, der_blk_3)
    5212             :             END DO !i_xyz
    5213             :             DEALLOCATE (contr_blk)
    5214             :          END IF !found
    5215             :       END DO !iter
    5216             :       CALL dbt_iterator_stop(iter)
    5217             : !$OMP END PARALLEL
    5218        1525 :       CALL timestop(handle)
    5219             : 
    5220        3050 :    END SUBROUTINE get_force_from_3c_trace
    5221             : 
    5222             : END MODULE

Generated by: LCOV version 1.15