LCOV - code coverage report
Current view: top level - src - hfx_ri_kp.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 2161 2191 98.6 %
Date: 2024-12-21 06:28:57 Functions: 38 40 95.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief RI-methods for HFX and K-points.
      10             : !> \auhtor Augustin Bussy (01.2023)
      11             : ! **************************************************************************************************
      12             : 
      13             : MODULE hfx_ri_kp
      14             :    USE admm_types,                      ONLY: get_admm_env
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      16             :                                               get_atomic_kind_set
      17             :    USE basis_set_types,                 ONLY: get_gto_basis_set,&
      18             :                                               gto_basis_set_p_type
      19             :    USE bibliography,                    ONLY: Bussy2024,&
      20             :                                               cite_reference
      21             :    USE cell_types,                      ONLY: cell_type,&
      22             :                                               pbc,&
      23             :                                               real_to_scaled,&
      24             :                                               scaled_to_real
      25             :    USE cp_array_utils,                  ONLY: cp_1d_logical_p_type,&
      26             :                                               cp_2d_r_p_type,&
      27             :                                               cp_3d_r_p_type
      28             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_create,&
      29             :                                               cp_blacs_env_release,&
      30             :                                               cp_blacs_env_type
      31             :    USE cp_control_types,                ONLY: dft_control_type
      32             :    USE cp_dbcsr_api,                    ONLY: &
      33             :         dbcsr_add, dbcsr_clear, dbcsr_copy, dbcsr_create, dbcsr_distribution_get, &
      34             :         dbcsr_distribution_new, dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_dot, &
      35             :         dbcsr_filter, dbcsr_finalize, dbcsr_get_block_p, dbcsr_get_info, &
      36             :         dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
      37             :         dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_put_block, dbcsr_release, &
      38             :         dbcsr_type, dbcsr_type_no_symmetry, dbcsr_type_symmetric
      39             :    USE cp_dbcsr_cholesky,               ONLY: cp_dbcsr_cholesky_decompose,&
      40             :                                               cp_dbcsr_cholesky_invert
      41             :    USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
      42             :    USE cp_dbcsr_diag,                   ONLY: cp_dbcsr_power
      43             :    USE cp_dbcsr_operations,             ONLY: cp_dbcsr_dist2d_to_dist
      44             :    USE dbt_api,                         ONLY: &
      45             :         dbt_batched_contract_finalize, dbt_batched_contract_init, dbt_clear, dbt_contract, &
      46             :         dbt_copy, dbt_copy_matrix_to_tensor, dbt_copy_tensor_to_matrix, dbt_create, dbt_destroy, &
      47             :         dbt_distribution_destroy, dbt_distribution_new, dbt_distribution_type, dbt_filter, &
      48             :         dbt_finalize, dbt_get_block, dbt_get_info, dbt_get_stored_coordinates, &
      49             :         dbt_iterator_blocks_left, dbt_iterator_next_block, dbt_iterator_start, dbt_iterator_stop, &
      50             :         dbt_iterator_type, dbt_mp_environ_pgrid, dbt_pgrid_create, dbt_pgrid_destroy, &
      51             :         dbt_pgrid_type, dbt_put_block, dbt_scale, dbt_type
      52             :    USE distribution_2d_types,           ONLY: distribution_2d_release,&
      53             :                                               distribution_2d_type
      54             :    USE hfx_ri,                          ONLY: get_idx_to_atom,&
      55             :                                               hfx_ri_pre_scf_calc_tensors
      56             :    USE hfx_types,                       ONLY: hfx_ri_type
      57             :    USE input_constants,                 ONLY: do_potential_short,&
      58             :                                               hfx_ri_do_2c_cholesky,&
      59             :                                               hfx_ri_do_2c_diag,&
      60             :                                               hfx_ri_do_2c_iter
      61             :    USE input_cp2k_hfx,                  ONLY: ri_pmat
      62             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      63             :                                               section_vals_type,&
      64             :                                               section_vals_val_get,&
      65             :                                               section_vals_val_set
      66             :    USE iterate_matrix,                  ONLY: invert_hotelling
      67             :    USE kinds,                           ONLY: dp,&
      68             :                                               int_8
      69             :    USE kpoint_types,                    ONLY: get_kpoint_info,&
      70             :                                               kpoint_type
      71             :    USE libint_2c_3c,                    ONLY: cutoff_screen_factor
      72             :    USE machine,                         ONLY: m_flush,&
      73             :                                               m_walltime
      74             :    USE mathlib,                         ONLY: erfc_cutoff
      75             :    USE message_passing,                 ONLY: mp_cart_type,&
      76             :                                               mp_para_env_type,&
      77             :                                               mp_request_type,&
      78             :                                               mp_waitall
      79             :    USE particle_methods,                ONLY: get_particle_set
      80             :    USE particle_types,                  ONLY: particle_type
      81             :    USE physcon,                         ONLY: angstrom
      82             :    USE qs_environment_types,            ONLY: get_qs_env,&
      83             :                                               qs_environment_type
      84             :    USE qs_force_types,                  ONLY: qs_force_type
      85             :    USE qs_integral_utils,               ONLY: basis_set_list_setup
      86             :    USE qs_interactions,                 ONLY: init_interaction_radii_orb_basis
      87             :    USE qs_kind_types,                   ONLY: qs_kind_type
      88             :    USE qs_neighbor_list_types,          ONLY: get_iterator_info,&
      89             :                                               neighbor_list_iterate,&
      90             :                                               neighbor_list_iterator_create,&
      91             :                                               neighbor_list_iterator_p_type,&
      92             :                                               neighbor_list_iterator_release,&
      93             :                                               neighbor_list_set_p_type,&
      94             :                                               release_neighbor_list_sets
      95             :    USE qs_scf_types,                    ONLY: qs_scf_env_type
      96             :    USE qs_tensors,                      ONLY: &
      97             :         build_2c_derivatives, build_2c_neighbor_lists, build_3c_derivatives, &
      98             :         build_3c_neighbor_lists, get_3c_iterator_info, get_tensor_occupancy, &
      99             :         neighbor_list_3c_destroy, neighbor_list_3c_iterate, neighbor_list_3c_iterator_create, &
     100             :         neighbor_list_3c_iterator_destroy
     101             :    USE qs_tensors_types,                ONLY: create_2c_tensor,&
     102             :                                               create_3c_tensor,&
     103             :                                               create_tensor_batches,&
     104             :                                               distribution_2d_create,&
     105             :                                               distribution_3d_create,&
     106             :                                               distribution_3d_type,&
     107             :                                               neighbor_list_3c_iterator_type,&
     108             :                                               neighbor_list_3c_type
     109             :    USE util,                            ONLY: get_limit
     110             :    USE virial_types,                    ONLY: virial_type
     111             : #include "./base/base_uses.f90"
     112             : 
     113             : !$ USE OMP_LIB, ONLY: omp_get_num_threads
     114             : 
     115             :    IMPLICIT NONE
     116             :    PRIVATE
     117             : 
     118             :    PUBLIC :: hfx_ri_update_ks_kp, hfx_ri_update_forces_kp
     119             : 
     120             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_ri_kp'
     121             : CONTAINS
     122             : 
     123             : ! NOTES: for a start, we do not seek performance, but accuracy. So in this first implementation,
     124             : !        we give little consideration to batching, load balance and such.
     125             : !        We also put everything here, even if there is some code replication with the original RI_HFX
     126             : !        We will only work in the RHO flavor
     127             : !        For now, we will also always assume that there is a single para_env, and that there is no
     128             : !        K-point subgroup. This might change in the future
     129             : 
     130             : ! **************************************************************************************************
     131             : !> \brief I_1nitialize the ri_data for K-point. For now, we take the normal, usual existing ri_data
     132             : !>        and we adapt it to our needs
     133             : !> \param dbcsr_template ...
     134             : !> \param ri_data ...
     135             : !> \param qs_env ...
     136             : ! **************************************************************************************************
     137          70 :    SUBROUTINE adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     138             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     139             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     140             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     141             : 
     142             :       INTEGER                                            :: i_img, i_RI, i_spin, iatom, natom, &
     143             :                                                             nblks_RI, nimg, nkind, nspins
     144          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, dist1, dist2, dist3
     145             :       TYPE(dft_control_type), POINTER                    :: dft_control
     146             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     147             : 
     148          70 :       NULLIFY (dft_control, para_env)
     149             : 
     150             :       !The main thing that we need to do is to allocate more space for the integrals, such that there
     151             :       !is room for each periodic image. Note that we only go in 1D, i.e. we store (mu^0 sigma^a|P^0),
     152             :       !and (P^0|Q^a) => the RI basis is always in the main cell.
     153             : 
     154             :       !Get kpoint info
     155          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, para_env=para_env, nkind=nkind)
     156          70 :       nimg = ri_data%nimg
     157             : 
     158             :       !Along the RI direction we have basis elements spread accross ncell_RI images.
     159          70 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
     160         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     161         506 :       DO i_RI = 1, ri_data%ncell_RI
     162        2344 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
     163             :       END DO
     164             : 
     165        4202 :       ALLOCATE (ri_data%t_3c_int_ctr_1(1, nimg))
     166             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), dist1, dist2, dist3, &
     167             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     168          70 :                             ri_data%bsizes_AO_split, [1, 2], [3], name="(AO RI | AO)")
     169             : 
     170        1716 :       DO i_img = 2, nimg
     171        1716 :          CALL dbt_create(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_1(1, i_img))
     172             :       END DO
     173          70 :       DEALLOCATE (dist1, dist2, dist3)
     174             : 
     175         770 :       ALLOCATE (ri_data%t_3c_int_ctr_2(1, 1))
     176             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_2(1, 1), dist1, dist2, dist3, &
     177             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     178          70 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
     179          70 :       DEALLOCATE (dist1, dist2, dist3)
     180             : 
     181             :       !We use full block sizes for the 2c quantities
     182          70 :       DEALLOCATE (bsizes_RI_ext)
     183          70 :       nblks_RI = SIZE(ri_data%bsizes_RI)
     184         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     185         506 :       DO i_RI = 1, ri_data%ncell_RI
     186        1378 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
     187             :       END DO
     188             : 
     189        3010 :       ALLOCATE (ri_data%t_2c_inv(1, natom), ri_data%t_2c_int(1, natom), ri_data%t_2c_pot(1, natom))
     190             :       CALL create_2c_tensor(ri_data%t_2c_inv(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     191             :                             bsizes_RI_ext, bsizes_RI_ext, &
     192          70 :                             name="(RI | RI)")
     193          70 :       DEALLOCATE (dist1, dist2)
     194          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, 1))
     195          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, 1))
     196         140 :       DO iatom = 2, natom
     197          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_inv(1, iatom))
     198          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, iatom))
     199         140 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, iatom))
     200             :       END DO
     201             : 
     202         350 :       ALLOCATE (ri_data%kp_cost(natom, natom, nimg))
     203       12082 :       ri_data%kp_cost = 0.0_dp
     204             : 
     205             :       !We store the density and KS matrix in tensor format
     206          70 :       nspins = dft_control%nspins
     207        9218 :       ALLOCATE (ri_data%rho_ao_t(nspins, nimg), ri_data%ks_t(nspins, nimg))
     208             :       CALL create_2c_tensor(ri_data%rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     209             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     210          70 :                             name="(AO | AO)")
     211          70 :       DEALLOCATE (dist1, dist2)
     212             : 
     213          70 :       CALL dbt_create(dbcsr_template, ri_data%ks_t(1, 1))
     214             : 
     215          70 :       IF (nspins == 2) THEN
     216          24 :          CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(2, 1))
     217          24 :          CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(2, 1))
     218             :       END IF
     219        1716 :       DO i_img = 2, nimg
     220        3710 :          DO i_spin = 1, nspins
     221        1994 :             CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(i_spin, i_img))
     222        3640 :             CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(i_spin, i_img))
     223             :          END DO
     224             :       END DO
     225             : 
     226         210 :    END SUBROUTINE adapt_ri_data_to_kp
     227             : 
     228             : ! **************************************************************************************************
     229             : !> \brief The pre-scf steps for RI-HFX k-points calculation. Namely the calculation of the integrals
     230             : !> \param dbcsr_template ...
     231             : !> \param ri_data ...
     232             : !> \param qs_env ...
     233             : ! **************************************************************************************************
     234          70 :    SUBROUTINE hfx_ri_pre_scf_kp(dbcsr_template, ri_data, qs_env)
     235             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     236             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     237             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     238             : 
     239             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'hfx_ri_pre_scf_kp'
     240             : 
     241             :       INTEGER                                            :: handle, i_img, iatom, natom, nimg, nkind
     242          70 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: t_2c_op_pot, t_2c_op_RI
     243          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: t_3c_int
     244             :       TYPE(dft_control_type), POINTER                    :: dft_control
     245             : 
     246          70 :       NULLIFY (dft_control)
     247             : 
     248          70 :       CALL timeset(routineN, handle)
     249             : 
     250          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, nkind=nkind)
     251             : 
     252          70 :       CALL cleanup_kp(ri_data)
     253             : 
     254             :       !We do all the checks on what we allow in this initial implementation
     255          70 :       IF (ri_data%flavor .NE. ri_pmat) CPABORT("K-points RI-HFX only with RHO flavor")
     256          70 :       IF (ri_data%same_op) ri_data%same_op = .FALSE. !force the full calculation with RI metric
     257          70 :       IF (ABS(ri_data%eps_pgf_orb - dft_control%qs_control%eps_pgf_orb) > 1.0E-16_dp) &
     258           0 :          CPABORT("RI%EPS_PGF_ORB and QS%EPS_PGF_ORB must be identical for RI-HFX k-points")
     259             : 
     260          70 :       CALL get_kp_and_ri_images(ri_data, qs_env)
     261          70 :       nimg = ri_data%nimg
     262             : 
     263             :       !Calculate the integrals
     264        3712 :       ALLOCATE (t_2c_op_pot(nimg), t_2c_op_RI(nimg))
     265        4202 :       ALLOCATE (t_3c_int(1, nimg))
     266          70 :       CALL hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_op_RI, t_2c_op_pot, t_3c_int, do_kpoints=.TRUE.)
     267             : 
     268             :       !Make sure the internals have the k-point format
     269          70 :       CALL adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     270             : 
     271             :       !For each atom i, we calculate the inverse RI metric (P^0 | Q^0)^-1 without external bumping yet
     272             :       !Also store the off-diagonal integrals of the RI metric in case of forces, bumped from the left
     273         210 :       DO iatom = 1, natom
     274             :          CALL get_ext_2c_int(ri_data%t_2c_inv(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     275         140 :                              do_inverse=.TRUE.)
     276             :          !for the forces:
     277             :          !off-diagonl RI metric bumped from the left
     278             :          CALL get_ext_2c_int(ri_data%t_2c_int(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, &
     279         140 :                              qs_env, off_diagonal=.TRUE.)
     280         140 :          CALL apply_bump(ri_data%t_2c_int(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     281             : 
     282             :          !RI metric with bumped off-diagonal blocks (but not inverted), depumed from left and right
     283             :          CALL get_ext_2c_int(ri_data%t_2c_pot(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     284         140 :                              do_inverse=.TRUE., skip_inverse=.TRUE.)
     285             :          CALL apply_bump(ri_data%t_2c_pot(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., &
     286         210 :                          from_right=.TRUE., debump=.TRUE.)
     287             : 
     288             :       END DO
     289             : 
     290        1786 :       DO i_img = 1, nimg
     291        1786 :          CALL dbcsr_release(t_2c_op_RI(i_img))
     292             :       END DO
     293             : 
     294        3572 :       ALLOCATE (ri_data%kp_mat_2c_pot(1, nimg))
     295        1786 :       DO i_img = 1, nimg
     296        1716 :          CALL dbcsr_create(ri_data%kp_mat_2c_pot(1, i_img), template=t_2c_op_pot(i_img))
     297        1716 :          CALL dbcsr_copy(ri_data%kp_mat_2c_pot(1, i_img), t_2c_op_pot(i_img))
     298        1786 :          CALL dbcsr_release(t_2c_op_pot(i_img))
     299             :       END DO
     300             : 
     301             :       !Pre-contract all 3c integrals with the bumped inverse RI metric (P^0|Q^0)^-1,
     302             :       !and store in ri_data%t_3c_int_ctr_1
     303          70 :       CALL precontract_3c_ints(t_3c_int, ri_data, qs_env)
     304             : 
     305             :       !reorder the 3c integrals such that empty images are bunched up together
     306          70 :       CALL reorder_3c_ints(ri_data%t_3c_int_ctr_1(1, :), ri_data)
     307             : 
     308          70 :       CALL timestop(handle)
     309             : 
     310        1856 :    END SUBROUTINE hfx_ri_pre_scf_kp
     311             : 
     312             : ! **************************************************************************************************
     313             : !> \brief clean-up the KP specific data from ri_data
     314             : !> \param ri_data ...
     315             : ! **************************************************************************************************
     316          70 :    SUBROUTINE cleanup_kp(ri_data)
     317             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     318             : 
     319             :       INTEGER                                            :: i, j
     320             : 
     321          70 :       IF (ALLOCATED(ri_data%kp_cost)) DEALLOCATE (ri_data%kp_cost)
     322          70 :       IF (ALLOCATED(ri_data%idx_to_img)) DEALLOCATE (ri_data%idx_to_img)
     323          70 :       IF (ALLOCATED(ri_data%img_to_idx)) DEALLOCATE (ri_data%img_to_idx)
     324          70 :       IF (ALLOCATED(ri_data%present_images)) DEALLOCATE (ri_data%present_images)
     325          70 :       IF (ALLOCATED(ri_data%img_to_RI_cell)) DEALLOCATE (ri_data%img_to_RI_cell)
     326          70 :       IF (ALLOCATED(ri_data%RI_cell_to_img)) DEALLOCATE (ri_data%RI_cell_to_img)
     327             : 
     328          70 :       IF (ALLOCATED(ri_data%kp_mat_2c_pot)) THEN
     329         540 :          DO j = 1, SIZE(ri_data%kp_mat_2c_pot, 2)
     330        1060 :             DO i = 1, SIZE(ri_data%kp_mat_2c_pot, 1)
     331        1040 :                CALL dbcsr_release(ri_data%kp_mat_2c_pot(i, j))
     332             :             END DO
     333             :          END DO
     334          20 :          DEALLOCATE (ri_data%kp_mat_2c_pot)
     335             :       END IF
     336             : 
     337          70 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
     338         540 :          DO i = 1, SIZE(ri_data%kp_t_3c_int)
     339         540 :             CALL dbt_destroy(ri_data%kp_t_3c_int(i))
     340             :          END DO
     341         540 :          DEALLOCATE (ri_data%kp_t_3c_int)
     342             :       END IF
     343             : 
     344          70 :       IF (ALLOCATED(ri_data%t_2c_inv)) THEN
     345         160 :          DO j = 1, SIZE(ri_data%t_2c_inv, 2)
     346         250 :             DO i = 1, SIZE(ri_data%t_2c_inv, 1)
     347         180 :                CALL dbt_destroy(ri_data%t_2c_inv(i, j))
     348             :             END DO
     349             :          END DO
     350         160 :          DEALLOCATE (ri_data%t_2c_inv)
     351             :       END IF
     352             : 
     353          70 :       IF (ALLOCATED(ri_data%t_2c_int)) THEN
     354         160 :          DO j = 1, SIZE(ri_data%t_2c_int, 2)
     355         250 :             DO i = 1, SIZE(ri_data%t_2c_int, 1)
     356         180 :                CALL dbt_destroy(ri_data%t_2c_int(i, j))
     357             :             END DO
     358             :          END DO
     359         160 :          DEALLOCATE (ri_data%t_2c_int)
     360             :       END IF
     361             : 
     362          70 :       IF (ALLOCATED(ri_data%t_2c_pot)) THEN
     363         160 :          DO j = 1, SIZE(ri_data%t_2c_pot, 2)
     364         250 :             DO i = 1, SIZE(ri_data%t_2c_pot, 1)
     365         180 :                CALL dbt_destroy(ri_data%t_2c_pot(i, j))
     366             :             END DO
     367             :          END DO
     368         160 :          DEALLOCATE (ri_data%t_2c_pot)
     369             :       END IF
     370             : 
     371          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_1)) THEN
     372         640 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_1, 2)
     373        1210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_1, 1)
     374        1140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_1(i, j))
     375             :             END DO
     376             :          END DO
     377         640 :          DEALLOCATE (ri_data%t_3c_int_ctr_1)
     378             :       END IF
     379             : 
     380          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_2)) THEN
     381         140 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_2, 2)
     382         210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_2, 1)
     383         140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_2(i, j))
     384             :             END DO
     385             :          END DO
     386         140 :          DEALLOCATE (ri_data%t_3c_int_ctr_2)
     387             :       END IF
     388             : 
     389          70 :       IF (ALLOCATED(ri_data%rho_ao_t)) THEN
     390         640 :          DO j = 1, SIZE(ri_data%rho_ao_t, 2)
     391        1428 :             DO i = 1, SIZE(ri_data%rho_ao_t, 1)
     392        1358 :                CALL dbt_destroy(ri_data%rho_ao_t(i, j))
     393             :             END DO
     394             :          END DO
     395         858 :          DEALLOCATE (ri_data%rho_ao_t)
     396             :       END IF
     397             : 
     398          70 :       IF (ALLOCATED(ri_data%ks_t)) THEN
     399         640 :          DO j = 1, SIZE(ri_data%ks_t, 2)
     400        1428 :             DO i = 1, SIZE(ri_data%ks_t, 1)
     401        1358 :                CALL dbt_destroy(ri_data%ks_t(i, j))
     402             :             END DO
     403             :          END DO
     404         858 :          DEALLOCATE (ri_data%ks_t)
     405             :       END IF
     406             : 
     407          70 :    END SUBROUTINE cleanup_kp
     408             : 
     409             : ! **************************************************************************************************
     410             : !> \brief Update the KS matrices for each real-space image
     411             : !> \param qs_env ...
     412             : !> \param ri_data ...
     413             : !> \param ks_matrix ...
     414             : !> \param ehfx ...
     415             : !> \param rho_ao ...
     416             : !> \param geometry_did_change ...
     417             : !> \param nspins ...
     418             : !> \param hf_fraction ...
     419             : ! **************************************************************************************************
     420         190 :    SUBROUTINE hfx_ri_update_ks_kp(qs_env, ri_data, ks_matrix, ehfx, rho_ao, &
     421             :                                   geometry_did_change, nspins, hf_fraction)
     422             : 
     423             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     424             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     425             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     426             :       REAL(KIND=dp), INTENT(OUT)                         :: ehfx
     427             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     428             :       LOGICAL, INTENT(IN)                                :: geometry_did_change
     429             :       INTEGER, INTENT(IN)                                :: nspins
     430             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     431             : 
     432             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_ks_kp'
     433             : 
     434             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_spin, iatom, &
     435             :          iblk, igroup, jatom, mb_img, n_batch_nze, natom, ngroups, nimg, nimg_nze
     436             :       INTEGER(int_8)                                     :: nflop, nze
     437         190 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_at, batch_ranges_nze, &
     438         190 :                                                             idx_to_at_AO
     439         190 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     440         190 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: sparsity_pattern
     441             :       LOGICAL                                            :: use_delta_p
     442             :       REAL(dp)                                           :: etmp, fac, occ, pfac, pref, t1, t2, t3, &
     443             :                                                             t4
     444             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     445             :       TYPE(dbcsr_type)                                   :: ks_desymm, rho_desymm, tmp
     446         190 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     447             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     448         190 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ks_t_split, t_2c_ao_tmp, t_2c_work, &
     449         190 :                                                             t_3c_int, t_3c_work_2, t_3c_work_3
     450         190 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: ks_t, ks_t_sub, t_3c_apc, t_3c_apc_sub
     451             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     452             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     453             : 
     454         190 :       NULLIFY (para_env, para_env_sub, blacs_env_sub, hfx_section, dbcsr_template)
     455             : 
     456         190 :       CALL cite_reference(Bussy2024)
     457             : 
     458         190 :       CALL timeset(routineN, handle)
     459             : 
     460         190 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natom)
     461             : 
     462         190 :       IF (nspins == 1) THEN
     463         112 :          fac = 0.5_dp*hf_fraction
     464             :       ELSE
     465          78 :          fac = 1.0_dp*hf_fraction
     466             :       END IF
     467             : 
     468         190 :       IF (geometry_did_change) THEN
     469          70 :          CALL hfx_ri_pre_scf_kp(ks_matrix(1, 1)%matrix, ri_data, qs_env)
     470             :       END IF
     471         190 :       nimg = ri_data%nimg
     472         190 :       nimg_nze = ri_data%nimg_nze
     473             : 
     474             :       !We need to calculate the KS matrix for each periodic cell with index b: F_mu^0,nu^b
     475             :       !F_mu^0,nu^b = -0.5 sum_a,c P_sigma^0,lambda^c (mu^0, sigma^a| P^0) V_P^0,Q^b (Q^b| nu^b lambda^a+c)
     476             :       !with V_P^0,Q^b = (P^0|R^0)^-1 * (R^0|S^b) * (S^b|Q^b)^-1
     477             : 
     478             :       !We use a local RI basis set for each atom in the system, which inlcudes RI basis elements for
     479             :       !each neighboring atom standing within the KIND radius (decay of Gaussian with smallest exponent)
     480             : 
     481             :       !We also limit the number of periodic images we consider accorrding to the HFX potentail in the
     482             :       !RI basis, because if V_P^0,Q^b is zero everywhere, then image b can be ignored (RI basis less diffuse)
     483             : 
     484             :       !We manage to calculate each KS matrix doing a double loop on iamges, and a double loop on atoms
     485             :       !First, we pre-contract and store P_sigma^0,lambda^c (mu^0, sigma^a| P^0) (P^0|R^0)^-1 into T_mu^0,lambda^a+c,P^0
     486             :       !Then, we loop over b_img, iatom, jatom to get (R^0|S^b)
     487             :       !Finally, we do an additional loop over a+c images where we do (R^0|S^b) (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c)
     488             :       !and the final contraction with T_mu^0,lambda^a+c,P^0
     489             : 
     490             :       !Note that the 3-center integrals are pre-contracted with the RI metric, and that the same tensor can be used
     491             :       !(mu^0, sigma^a| P^0) (P^0|R^0)  <===> (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c) by relabelling the images
     492             : 
     493         190 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     494         190 :       CALL section_vals_val_get(hfx_section, "KP_USE_DELTA_P", l_val=use_delta_p)
     495             : 
     496             :       !By default, build the density tensor based on the difference of this SCF P and that of the prev. SCF
     497         190 :       pfac = -1.0_dp
     498         190 :       IF (.NOT. use_delta_p) pfac = 0.0_dp
     499         190 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, pfac, ri_data, qs_env)
     500             : 
     501       13412 :       ALLOCATE (ks_t(nspins, nimg))
     502        4994 :       DO i_img = 1, nimg
     503       11322 :          DO i_spin = 1, nspins
     504       11132 :             CALL dbt_create(ri_data%ks_t(1, 1), ks_t(i_spin, i_img))
     505             :          END DO
     506             :       END DO
     507             : 
     508         570 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     509         190 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     510             : 
     511             :       !First we calculate and store T^1_mu^0,lambda^a+c,P = P_mu^0,lambda^c * (mu_0 sigma^a | P^0) (P^0|R^0)^-1
     512             :       !To avoid doing nimg**2 tiny contractions that do not scale well with a large number of CPUs,
     513             :       !we instead do a single loop over the a+c image index. For each a+c, we get a list of allowed
     514             :       !combination of a,c indices. Then we build TAS tensors P_mu^0,lambda^c with all concerned c's
     515             :       !and (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 with all a's. Then we perform a single contraction with larger tensors,
     516             :       !were the sum over a,c is automatically taken care of
     517       13222 :       ALLOCATE (t_3c_apc(nspins, nimg))
     518        4994 :       DO i_img = 1, nimg
     519       11322 :          DO i_spin = 1, nspins
     520       11132 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     521             :          END DO
     522             :       END DO
     523         190 :       CALL contract_pmat_3c(t_3c_apc, ri_data%rho_ao_t, ri_data, qs_env)
     524             : 
     525         190 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     526         190 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     527         190 :       CALL section_vals_val_get(hfx_section, "KP_STACK_SIZE", i_val=batch_size)
     528         190 :       ri_data%kp_stack_size = batch_size
     529             : 
     530         190 :       IF (MOD(para_env%num_pe, ngroups) .NE. 0) THEN
     531           0 :          CPWARN("KP_NGROUPS must be an integer divisor of the total number of MPI ranks. It was set to 1.")
     532           0 :          ngroups = 1
     533           0 :          CALL section_vals_val_set(hfx_section, "KP_NGROUPS", i_val=ngroups)
     534             :       END IF
     535         190 :       IF ((MOD(ngroups, natom) .NE. 0) .AND. (MOD(natom, ngroups) .NE. 0) .AND. geometry_did_change) THEN
     536           0 :          IF (ngroups > 1) THEN
     537           0 :             CPWARN("Better load balancing is reached if NGROUPS is a multiple/divisor of the number of atoms")
     538             :          END IF
     539             :       END IF
     540         190 :       group_size = para_env%num_pe/ngroups
     541         190 :       igroup = para_env%mepos/group_size
     542             : 
     543         190 :       ALLOCATE (para_env_sub)
     544         190 :       CALL para_env_sub%from_split(para_env, igroup)
     545         190 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     546             : 
     547             :       ! The sparsity pattern of each iatom, jatom pair, on each b_img, and on which subgroup
     548         950 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     549         190 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     550         190 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     551             : 
     552             :       !Get all the required tensors in the subgroups
     553       24106 :       ALLOCATE (mat_2c_pot(nimg), ks_t_sub(nspins, nimg), t_2c_ao_tmp(1), ks_t_split(2), t_2c_work(3))
     554             :       CALL get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
     555         190 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     556             : 
     557       24296 :       ALLOCATE (t_3c_int(nimg), t_3c_apc_sub(nspins, nimg), t_3c_work_2(3), t_3c_work_3(3))
     558             :       CALL get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
     559         190 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     560             : 
     561             :       !We go atom by atom, therefore there is an automatic batching along that direction
     562             :       !Also, because we stack the 3c tensors nimg times, we naturally do some batching there too
     563         570 :       ALLOCATE (batch_ranges_at(natom + 1))
     564         190 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     565         190 :       iatom = 0
     566         928 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     567         928 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     568         380 :             iatom = iatom + 1
     569         380 :             batch_ranges_at(iatom) = iblk
     570             :          END IF
     571             :       END DO
     572             : 
     573         190 :       n_batch_nze = nimg_nze/batch_size
     574         190 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
     575         570 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
     576         404 :       DO i_batch = 1, n_batch_nze
     577         404 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     578             :       END DO
     579         190 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
     580             : 
     581         190 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     582         190 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     583         190 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     584         190 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     585             : 
     586         190 :       t1 = m_walltime()
     587       33818 :       ri_data%kp_cost(:, :, :) = 0.0_dp
     588         570 :       ALLOCATE (iapc_pairs(nimg, 2))
     589        4994 :       DO b_img = 1, nimg
     590        4804 :          CALL dbt_batched_contract_init(ks_t_split(1))
     591        4804 :          CALL dbt_batched_contract_init(ks_t_split(2))
     592       14412 :          DO jatom = 1, natom
     593       33628 :             DO iatom = 1, natom
     594       19216 :                IF (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup) CYCLE
     595        3232 :                pref = 1.0_dp
     596        3232 :                IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp
     597             : 
     598             :                !measure the cost of the given i, j, b configuration
     599        3232 :                t3 = m_walltime()
     600             : 
     601             :                !Get the proper HFX potential 2c integrals (R_i^0|S_j^b)
     602        3232 :                CALL timeset(routineN//"_2c", handle2)
     603             :                CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     604             :                                    blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     605        3232 :                                    dbcsr_template=dbcsr_template)
     606        3232 :                CALL dbt_copy(t_2c_work(1), t_2c_work(2), move_data=.TRUE.) !move to split blocks
     607        3232 :                CALL dbt_filter(t_2c_work(2), ri_data%filter_eps)
     608        3232 :                CALL timestop(handle2)
     609             : 
     610        3232 :                CALL dbt_batched_contract_init(t_2c_work(2))
     611        3232 :                CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env)
     612        3232 :                CALL timeset(routineN//"_3c", handle2)
     613             : 
     614             :                !Stack the (S^b|Q^b)^-1 * (Q^b| nu^b lambda^a+c) integrals over a+c and multiply by (R_i^0|S_j^b)
     615        7331 :                DO i_batch = 1, n_batch_nze
     616             :                   CALL fill_3c_stack(t_3c_work_3(3), t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
     617             :                                      filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
     618       12297 :                                      img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     619        4099 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1), move_data=.TRUE.)
     620             : 
     621             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_3c_work_3(1), &
     622             :                                     0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
     623             :                                     contract_1=[2], notcontract_1=[1], &
     624             :                                     contract_2=[1], notcontract_2=[2, 3], &
     625        4099 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     626        4099 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     627        4099 :                   CALL dbt_copy(t_3c_work_3(2), t_3c_work_2(2), order=[2, 1, 3], move_data=.TRUE.)
     628        4099 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1))
     629             : 
     630             :                   !Stack the P_sigma^a,lambda^a+c * (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 integrals over a+c and contract
     631             :                   !to get the final block of the KS matrix
     632       12586 :                   DO i_spin = 1, nspins
     633             :                      CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
     634             :                                         ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
     635       15765 :                                         img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     636        5255 :                      CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
     637        5255 :                      IF (nze == 0) CYCLE
     638        5235 :                      CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
     639             :                      CALL dbt_contract(-pref*fac, t_3c_work_2(1), t_3c_work_2(2), &
     640             :                                        1.0_dp, ks_t_split(i_spin), map_1=[1], map_2=[2], &
     641             :                                        contract_1=[2, 3], notcontract_1=[1], &
     642             :                                        contract_2=[2, 3], notcontract_2=[1], &
     643             :                                        filter_eps=ri_data%filter_eps, &
     644        5235 :                                        move_data=i_spin == nspins, flop=nflop)
     645       14589 :                      ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     646             :                   END DO
     647             :                END DO !i_batch
     648        3232 :                CALL timestop(handle2)
     649        3232 :                CALL dbt_batched_contract_finalize(t_2c_work(2))
     650             : 
     651        3232 :                t4 = m_walltime()
     652       35288 :                ri_data%kp_cost(iatom, jatom, b_img) = t4 - t3
     653             :             END DO !iatom
     654             :          END DO !jatom
     655        4804 :          CALL dbt_batched_contract_finalize(ks_t_split(1))
     656        4804 :          CALL dbt_batched_contract_finalize(ks_t_split(2))
     657             : 
     658       11322 :          DO i_spin = 1, nspins
     659        6328 :             CALL dbt_copy(ks_t_split(i_spin), t_2c_ao_tmp(1), move_data=.TRUE.)
     660       11132 :             CALL dbt_copy(t_2c_ao_tmp(1), ks_t_sub(i_spin, b_img), summation=.TRUE.)
     661             :          END DO
     662             :       END DO !b_img
     663         190 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
     664         190 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
     665         190 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
     666         190 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
     667         190 :       CALL para_env%sync()
     668         190 :       CALL para_env%sum(ri_data%dbcsr_nflop)
     669         190 :       CALL para_env%sum(ri_data%kp_cost)
     670         190 :       t2 = m_walltime()
     671         190 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
     672             : 
     673             :       !transfer KS tensor from subgroup to main group
     674         190 :       CALL gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
     675             : 
     676             :       !Keep the 3c integrals on the subgroups to avoid communication at next SCF step
     677        4994 :       DO i_img = 1, nimg
     678        4994 :          CALL dbt_copy(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img), move_data=.TRUE.)
     679             :       END DO
     680             : 
     681             :       !clean-up subgroup tensors
     682         190 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     683         190 :       CALL dbt_destroy(ks_t_split(1))
     684         190 :       CALL dbt_destroy(ks_t_split(2))
     685         190 :       CALL dbt_destroy(t_2c_work(1))
     686         190 :       CALL dbt_destroy(t_2c_work(2))
     687         190 :       CALL dbt_destroy(t_3c_work_2(1))
     688         190 :       CALL dbt_destroy(t_3c_work_2(2))
     689         190 :       CALL dbt_destroy(t_3c_work_2(3))
     690         190 :       CALL dbt_destroy(t_3c_work_3(1))
     691         190 :       CALL dbt_destroy(t_3c_work_3(2))
     692         190 :       CALL dbt_destroy(t_3c_work_3(3))
     693        4994 :       DO i_img = 1, nimg
     694        4804 :          CALL dbt_destroy(t_3c_int(i_img))
     695        4804 :          CALL dbcsr_release(mat_2c_pot(i_img))
     696       11322 :          DO i_spin = 1, nspins
     697        6328 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
     698       11132 :             CALL dbt_destroy(ks_t_sub(i_spin, i_img))
     699             :          END DO
     700             :       END DO
     701         190 :       IF (ASSOCIATED(dbcsr_template)) THEN
     702         190 :          CALL dbcsr_release(dbcsr_template)
     703         190 :          DEALLOCATE (dbcsr_template)
     704             :       END IF
     705             : 
     706             :       !End of subgroup parallelization
     707         190 :       CALL cp_blacs_env_release(blacs_env_sub)
     708         190 :       CALL para_env_sub%free()
     709         190 :       DEALLOCATE (para_env_sub)
     710             : 
     711             :       !Currently, rho_ao_t holds the density difference (wrt to pref SCF step).
     712             :       !ks_t also hold that diff, while only having half the blocks => need to add to prev ks_t and symmetrize
     713             :       !We need the full thing for the energy, on the next SCF step
     714         190 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     715         458 :       DO i_spin = 1, nspins
     716        6786 :          DO b_img = 1, nimg
     717        6328 :             CALL dbt_copy(ks_t(i_spin, b_img), ri_data%ks_t(i_spin, b_img), summation=.TRUE.)
     718             : 
     719             :             !desymmetrize
     720        6328 :             mb_img = get_opp_index(b_img, qs_env)
     721        6596 :             IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
     722        5708 :                CALL dbt_copy(ks_t(i_spin, mb_img), ri_data%ks_t(i_spin, b_img), order=[2, 1], summation=.TRUE.)
     723             :             END IF
     724             :          END DO
     725             :       END DO
     726        4994 :       DO b_img = 1, nimg
     727       11322 :          DO i_spin = 1, nspins
     728       11132 :             CALL dbt_destroy(ks_t(i_spin, b_img))
     729             :          END DO
     730             :       END DO
     731             : 
     732             :       !calculate the energy
     733         190 :       CALL dbt_create(ri_data%ks_t(1, 1), t_2c_ao_tmp(1))
     734         190 :       CALL dbcsr_create(tmp, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_symmetric)
     735         190 :       CALL dbcsr_create(ks_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     736         190 :       CALL dbcsr_create(rho_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     737         190 :       ehfx = 0.0_dp
     738        4994 :       DO i_img = 1, nimg
     739       11322 :          DO i_spin = 1, nspins
     740        6328 :             CALL dbt_filter(ri_data%ks_t(i_spin, i_img), ri_data%filter_eps)
     741        6328 :             CALL dbt_copy(ri_data%ks_t(i_spin, i_img), t_2c_ao_tmp(1))
     742        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), ks_desymm)
     743        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), tmp)
     744        6328 :             CALL dbcsr_add(ks_matrix(i_spin, i_img)%matrix, tmp, 1.0_dp, 1.0_dp)
     745             : 
     746        6328 :             CALL dbt_copy(ri_data%rho_ao_t(i_spin, i_img), t_2c_ao_tmp(1))
     747        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), rho_desymm)
     748             : 
     749        6328 :             CALL dbcsr_dot(ks_desymm, rho_desymm, etmp)
     750        6328 :             ehfx = ehfx + 0.5_dp*etmp
     751             : 
     752       11132 :             IF (.NOT. use_delta_p) CALL dbt_clear(ri_data%ks_t(i_spin, i_img))
     753             :          END DO
     754             :       END DO
     755         190 :       CALL dbcsr_release(rho_desymm)
     756         190 :       CALL dbcsr_release(ks_desymm)
     757         190 :       CALL dbcsr_release(tmp)
     758         190 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     759             : 
     760         190 :       CALL timestop(handle)
     761             : 
     762       33346 :    END SUBROUTINE hfx_ri_update_ks_kp
     763             : 
     764             : ! **************************************************************************************************
     765             : !> \brief Update the K-points RI-HFX forces
     766             : !> \param qs_env ...
     767             : !> \param ri_data ...
     768             : !> \param nspins ...
     769             : !> \param hf_fraction ...
     770             : !> \param rho_ao ...
     771             : !> \param use_virial ...
     772             : !> \note Because this routine uses stored quantities calculated in the energy calculation, they should
     773             : !>       always be called by pairs, and with the same input densities
     774             : ! **************************************************************************************************
     775          42 :    SUBROUTINE hfx_ri_update_forces_kp(qs_env, ri_data, nspins, hf_fraction, rho_ao, use_virial)
     776             : 
     777             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     778             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     779             :       INTEGER, INTENT(IN)                                :: nspins
     780             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     781             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     782             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_virial
     783             : 
     784             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_forces_kp'
     785             : 
     786             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_loop, i_spin, &
     787             :          i_xyz, iatom, iblk, igroup, j_xyz, jatom, k_xyz, n_batch, natom, ngroups, nimg, nimg_nze
     788             :       INTEGER(int_8)                                     :: nflop, nze
     789          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_of_kind, batch_ranges_at, &
     790          42 :                                                             batch_ranges_nze, dist1, dist2, &
     791          42 :                                                             i_images, idx_to_at_AO, idx_to_at_RI, &
     792          42 :                                                             kind_of
     793          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     794          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: force_pattern, sparsity_pattern
     795             :       INTEGER, DIMENSION(2, 1)                           :: bounds_iat, bounds_jat
     796             :       LOGICAL                                            :: use_virial_prv
     797             :       REAL(dp)                                           :: fac, occ, pref, t1, t2
     798             :       REAL(dp), DIMENSION(3, 3)                          :: work_virial
     799          42 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     800             :       TYPE(cell_type), POINTER                           :: cell
     801             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     802          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     803          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_pot, mat_der_pot_sub
     804             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     805         714 :       TYPE(dbt_type)                                     :: t_2c_R, t_2c_R_split
     806          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_2c_bint, t_2c_binv, t_2c_der_pot, &
     807          84 :                                                             t_2c_inv, t_2c_metric, t_2c_work, &
     808          42 :                                                             t_3c_der_stack, t_3c_work_2, &
     809          42 :                                                             t_3c_work_3
     810          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :) :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
     811          84 :          t_2c_der_metric_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_AO, t_3c_der_AO_sub, t_3c_der_RI, &
     812          42 :          t_3c_der_RI_sub
     813             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     814          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     815          42 :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
     816             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     817             :       TYPE(virial_type), POINTER                         :: virial
     818             : 
     819          42 :       NULLIFY (para_env, para_env_sub, hfx_section, blacs_env_sub, dbcsr_template, force, atomic_kind_set, &
     820          42 :                virial, particle_set, cell)
     821             : 
     822          42 :       CALL timeset(routineN, handle)
     823             : 
     824          42 :       use_virial_prv = .FALSE.
     825          42 :       IF (PRESENT(use_virial)) use_virial_prv = use_virial
     826             : 
     827          42 :       IF (nspins == 1) THEN
     828          28 :          fac = 0.5_dp*hf_fraction
     829             :       ELSE
     830          14 :          fac = 1.0_dp*hf_fraction
     831             :       END IF
     832             : 
     833             :       CALL get_qs_env(qs_env, natom=natom, para_env=para_env, force=force, cell=cell, virial=virial, &
     834          42 :                       atomic_kind_set=atomic_kind_set, particle_set=particle_set)
     835          42 :       CALL get_atomic_kind_set(atomic_kind_set, kind_of=kind_of, atom_of_kind=atom_of_kind)
     836             : 
     837         126 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     838          42 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     839             : 
     840         126 :       ALLOCATE (idx_to_at_RI(SIZE(ri_data%bsizes_RI_split)))
     841          42 :       CALL get_idx_to_atom(idx_to_at_RI, ri_data%bsizes_RI_split, ri_data%bsizes_RI)
     842             : 
     843          42 :       nimg = ri_data%nimg
     844       11190 :       ALLOCATE (t_3c_der_RI(nimg, 3), t_3c_der_AO(nimg, 3), mat_der_pot(nimg, 3), t_2c_der_metric(natom, 3))
     845             : 
     846             :       !We assume that the integrals are available from the SCF
     847             :       !pre-calculate the derivs. 3c tensors as (P^0| sigma^a mu^0), with t_3c_der_AO holding deriv wrt mu^0
     848          42 :       CALL precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
     849             : 
     850             :       !Calculate the density matrix at each image
     851        2690 :       ALLOCATE (rho_ao_t(nspins, nimg))
     852             :       CALL create_2c_tensor(rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     853             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     854          42 :                             name="(AO | AO)")
     855          42 :       DEALLOCATE (dist1, dist2)
     856          42 :       IF (nspins == 2) CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(2, 1))
     857        1010 :       DO i_img = 2, nimg
     858        2130 :          DO i_spin = 1, nspins
     859        2088 :             CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(i_spin, i_img))
     860             :          END DO
     861             :       END DO
     862          42 :       CALL get_pmat_images(rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     863             : 
     864             :       !Contract integrals with the density matrix
     865        2690 :       ALLOCATE (t_3c_apc(nspins, nimg))
     866        1052 :       DO i_img = 1, nimg
     867        2228 :          DO i_spin = 1, nspins
     868        2186 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     869             :          END DO
     870             :       END DO
     871          42 :       CALL contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
     872             : 
     873             :       !Setup the subgroups
     874          42 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     875          42 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     876          42 :       group_size = para_env%num_pe/ngroups
     877          42 :       igroup = para_env%mepos/group_size
     878             : 
     879          42 :       ALLOCATE (para_env_sub)
     880          42 :       CALL para_env_sub%from_split(para_env, igroup)
     881          42 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     882             : 
     883             :       !Get the ususal sparsity pattern
     884         210 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     885          42 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     886          42 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     887             : 
     888             :       !Get the 2-center quantities in the subgroups (note: main group derivs are deleted wihtin)
     889           0 :       ALLOCATE (t_2c_inv(natom), mat_2c_pot(nimg), rho_ao_t_sub(nspins, nimg), t_2c_work(5), &
     890           0 :                 t_2c_der_metric_sub(natom, 3), mat_der_pot_sub(nimg, 3), t_2c_bint(natom), &
     891       10300 :                 t_2c_metric(natom), t_2c_binv(natom))
     892             :       CALL get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
     893             :                                   rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
     894          42 :                                   mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
     895          42 :       CALL dbt_create(t_2c_work(1), t_2c_R) !nRI x nRI
     896          42 :       CALL dbt_create(t_2c_work(5), t_2c_R_split) !nRI x nRI with split blocks
     897             : 
     898         504 :       ALLOCATE (t_2c_der_pot(3))
     899         168 :       DO i_xyz = 1, 3
     900         168 :          CALL dbt_create(t_2c_R, t_2c_der_pot(i_xyz))
     901             :       END DO
     902             : 
     903             :       !Get the 3-center quantities in the subgroups. The integrals and t_3c_apc already there
     904           0 :       ALLOCATE (t_3c_work_2(3), t_3c_work_3(4), t_3c_der_stack(6), t_3c_der_AO_sub(nimg, 3), &
     905       11312 :                 t_3c_der_RI_sub(nimg, 3), t_3c_apc_sub(nspins, nimg))
     906             :       CALL get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
     907             :                                   t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_stack, &
     908          42 :                                   group_size, ngroups, para_env, para_env_sub, ri_data)
     909             : 
     910             :       !Set up batched contraction (go atom by atom)
     911         126 :       ALLOCATE (batch_ranges_at(natom + 1))
     912          42 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     913          42 :       iatom = 0
     914         212 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     915         212 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     916          84 :             iatom = iatom + 1
     917          84 :             batch_ranges_at(iatom) = iblk
     918             :          END IF
     919             :       END DO
     920             : 
     921          42 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     922          42 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     923          42 :       CALL dbt_batched_contract_init(t_3c_work_3(3), batch_range_2=batch_ranges_at)
     924          42 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     925          42 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     926             : 
     927             :       !Preparing for the stacking of 3c tensors
     928          42 :       nimg_nze = ri_data%nimg_nze
     929          42 :       batch_size = ri_data%kp_stack_size
     930          42 :       n_batch = nimg_nze/batch_size
     931          42 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch = n_batch + 1
     932         126 :       ALLOCATE (batch_ranges_nze(n_batch + 1))
     933          94 :       DO i_batch = 1, n_batch
     934          94 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     935             :       END DO
     936          42 :       batch_ranges_nze(n_batch + 1) = nimg_nze + 1
     937             : 
     938             :       !Applying the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 from left and right
     939             :       !And keep the bump on LHS only version as well, with B*M^-1 = (M^-1*B)^T
     940         126 :       DO iatom = 1, natom
     941          84 :          CALL dbt_create(t_2c_inv(iatom), t_2c_binv(iatom))
     942          84 :          CALL dbt_copy(t_2c_inv(iatom), t_2c_binv(iatom))
     943          84 :          CALL apply_bump(t_2c_binv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     944         126 :          CALL apply_bump(t_2c_inv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
     945             :       END DO
     946             : 
     947          42 :       t1 = m_walltime()
     948          42 :       work_virial = 0.0_dp
     949         210 :       ALLOCATE (iapc_pairs(nimg, 2), i_images(nimg))
     950         210 :       ALLOCATE (force_pattern(natom, natom, nimg))
     951        7112 :       force_pattern(:, :, :) = -1
     952             :       !We proceed with 2 loops: one over the sparsity pattern from the SCF, one over the rest
     953             :       !We use the SCF cost model for the first loop, while we calculate the cost of the upcoming loop
     954         126 :       DO i_loop = 1, 2
     955        2104 :          DO b_img = 1, nimg
     956        6144 :             DO jatom = 1, natom
     957       14140 :                DO iatom = 1, natom
     958             : 
     959        8080 :                   pref = -0.5_dp*fac
     960        8080 :                   IF (i_loop == 1 .AND. (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     961        4619 :                   IF (i_loop == 2 .AND. (.NOT. force_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     962             : 
     963             :                   !Get the proper HFX potential 2c integrals (R_i^0|S_j^b), times (S_j^b|Q_j^b)^-1
     964        1098 :                   CALL timeset(routineN//"_2c_1", handle2)
     965             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     966             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     967        1098 :                                       dbcsr_template=dbcsr_template)
     968             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_inv(jatom), &
     969             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
     970             :                                     contract_1=[2], notcontract_1=[1], &
     971             :                                     contract_2=[1], notcontract_2=[2], &
     972        1098 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     973        1098 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(5), move_data=.TRUE.) !move to split blocks
     974        1098 :                   CALL dbt_filter(t_2c_work(5), ri_data%filter_eps)
     975        1098 :                   CALL timestop(handle2)
     976             : 
     977        1098 :                   CALL timeset(routineN//"_3c", handle2)
     978        5488 :                   bounds_iat(:, 1) = [SUM(ri_data%bsizes_AO(1:iatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:iatom))]
     979        5452 :                   bounds_jat(:, 1) = [SUM(ri_data%bsizes_AO(1:jatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:jatom))]
     980        1098 :                   CALL dbt_clear(t_2c_R_split)
     981             : 
     982        2435 :                   DO i_spin = 1, nspins
     983        2435 :                      CALL dbt_batched_contract_init(rho_ao_t_sub(i_spin, b_img))
     984             :                   END DO
     985             : 
     986        1098 :                   CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env, i_images) !i = a+c-b
     987        2667 :                   DO i_batch = 1, n_batch
     988             : 
     989             :                      !Stack the 3c derivatives to take the trace later on
     990        6276 :                      DO i_xyz = 1, 3
     991        4707 :                         CALL dbt_clear(t_3c_der_stack(i_xyz))
     992             :                         CALL fill_3c_stack(t_3c_der_stack(i_xyz), t_3c_der_RI_sub(:, i_xyz), &
     993             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     994             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
     995       14121 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     996             : 
     997        4707 :                         CALL dbt_clear(t_3c_der_stack(3 + i_xyz))
     998             :                         CALL fill_3c_stack(t_3c_der_stack(3 + i_xyz), t_3c_der_AO_sub(:, i_xyz), &
     999             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
    1000             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
    1001       15690 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1002             :                      END DO
    1003             : 
    1004        4475 :                      DO i_spin = 1, nspins
    1005             :                         !stack the t_3c_apc tensors
    1006        1808 :                         CALL dbt_clear(t_3c_work_2(3))
    1007             :                         CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
    1008             :                                            ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
    1009        5424 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1010        1808 :                         CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
    1011        1808 :                         IF (nze == 0) CYCLE
    1012        1808 :                         CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
    1013             : 
    1014             :                         !Contract with the second density matrix: P_mu^0,nu^b * t_3c_apc,
    1015             :                         !where t_3c_apc = P_sigma^a,lambda^a+c (mu^0 P^0 sigma^a) *(P^0|R^0)^-1 (stacked along a+c)
    1016             :                         CALL dbt_contract(1.0_dp, rho_ao_t_sub(i_spin, b_img), t_3c_work_2(1), &
    1017             :                                           0.0_dp, t_3c_work_2(2), map_1=[1], map_2=[2, 3], &
    1018             :                                           contract_1=[1], notcontract_1=[2], &
    1019             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1020             :                                           bounds_1=bounds_iat, bounds_2=bounds_jat, &
    1021        1808 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1022        1808 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1023             : 
    1024        1808 :                         CALL get_tensor_occupancy(t_3c_work_2(2), nze, occ)
    1025        1808 :                         IF (nze == 0) CYCLE
    1026             : 
    1027             :                         !Contract with V_PQ so that we can take the trace with (Q^b|nu^b lmabda^a+c)^(x)
    1028        1666 :                         CALL dbt_copy(t_3c_work_2(2), t_3c_work_3(1), order=[2, 1, 3], move_data=.TRUE.)
    1029        1666 :                         CALL dbt_batched_contract_init(t_2c_work(5))
    1030             :                         CALL dbt_contract(1.0_dp, t_2c_work(5), t_3c_work_3(1), &
    1031             :                                           0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
    1032             :                                           contract_1=[1], notcontract_1=[2], &
    1033             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1034        1666 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1035        1666 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1036        1666 :                         CALL dbt_batched_contract_finalize(t_2c_work(5))
    1037             : 
    1038             :                         !Contract with the 3c derivatives to get the force/virial
    1039        1666 :                         CALL dbt_copy(t_3c_work_3(2), t_3c_work_3(4), move_data=.TRUE.)
    1040        1666 :                         IF (use_virial_prv) THEN
    1041             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1042             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1043             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1044             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1045         257 :                                                         ri_data, qs_env, work_virial, cell, particle_set)
    1046             :                         ELSE
    1047             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1048             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1049             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1050             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1051        1409 :                                                         ri_data, qs_env)
    1052             :                         END IF
    1053        1666 :                         CALL dbt_clear(t_3c_work_3(4))
    1054             : 
    1055             :                         !Contract with the 3-center integrals in order to have a matrix R_PQ such that
    1056             :                         !we can take the trace sum_PQ R_PQ (P^0|Q^b)^(x)
    1057        1666 :                         IF (i_loop == 2) CYCLE
    1058             : 
    1059             :                         !Stack the 3c integrals
    1060             :                         CALL fill_3c_stack(t_3c_work_3(4), ri_data%kp_t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
    1061             :                                            filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
    1062        2634 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1063         878 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(3), move_data=.TRUE.)
    1064             : 
    1065         878 :                         CALL dbt_batched_contract_init(t_2c_R_split)
    1066             :                         CALL dbt_contract(1.0_dp, t_3c_work_3(1), t_3c_work_3(3), &
    1067             :                                           1.0_dp, t_2c_R_split, map_1=[1], map_2=[2], &
    1068             :                                           contract_1=[2, 3], notcontract_1=[1], &
    1069             :                                           contract_2=[2, 3], notcontract_2=[1], &
    1070         878 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1071         878 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1072         878 :                         CALL dbt_batched_contract_finalize(t_2c_R_split)
    1073        6063 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(1))
    1074             :                      END DO
    1075             :                   END DO
    1076        2435 :                   DO i_spin = 1, nspins
    1077        2435 :                      CALL dbt_batched_contract_finalize(rho_ao_t_sub(i_spin, b_img))
    1078             :                   END DO
    1079        1098 :                   CALL timestop(handle2)
    1080             : 
    1081        1098 :                   IF (i_loop == 2) CYCLE
    1082         579 :                   pref = 2.0_dp*pref
    1083         579 :                   IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp*pref
    1084             : 
    1085         579 :                   CALL timeset(routineN//"_2c_2", handle2)
    1086             :                   !Note that the derivatives are in atomic block format (not split)
    1087         579 :                   CALL dbt_copy(t_2c_R_split, t_2c_R, move_data=.TRUE.)
    1088             : 
    1089             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
    1090             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
    1091         579 :                                       dbcsr_template=dbcsr_template)
    1092             : 
    1093             :                   !We have to calculate: S^-1(iat) * R_PQ * S^-1(jat)    to trace with HFX pot der
    1094             :                   !                      + R_PQ * S^-1(jat) * pot^T      to trace with S^(x) (iat)
    1095             :                   !                      + pot^T * S^-1(iat) *R_PQ       to trace with S^(x) (jat)
    1096             : 
    1097             :                   !Because 3c tensors are all precontracted with the inverse RI metric,
    1098             :                   !t_2c_R is currently implicitely multiplied by S^-1(iat) from the left
    1099             :                   !and S^-1(jat) from the right, directly in the proper format for the trace
    1100             :                   !with the HFX potential derivative
    1101             : 
    1102             :                   !Trace with HFX pot deriv, that we need to build first
    1103        2316 :                   DO i_xyz = 1, 3
    1104             :                      CALL get_ext_2c_int(t_2c_der_pot(i_xyz), mat_der_pot_sub(:, i_xyz), iatom, jatom, &
    1105             :                                          b_img, ri_data, qs_env, blacs_env_ext=blacs_env_sub, &
    1106        2316 :                                          para_env_ext=para_env_sub, dbcsr_template=dbcsr_template)
    1107             :                   END DO
    1108             : 
    1109         579 :                   IF (use_virial_prv) THEN
    1110             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1111         113 :                                            b_img, pref, ri_data, qs_env, work_virial, cell, particle_set)
    1112             :                   ELSE
    1113             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1114         466 :                                            b_img, pref, ri_data, qs_env)
    1115             :                   END IF
    1116             : 
    1117        2316 :                   DO i_xyz = 1, 3
    1118        2316 :                      CALL dbt_clear(t_2c_der_pot(i_xyz))
    1119             :                   END DO
    1120             : 
    1121             :                   !R_PQ * S^-1(jat) * pot^T  (=A)
    1122             :                   CALL dbt_contract(1.0_dp, t_2c_metric(iatom), t_2c_R, & !get rid of implicit S^-1(iat)
    1123             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1124             :                                     contract_1=[2], notcontract_1=[1], &
    1125             :                                     contract_2=[1], notcontract_2=[2], &
    1126         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1127         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1128             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_work(1), &
    1129             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1130             :                                     contract_1=[2], notcontract_1=[1], &
    1131             :                                     contract_2=[2], notcontract_2=[1], &
    1132         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1133         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1134             : 
    1135             :                   !With the RI bump function, things get more complex. M = (S|P)_D + B*(S|P)_OD*B
    1136             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1137             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(iatom), &
    1138             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1139             :                                     contract_1=[2], notcontract_1=[1], &
    1140             :                                     contract_2=[1], notcontract_2=[2], &
    1141         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1142         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1143             : 
    1144             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1145             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1146             :                                     contract_1=[1], notcontract_1=[2], &
    1147             :                                     contract_2=[1], notcontract_2=[2], &
    1148         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1149         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1150             : 
    1151         579 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1152             :                   CALL get_2c_bump_forces(force, t_2c_work(4), iatom, atom_of_kind, kind_of, pref, &
    1153         579 :                                           ri_data, qs_env, work_virial)
    1154             : 
    1155             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1156             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(2), &
    1157             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1158             :                                     contract_1=[1], notcontract_1=[2], &
    1159             :                                     contract_2=[1], notcontract_2=[2], &
    1160         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1161         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1162             : 
    1163         579 :                   IF (use_virial_prv) THEN
    1164             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1165             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1166         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1167             :                   ELSE
    1168             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1169         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1170             :                   END IF
    1171             : 
    1172             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1173         579 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1174         579 :                   CALL apply_bump(t_2c_work(2), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1175             : 
    1176         579 :                   IF (use_virial_prv) THEN
    1177             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1178             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1179         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1180             :                   ELSE
    1181             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1182         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1183             :                   END IF
    1184             : 
    1185             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1186             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1187             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(iatom), &
    1188             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1189             :                                     contract_1=[2], notcontract_1=[1], &
    1190             :                                     contract_2=[1], notcontract_2=[2], &
    1191         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1192         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1193             : 
    1194             :                   CALL dbt_contract(1.0_dp, t_2c_bint(iatom), t_2c_work(4), &
    1195             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1196             :                                     contract_1=[1], notcontract_1=[2], &
    1197             :                                     contract_2=[1], notcontract_2=[2], &
    1198         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1199         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1200             : 
    1201             :                   CALL get_2c_bump_forces(force, t_2c_work(2), iatom, atom_of_kind, kind_of, -pref, &
    1202         579 :                                           ri_data, qs_env, work_virial)
    1203             : 
    1204             :                   ! pot^T * S^-1(iat) * R_PQ (=A)
    1205             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_R, &
    1206             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1207             :                                     contract_1=[1], notcontract_1=[2], &
    1208             :                                     contract_2=[1], notcontract_2=[2], &
    1209         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1210         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1211             : 
    1212             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_metric(jatom), & !get rid of implicit S^-1(jat)
    1213             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1214             :                                     contract_1=[2], notcontract_1=[1], &
    1215             :                                     contract_2=[1], notcontract_2=[2], &
    1216         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1217         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1218             : 
    1219             :                   !Do the same shenanigans with the S^(x) (jatom)
    1220             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1221             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(jatom), &
    1222             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1223             :                                     contract_1=[2], notcontract_1=[1], &
    1224             :                                     contract_2=[1], notcontract_2=[2], &
    1225         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1226         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1227             : 
    1228             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1229             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1230             :                                     contract_1=[1], notcontract_1=[2], &
    1231             :                                     contract_2=[1], notcontract_2=[2], &
    1232         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1233         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1234             : 
    1235         579 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1236             :                   CALL get_2c_bump_forces(force, t_2c_work(4), jatom, atom_of_kind, kind_of, pref, &
    1237         579 :                                           ri_data, qs_env, work_virial)
    1238             : 
    1239             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1240             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(2), &
    1241             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1242             :                                     contract_1=[1], notcontract_1=[2], &
    1243             :                                     contract_2=[1], notcontract_2=[2], &
    1244         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1245         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1246             : 
    1247         579 :                   IF (use_virial_prv) THEN
    1248             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1249             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1250         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1251             :                   ELSE
    1252             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1253         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1254             :                   END IF
    1255             : 
    1256             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1257         579 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1258         579 :                   CALL apply_bump(t_2c_work(2), jatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1259             : 
    1260         579 :                   IF (use_virial_prv) THEN
    1261             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1262             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1263         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1264             :                   ELSE
    1265             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1266         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1267             :                   END IF
    1268             : 
    1269             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1270             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1271             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(jatom), &
    1272             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1273             :                                     contract_1=[2], notcontract_1=[1], &
    1274             :                                     contract_2=[1], notcontract_2=[2], &
    1275         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1276         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1277             : 
    1278             :                   CALL dbt_contract(1.0_dp, t_2c_bint(jatom), t_2c_work(4), &
    1279             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1280             :                                     contract_1=[1], notcontract_1=[2], &
    1281             :                                     contract_2=[1], notcontract_2=[2], &
    1282         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1283         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1284             : 
    1285             :                   CALL get_2c_bump_forces(force, t_2c_work(2), jatom, atom_of_kind, kind_of, -pref, &
    1286         579 :                                           ri_data, qs_env, work_virial)
    1287             : 
    1288       14376 :                   CALL timestop(handle2)
    1289             :                END DO !iatom
    1290             :             END DO !jatom
    1291             :          END DO !b_img
    1292             : 
    1293         126 :          IF (i_loop == 1) THEN
    1294          42 :             CALL update_pattern_to_forces(force_pattern, sparsity_pattern, ngroups, ri_data, qs_env)
    1295             :          END IF
    1296             :       END DO !i_loop
    1297             : 
    1298          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
    1299          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
    1300          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(3))
    1301          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
    1302          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
    1303             : 
    1304          42 :       IF (use_virial_prv) THEN
    1305          32 :          DO k_xyz = 1, 3
    1306         104 :             DO j_xyz = 1, 3
    1307         312 :                DO i_xyz = 1, 3
    1308             :                   virial%pv_fock_4c(i_xyz, j_xyz) = virial%pv_fock_4c(i_xyz, j_xyz) &
    1309         288 :                                                     + work_virial(i_xyz, k_xyz)*cell%hmat(j_xyz, k_xyz)
    1310             :                END DO
    1311             :             END DO
    1312             :          END DO
    1313             :       END IF
    1314             : 
    1315             :       !End of subgroup parallelization
    1316          42 :       CALL cp_blacs_env_release(blacs_env_sub)
    1317          42 :       CALL para_env_sub%free()
    1318          42 :       DEALLOCATE (para_env_sub)
    1319             : 
    1320          42 :       CALL para_env%sync()
    1321          42 :       t2 = m_walltime()
    1322          42 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    1323             : 
    1324             :       !clean-up
    1325          42 :       IF (ASSOCIATED(dbcsr_template)) THEN
    1326          42 :          CALL dbcsr_release(dbcsr_template)
    1327          42 :          DEALLOCATE (dbcsr_template)
    1328             :       END IF
    1329          42 :       CALL dbt_destroy(t_2c_R)
    1330          42 :       CALL dbt_destroy(t_2c_R_split)
    1331          42 :       CALL dbt_destroy(t_2c_work(1))
    1332          42 :       CALL dbt_destroy(t_2c_work(2))
    1333          42 :       CALL dbt_destroy(t_2c_work(3))
    1334          42 :       CALL dbt_destroy(t_2c_work(4))
    1335          42 :       CALL dbt_destroy(t_2c_work(5))
    1336          42 :       CALL dbt_destroy(t_3c_work_2(1))
    1337          42 :       CALL dbt_destroy(t_3c_work_2(2))
    1338          42 :       CALL dbt_destroy(t_3c_work_2(3))
    1339          42 :       CALL dbt_destroy(t_3c_work_3(1))
    1340          42 :       CALL dbt_destroy(t_3c_work_3(2))
    1341          42 :       CALL dbt_destroy(t_3c_work_3(3))
    1342          42 :       CALL dbt_destroy(t_3c_work_3(4))
    1343          42 :       CALL dbt_destroy(t_3c_der_stack(1))
    1344          42 :       CALL dbt_destroy(t_3c_der_stack(2))
    1345          42 :       CALL dbt_destroy(t_3c_der_stack(3))
    1346          42 :       CALL dbt_destroy(t_3c_der_stack(4))
    1347          42 :       CALL dbt_destroy(t_3c_der_stack(5))
    1348          42 :       CALL dbt_destroy(t_3c_der_stack(6))
    1349         168 :       DO i_xyz = 1, 3
    1350         168 :          CALL dbt_destroy(t_2c_der_pot(i_xyz))
    1351             :       END DO
    1352         126 :       DO iatom = 1, natom
    1353          84 :          CALL dbt_destroy(t_2c_inv(iatom))
    1354          84 :          CALL dbt_destroy(t_2c_binv(iatom))
    1355          84 :          CALL dbt_destroy(t_2c_bint(iatom))
    1356          84 :          CALL dbt_destroy(t_2c_metric(iatom))
    1357         378 :          DO i_xyz = 1, 3
    1358         336 :             CALL dbt_destroy(t_2c_der_metric_sub(iatom, i_xyz))
    1359             :          END DO
    1360             :       END DO
    1361        1052 :       DO i_img = 1, nimg
    1362        1010 :          CALL dbcsr_release(mat_2c_pot(i_img))
    1363        2228 :          DO i_spin = 1, nspins
    1364        1176 :             CALL dbt_destroy(rho_ao_t_sub(i_spin, i_img))
    1365        2186 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
    1366             :          END DO
    1367             :       END DO
    1368         168 :       DO i_xyz = 1, 3
    1369        3198 :          DO i_img = 1, nimg
    1370        3030 :             CALL dbt_destroy(t_3c_der_RI_sub(i_img, i_xyz))
    1371        3030 :             CALL dbt_destroy(t_3c_der_AO_sub(i_img, i_xyz))
    1372        3156 :             CALL dbcsr_release(mat_der_pot_sub(i_img, i_xyz))
    1373             :          END DO
    1374             :       END DO
    1375             : 
    1376          42 :       CALL timestop(handle)
    1377             : 
    1378       18756 :    END SUBROUTINE hfx_ri_update_forces_kp
    1379             : 
    1380             : ! **************************************************************************************************
    1381             : !> \brief A routine the applies the RI bump matrix from the left and/or the right, given an input
    1382             : !>        matrix and the central RI atom. We assume atomic block sizes
    1383             : !> \param t_2c_inout ...
    1384             : !> \param atom_i ...
    1385             : !> \param ri_data ...
    1386             : !> \param qs_env ...
    1387             : !> \param from_left ...
    1388             : !> \param from_right ...
    1389             : !> \param debump ...
    1390             : ! **************************************************************************************************
    1391        1746 :    SUBROUTINE apply_bump(t_2c_inout, atom_i, ri_data, qs_env, from_left, from_right, debump)
    1392             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_inout
    1393             :       INTEGER, INTENT(IN)                                :: atom_i
    1394             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1395             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1396             :       LOGICAL, INTENT(IN), OPTIONAL                      :: from_left, from_right, debump
    1397             : 
    1398             :       INTEGER                                            :: i_img, i_RI, iatom, ind(2), j_img, j_RI, &
    1399             :                                                             jatom, natom, nblks(2), nimg, nkind
    1400        1746 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1401        1746 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1402             :       LOGICAL                                            :: found, my_debump, my_left, my_right
    1403             :       REAL(dp)                                           :: bval, r0, r1, ri(3), rj(3), rref(3), &
    1404             :                                                             scoord(3)
    1405        1746 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1406             :       TYPE(cell_type), POINTER                           :: cell
    1407             :       TYPE(dbt_iterator_type)                            :: iter
    1408             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1409        1746 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1410        1746 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1411             : 
    1412        1746 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1413             : 
    1414             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1415        1746 :                       kpoints=kpoints, particle_set=particle_set)
    1416        1746 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1417             : 
    1418        1746 :       my_debump = .FALSE.
    1419        1746 :       IF (PRESENT(debump)) my_debump = debump
    1420             : 
    1421        1746 :       my_left = .FALSE.
    1422        1746 :       IF (PRESENT(from_left)) my_left = from_left
    1423             : 
    1424        1746 :       my_right = .FALSE.
    1425        1746 :       IF (PRESENT(from_right)) my_right = from_right
    1426        1746 :       CPASSERT(my_left .OR. my_right)
    1427             : 
    1428        1746 :       CALL dbt_get_info(t_2c_inout, nblks_total=nblks)
    1429        1746 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1430        1746 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1431             : 
    1432        1746 :       nimg = ri_data%nimg
    1433             : 
    1434             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1435        1746 :       r1 = ri_data%kp_RI_range
    1436        1746 :       r0 = ri_data%kp_bump_rad
    1437        1746 :       rref = pbc(particle_set(atom_i)%r, cell)
    1438             : 
    1439             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_inout,natom,ri_data,cell,particle_set,index_to_cell,my_left, &
    1440             : !$OMP                               my_right,r0,r1,rref,my_debump) &
    1441        1746 : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,bval)
    1442             :       CALL dbt_iterator_start(iter, t_2c_inout)
    1443             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1444             :          CALL dbt_iterator_next_block(iter, ind)
    1445             :          CALL dbt_get_block(t_2c_inout, ind, blk, found)
    1446             :          IF (.NOT. found) CYCLE
    1447             : 
    1448             :          i_RI = (ind(1) - 1)/natom + 1
    1449             :          i_img = ri_data%RI_cell_to_img(i_RI)
    1450             :          iatom = ind(1) - (i_RI - 1)*natom
    1451             : 
    1452             :          CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    1453             :          CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    1454             : 
    1455             :          j_RI = (ind(2) - 1)/natom + 1
    1456             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1457             :          jatom = ind(2) - (j_RI - 1)*natom
    1458             : 
    1459             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1460             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1461             : 
    1462             :          IF (.NOT. my_debump) THEN
    1463             :             IF (my_left) blk(:, :) = blk(:, :)*bump(NORM2(ri - rref), r0, r1)
    1464             :             IF (my_right) blk(:, :) = blk(:, :)*bump(NORM2(rj - rref), r0, r1)
    1465             :          ELSE
    1466             :             !Note: by construction, the bump function is never quite zero, as its range is the same
    1467             :             !      as that of the extended RI basis (but we are safe)
    1468             :             bval = bump(NORM2(ri - rref), r0, r1)
    1469             :             IF (my_left .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1470             :             bval = bump(NORM2(rj - rref), r0, r1)
    1471             :             IF (my_right .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1472             :          END IF
    1473             : 
    1474             :          CALL dbt_put_block(t_2c_inout, ind, SHAPE(blk), blk)
    1475             : 
    1476             :          DEALLOCATE (blk)
    1477             :       END DO
    1478             :       CALL dbt_iterator_stop(iter)
    1479             : !$OMP END PARALLEL
    1480        1746 :       CALL dbt_filter(t_2c_inout, ri_data%filter_eps)
    1481             : 
    1482        3492 :    END SUBROUTINE apply_bump
    1483             : 
    1484             : ! **************************************************************************************************
    1485             : !> \brief A routine that calculates the forces due to the derivative of the bump function
    1486             : !> \param force ...
    1487             : !> \param t_2c_in ...
    1488             : !> \param atom_i ...
    1489             : !> \param atom_of_kind ...
    1490             : !> \param kind_of ...
    1491             : !> \param pref ...
    1492             : !> \param ri_data ...
    1493             : !> \param qs_env ...
    1494             : !> \param work_virial ...
    1495             : ! **************************************************************************************************
    1496        2316 :    SUBROUTINE get_2c_bump_forces(force, t_2c_in, atom_i, atom_of_kind, kind_of, pref, ri_data, &
    1497             :                                  qs_env, work_virial)
    1498             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    1499             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_in
    1500             :       INTEGER, INTENT(IN)                                :: atom_i
    1501             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    1502             :       REAL(dp), INTENT(IN)                               :: pref
    1503             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1504             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1505             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT)           :: work_virial
    1506             : 
    1507             :       INTEGER :: i, i_img, i_RI, i_xyz, iat_of_kind, iatom, ikind, ind(2), j_img, j_RI, j_xyz, &
    1508             :          jat_of_kind, jatom, jkind, natom, nblks(2), nimg, nkind
    1509        2316 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1510        2316 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1511             :       LOGICAL                                            :: found
    1512             :       REAL(dp)                                           :: new_force, r0, r1, ri(3), rj(3), &
    1513             :                                                             rref(3), scoord(3), x
    1514        2316 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1515             :       TYPE(cell_type), POINTER                           :: cell
    1516             :       TYPE(dbt_iterator_type)                            :: iter
    1517             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1518        2316 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1519        2316 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1520             : 
    1521        2316 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1522             : 
    1523             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1524        2316 :                       kpoints=kpoints, particle_set=particle_set)
    1525        2316 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1526             : 
    1527        2316 :       CALL dbt_get_info(t_2c_in, nblks_total=nblks)
    1528        2316 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1529        2316 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1530             : 
    1531        2316 :       nimg = ri_data%nimg
    1532             : 
    1533             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1534        2316 :       r1 = ri_data%kp_RI_range
    1535        2316 :       r0 = ri_data%kp_bump_rad
    1536        2316 :       rref = pbc(particle_set(atom_i)%r, cell)
    1537             : 
    1538        2316 :       iat_of_kind = atom_of_kind(atom_i)
    1539        2316 :       ikind = kind_of(atom_i)
    1540             : 
    1541             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_in,natom,ri_data,cell,particle_set,index_to_cell,pref, &
    1542             : !$OMP force,r0,r1,rref,atom_of_kind,kind_of,iat_of_kind,ikind,work_virial) &
    1543             : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,jkind,jat_of_kind, &
    1544        2316 : !$OMP         new_force,i_xyz,i,x,j_xyz)
    1545             :       CALL dbt_iterator_start(iter, t_2c_in)
    1546             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1547             :          CALL dbt_iterator_next_block(iter, ind)
    1548             :          IF (ind(1) .NE. ind(2)) CYCLE !bump matrix is diagonal
    1549             : 
    1550             :          CALL dbt_get_block(t_2c_in, ind, blk, found)
    1551             :          IF (.NOT. found) CYCLE
    1552             : 
    1553             :          !bump is a function of x = SQRT((R - Rref)^2). We refer to R as jatom, and Rref as atom_i
    1554             :          j_RI = (ind(2) - 1)/natom + 1
    1555             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1556             :          jatom = ind(2) - (j_RI - 1)*natom
    1557             :          jat_of_kind = atom_of_kind(jatom)
    1558             :          jkind = kind_of(jatom)
    1559             : 
    1560             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1561             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1562             :          x = NORM2(rj - rref)
    1563             :          IF (x < r0 .OR. x > r1) CYCLE
    1564             : 
    1565             :          new_force = 0.0_dp
    1566             :          DO i = 1, SIZE(blk, 1)
    1567             :             new_force = new_force + blk(i, i)
    1568             :          END DO
    1569             :          new_force = pref*new_force*dbump(x, r0, r1)
    1570             : 
    1571             :          !x = SQRT((R - Rref)^2), so we multiply by dx/dR and dx/dRref
    1572             :          DO i_xyz = 1, 3
    1573             :             !Force acting on second atom
    1574             : !$OMP ATOMIC
    1575             :             force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) + &
    1576             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1577             : 
    1578             :             !virial acting on second atom
    1579             :             CALL real_to_scaled(scoord, rj, cell)
    1580             :             DO j_xyz = 1, 3
    1581             : !$OMP ATOMIC
    1582             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1583             :                                            + new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1584             :             END DO
    1585             : 
    1586             :             !Force acting on reference atom, defining the RI basis
    1587             : !$OMP ATOMIC
    1588             :             force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) - &
    1589             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1590             : 
    1591             :             !virial of ref atom
    1592             :             CALL real_to_scaled(scoord, rref, cell)
    1593             :             DO j_xyz = 1, 3
    1594             : !$OMP ATOMIC
    1595             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1596             :                                            - new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1597             :             END DO
    1598             :          END DO !i_xyz
    1599             : 
    1600             :          DEALLOCATE (blk)
    1601             :       END DO
    1602             :       CALL dbt_iterator_stop(iter)
    1603             : !$OMP END PARALLEL
    1604             : 
    1605        4632 :    END SUBROUTINE get_2c_bump_forces
    1606             : 
    1607             : ! **************************************************************************************************
    1608             : !> \brief The bumb function as defined by Juerg
    1609             : !> \param x ...
    1610             : !> \param r0 ...
    1611             : !> \param r1 ...
    1612             : !> \return ...
    1613             : ! **************************************************************************************************
    1614       25181 :    FUNCTION bump(x, r0, r1) RESULT(b)
    1615             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1616             :       REAL(dp)                                           :: b
    1617             : 
    1618             :       REAL(dp)                                           :: r
    1619             : 
    1620             :       !Head-Gordon
    1621             :       !b = 1.0_dp/(1.0_dp+EXP((r1-r0)/(r1-x)-(r1-r0)/(x-r0)))
    1622             :       !Juerg
    1623       25181 :       r = (x - r0)/(r1 - r0)
    1624       25181 :       b = -6.0_dp*r**5 + 15.0_dp*r**4 - 10.0_dp*r**3 + 1.0_dp
    1625       25181 :       IF (x .GE. r1) b = 0.0_dp
    1626       25181 :       IF (x .LE. r0) b = 1.0_dp
    1627             : 
    1628       25181 :    END FUNCTION bump
    1629             : 
    1630             : ! **************************************************************************************************
    1631             : !> \brief The derivative of the bump function
    1632             : !> \param x ...
    1633             : !> \param r0 ...
    1634             : !> \param r1 ...
    1635             : !> \return ...
    1636             : ! **************************************************************************************************
    1637         525 :    FUNCTION dbump(x, r0, r1) RESULT(b)
    1638             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1639             :       REAL(dp)                                           :: b
    1640             : 
    1641             :       REAL(dp)                                           :: r
    1642             : 
    1643         525 :       r = (x - r0)/(r1 - r0)
    1644         525 :       b = (-30.0_dp*r**4 + 60.0_dp*r**3 - 30.0_dp*r**2)/(r1 - r0)
    1645         525 :       IF (x .GE. r1) b = 0.0_dp
    1646         525 :       IF (x .LE. r0) b = 0.0_dp
    1647             : 
    1648         525 :    END FUNCTION dbump
    1649             : 
    1650             : ! **************************************************************************************************
    1651             : !> \brief return the cell index a+c corresponding to given cell index i and b, with i = a+c-b
    1652             : !> \param i_index ...
    1653             : !> \param b_index ...
    1654             : !> \param qs_env ...
    1655             : !> \return ...
    1656             : ! **************************************************************************************************
    1657      151439 :    FUNCTION get_apc_index_from_ib(i_index, b_index, qs_env) RESULT(apc_index)
    1658             :       INTEGER, INTENT(IN)                                :: i_index, b_index
    1659             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1660             :       INTEGER                                            :: apc_index
    1661             : 
    1662             :       INTEGER, DIMENSION(3)                              :: cell_apc
    1663      151439 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1664      151439 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1665             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1666             : 
    1667      151439 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1668      151439 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1669             : 
    1670             :       !i = a+c-b => a+c = i+b
    1671      605756 :       cell_apc(:) = index_to_cell(:, i_index) + index_to_cell(:, b_index)
    1672             : 
    1673     1037869 :       IF (ANY([cell_apc(1), cell_apc(2), cell_apc(3)] < LBOUND(cell_to_index)) .OR. &
    1674             :           ANY([cell_apc(1), cell_apc(2), cell_apc(3)] > UBOUND(cell_to_index))) THEN
    1675             : 
    1676             :          apc_index = 0
    1677             :       ELSE
    1678      132550 :          apc_index = cell_to_index(cell_apc(1), cell_apc(2), cell_apc(3))
    1679             :       END IF
    1680             : 
    1681      151439 :    END FUNCTION get_apc_index_from_ib
    1682             : 
    1683             : ! **************************************************************************************************
    1684             : !> \brief return the cell index i corresponding to the summ of cell_a and cell_c
    1685             : !> \param a_index ...
    1686             : !> \param c_index ...
    1687             : !> \param qs_env ...
    1688             : !> \return ...
    1689             : ! **************************************************************************************************
    1690           0 :    FUNCTION get_apc_index(a_index, c_index, qs_env) RESULT(i_index)
    1691             :       INTEGER, INTENT(IN)                                :: a_index, c_index
    1692             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1693             :       INTEGER                                            :: i_index
    1694             : 
    1695             :       INTEGER, DIMENSION(3)                              :: cell_i
    1696           0 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1697           0 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1698             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1699             : 
    1700           0 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1701           0 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1702             : 
    1703           0 :       cell_i(:) = index_to_cell(:, a_index) + index_to_cell(:, c_index)
    1704             : 
    1705           0 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1706             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1707             : 
    1708             :          i_index = 0
    1709             :       ELSE
    1710           0 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1711             :       END IF
    1712             : 
    1713           0 :    END FUNCTION get_apc_index
    1714             : 
    1715             : ! **************************************************************************************************
    1716             : !> \brief return the cell index i corresponding to the summ of cell_a + cell_c - cell_b
    1717             : !> \param apc_index ...
    1718             : !> \param b_index ...
    1719             : !> \param qs_env ...
    1720             : !> \return ...
    1721             : ! **************************************************************************************************
    1722      510764 :    FUNCTION get_i_index(apc_index, b_index, qs_env) RESULT(i_index)
    1723             :       INTEGER, INTENT(IN)                                :: apc_index, b_index
    1724             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1725             :       INTEGER                                            :: i_index
    1726             : 
    1727             :       INTEGER, DIMENSION(3)                              :: cell_i
    1728      510764 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1729      510764 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1730             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1731             : 
    1732      510764 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1733      510764 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1734             : 
    1735     2043056 :       cell_i(:) = index_to_cell(:, apc_index) - index_to_cell(:, b_index)
    1736             : 
    1737     3491940 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1738             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1739             : 
    1740             :          i_index = 0
    1741             :       ELSE
    1742      438348 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1743             :       END IF
    1744             : 
    1745      510764 :    END FUNCTION get_i_index
    1746             : 
    1747             : ! **************************************************************************************************
    1748             : !> \brief A routine that returns all allowed a,c pairs such that a+c images corresponds to the value
    1749             : !>        of the apc_index input. Takes into account that image a corresponds to 3c integrals, which
    1750             : !>        are ordered in their own way
    1751             : !> \param ac_pairs ...
    1752             : !> \param apc_index ...
    1753             : !> \param ri_data ...
    1754             : !> \param qs_env ...
    1755             : ! **************************************************************************************************
    1756       15228 :    SUBROUTINE get_ac_pairs(ac_pairs, apc_index, ri_data, qs_env)
    1757             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: ac_pairs
    1758             :       INTEGER, INTENT(IN)                                :: apc_index
    1759             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1760             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1761             : 
    1762             :       INTEGER                                            :: a_index, actual_img, c_index, nimg
    1763             : 
    1764       15228 :       nimg = SIZE(ac_pairs, 1)
    1765             : 
    1766     1067212 :       ac_pairs(:, :) = 0
    1767             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(ac_pairs,nimg,ri_data,qs_env,apc_index) &
    1768       15228 : !$OMP PRIVATE(a_index,actual_img,c_index)
    1769             :       DO a_index = 1, nimg
    1770             :          actual_img = ri_data%idx_to_img(a_index)
    1771             :          !c = a+c - a
    1772             :          c_index = get_i_index(apc_index, actual_img, qs_env)
    1773             :          ac_pairs(a_index, 1) = a_index
    1774             :          ac_pairs(a_index, 2) = c_index
    1775             :       END DO
    1776             : !$OMP END PARALLEL DO
    1777             : 
    1778       15228 :    END SUBROUTINE get_ac_pairs
    1779             : 
    1780             : ! **************************************************************************************************
    1781             : !> \brief A routine that returns all allowed i,a+c pairs such that, for the given value of b, we have
    1782             : !>        i = a+c-b. Takes into account that image i corrsponds to the 3c ints, which are ordered in
    1783             : !>        their own way
    1784             : !> \param iapc_pairs ...
    1785             : !> \param b_index ...
    1786             : !> \param ri_data ...
    1787             : !> \param qs_env ...
    1788             : !> \param actual_i_img ...
    1789             : ! **************************************************************************************************
    1790        4330 :    SUBROUTINE get_iapc_pairs(iapc_pairs, b_index, ri_data, qs_env, actual_i_img)
    1791             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: iapc_pairs
    1792             :       INTEGER, INTENT(IN)                                :: b_index
    1793             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1794             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1795             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: actual_i_img
    1796             : 
    1797             :       INTEGER                                            :: actual_img, apc_index, i_index, nimg
    1798             : 
    1799        4330 :       nimg = SIZE(iapc_pairs, 1)
    1800       43396 :       IF (PRESENT(actual_i_img)) actual_i_img(:) = 0
    1801             : 
    1802      315868 :       iapc_pairs(:, :) = 0
    1803             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(iapc_pairs,nimg,ri_data,qs_env,b_index,actual_i_img) &
    1804        4330 : !$OMP PRIVATE(i_index,actual_img,apc_index)
    1805             :       DO i_index = 1, nimg
    1806             :          actual_img = ri_data%idx_to_img(i_index)
    1807             :          apc_index = get_apc_index_from_ib(actual_img, b_index, qs_env)
    1808             :          IF (apc_index == 0) CYCLE
    1809             :          iapc_pairs(i_index, 1) = i_index
    1810             :          iapc_pairs(i_index, 2) = apc_index
    1811             :          IF (PRESENT(actual_i_img)) actual_i_img(i_index) = actual_img
    1812             :       END DO
    1813             : 
    1814        4330 :    END SUBROUTINE get_iapc_pairs
    1815             : 
    1816             : ! **************************************************************************************************
    1817             : !> \brief A function that, given a cell index a, returun the index corresponding to -a, and zero if
    1818             : !>        if out of bounds
    1819             : !> \param a_index ...
    1820             : !> \param qs_env ...
    1821             : !> \return ...
    1822             : ! **************************************************************************************************
    1823       61137 :    FUNCTION get_opp_index(a_index, qs_env) RESULT(opp_index)
    1824             :       INTEGER, INTENT(IN)                                :: a_index
    1825             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1826             :       INTEGER                                            :: opp_index
    1827             : 
    1828             :       INTEGER, DIMENSION(3)                              :: opp_cell
    1829       61137 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1830       61137 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1831             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1832             : 
    1833       61137 :       NULLIFY (kpoints, cell_to_index, index_to_cell)
    1834             : 
    1835       61137 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1836       61137 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1837             : 
    1838      244548 :       opp_cell(:) = -index_to_cell(:, a_index)
    1839             : 
    1840      427959 :       IF (ANY([opp_cell(1), opp_cell(2), opp_cell(3)] < LBOUND(cell_to_index)) .OR. &
    1841             :           ANY([opp_cell(1), opp_cell(2), opp_cell(3)] > UBOUND(cell_to_index))) THEN
    1842             : 
    1843             :          opp_index = 0
    1844             :       ELSE
    1845       61137 :          opp_index = cell_to_index(opp_cell(1), opp_cell(2), opp_cell(3))
    1846             :       END IF
    1847             : 
    1848       61137 :    END FUNCTION get_opp_index
    1849             : 
    1850             : ! **************************************************************************************************
    1851             : !> \brief A routine that returns the actual non-symemtric density matrix for each image, by Fourier
    1852             : !>        transforming the kpoint density matrix
    1853             : !> \param rho_ao_t ...
    1854             : !> \param rho_ao ...
    1855             : !> \param scale_prev_p ...
    1856             : !> \param ri_data ...
    1857             : !> \param qs_env ...
    1858             : ! **************************************************************************************************
    1859         422 :    SUBROUTINE get_pmat_images(rho_ao_t, rho_ao, scale_prev_p, ri_data, qs_env)
    1860             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t
    1861             :       TYPE(dbcsr_p_type), DIMENSION(:, :), INTENT(INOUT) :: rho_ao
    1862             :       REAL(dp), INTENT(IN)                               :: scale_prev_p
    1863             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1864             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1865             : 
    1866             :       INTEGER                                            :: cell_j(3), i_img, i_spin, iatom, icol, &
    1867             :                                                             irow, j_img, jatom, mi_img, mj_img, &
    1868             :                                                             nimg, nspins
    1869         422 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1870             :       LOGICAL                                            :: found
    1871             :       REAL(dp)                                           :: fac
    1872         422 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock, pblock_desymm
    1873         422 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks, rho_desymm
    1874        3798 :       TYPE(dbt_type)                                     :: tmp
    1875             :       TYPE(dft_control_type), POINTER                    :: dft_control
    1876             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1877             :       TYPE(neighbor_list_iterator_p_type), &
    1878         422 :          DIMENSION(:), POINTER                           :: nl_iterator
    1879             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    1880         422 :          POINTER                                         :: sab_nl, sab_nl_nosym
    1881             :       TYPE(qs_scf_env_type), POINTER                     :: scf_env
    1882             : 
    1883         422 :       NULLIFY (rho_desymm, kpoints, sab_nl_nosym, scf_env, matrix_ks, dft_control, &
    1884         422 :                sab_nl, nl_iterator, cell_to_index, pblock, pblock_desymm)
    1885             : 
    1886         422 :       CALL get_qs_env(qs_env, kpoints=kpoints, scf_env=scf_env, matrix_ks_kp=matrix_ks, dft_control=dft_control)
    1887         422 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=sab_nl_nosym, cell_to_index=cell_to_index, sab_nl=sab_nl)
    1888             : 
    1889         422 :       IF (dft_control%do_admm) THEN
    1890         204 :          CALL get_admm_env(qs_env%admm_env, matrix_ks_aux_fit_kp=matrix_ks)
    1891             :       END IF
    1892             : 
    1893         422 :       nspins = SIZE(matrix_ks, 1)
    1894         422 :       nimg = ri_data%nimg
    1895             : 
    1896       26138 :       ALLOCATE (rho_desymm(nspins, nimg))
    1897       11040 :       DO i_img = 1, nimg
    1898       24872 :          DO i_spin = 1, nspins
    1899       13832 :             ALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1900             :             CALL dbcsr_create(rho_desymm(i_spin, i_img)%matrix, template=matrix_ks(i_spin, i_img)%matrix, &
    1901       13832 :                               matrix_type=dbcsr_type_no_symmetry)
    1902       24450 :             CALL cp_dbcsr_alloc_block_from_nbl(rho_desymm(i_spin, i_img)%matrix, sab_nl_nosym)
    1903             :          END DO
    1904             :       END DO
    1905         422 :       CALL dbt_create(rho_desymm(1, 1)%matrix, tmp)
    1906             : 
    1907             :       !We transfor the symmtric typed (but not actually symmetric: P_ab^i = P_ba^-i) real-spaced density
    1908             :       !matrix into proper non-symemtric ones (using the same nl for consistency)
    1909         422 :       CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
    1910       18275 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    1911       17853 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    1912       17853 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    1913       17853 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    1914             : 
    1915       12737 :          fac = 1.0_dp
    1916       12737 :          IF (iatom == jatom) fac = 0.5_dp
    1917       12737 :          mj_img = get_opp_index(j_img, qs_env)
    1918             :          !if no opposite image, then no sum of P^j + P^-j => need full diag
    1919       12737 :          IF (mj_img == 0) fac = 1.0_dp
    1920             : 
    1921       12737 :          irow = iatom
    1922       12737 :          icol = jatom
    1923       12737 :          IF (iatom > jatom) THEN
    1924             :             !because symmetric nl. Value for atom pair i,j is actually stored in j,i if i > j
    1925        3997 :             irow = jatom
    1926        3997 :             icol = iatom
    1927             :          END IF
    1928             : 
    1929       30064 :          DO i_spin = 1, nspins
    1930       16905 :             CALL dbcsr_get_block_p(rho_ao(i_spin, j_img)%matrix, irow, icol, pblock, found)
    1931       16905 :             IF (.NOT. found) CYCLE
    1932             : 
    1933             :             !distribution of symm and non-symm matrix match in that way
    1934       16905 :             CALL dbcsr_get_block_p(rho_desymm(i_spin, j_img)%matrix, iatom, jatom, pblock_desymm, found)
    1935       16905 :             IF (.NOT. found) CYCLE
    1936             : 
    1937       68568 :             IF (iatom > jatom) THEN
    1938      656678 :                pblock_desymm(:, :) = fac*TRANSPOSE(pblock(:, :))
    1939             :             ELSE
    1940     1568460 :                pblock_desymm(:, :) = fac*pblock(:, :)
    1941             :             END IF
    1942             :          END DO
    1943             :       END DO
    1944         422 :       CALL neighbor_list_iterator_release(nl_iterator)
    1945             : 
    1946       11040 :       DO i_img = 1, nimg
    1947       24872 :          DO i_spin = 1, nspins
    1948       13832 :             CALL dbt_scale(rho_ao_t(i_spin, i_img), scale_prev_p)
    1949             : 
    1950       13832 :             CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, i_img)%matrix, tmp)
    1951       13832 :             CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), summation=.TRUE., move_data=.TRUE.)
    1952             : 
    1953             :             !symmetrize by addin transpose of opp img
    1954       13832 :             mi_img = get_opp_index(i_img, qs_env)
    1955       13832 :             IF (mi_img > 0 .AND. mi_img .LE. nimg) THEN
    1956       12464 :                CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, mi_img)%matrix, tmp)
    1957       12464 :                CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), order=[2, 1], summation=.TRUE., move_data=.TRUE.)
    1958             :             END IF
    1959       24450 :             CALL dbt_filter(rho_ao_t(i_spin, i_img), ri_data%filter_eps)
    1960             :          END DO
    1961             :       END DO
    1962             : 
    1963       11040 :       DO i_img = 1, nimg
    1964       24872 :          DO i_spin = 1, nspins
    1965       13832 :             CALL dbcsr_release(rho_desymm(i_spin, i_img)%matrix)
    1966       24450 :             DEALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1967             :          END DO
    1968             :       END DO
    1969             : 
    1970         422 :       CALL dbt_destroy(tmp)
    1971         422 :       DEALLOCATE (rho_desymm)
    1972             : 
    1973         844 :    END SUBROUTINE get_pmat_images
    1974             : 
    1975             : ! **************************************************************************************************
    1976             : !> \brief A routine that, given a cell index b and atom indices ij, returns a 2c tensor with the HFX
    1977             : !>        potential (P_i^0|Q_j^b), within the extended RI basis
    1978             : !> \param t_2c_pot ...
    1979             : !> \param mat_orig ...
    1980             : !> \param atom_i ...
    1981             : !> \param atom_j ...
    1982             : !> \param img_b ...
    1983             : !> \param ri_data ...
    1984             : !> \param qs_env ...
    1985             : !> \param do_inverse ...
    1986             : !> \param para_env_ext ...
    1987             : !> \param blacs_env_ext ...
    1988             : !> \param dbcsr_template ...
    1989             : !> \param off_diagonal ...
    1990             : !> \param skip_inverse ...
    1991             : ! **************************************************************************************************
    1992        7318 :    SUBROUTINE get_ext_2c_int(t_2c_pot, mat_orig, atom_i, atom_j, img_b, ri_data, qs_env, do_inverse, &
    1993             :                              para_env_ext, blacs_env_ext, dbcsr_template, off_diagonal, skip_inverse)
    1994             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_pot
    1995             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_orig
    1996             :       INTEGER, INTENT(IN)                                :: atom_i, atom_j, img_b
    1997             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1998             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1999             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_inverse
    2000             :       TYPE(mp_para_env_type), OPTIONAL, POINTER          :: para_env_ext
    2001             :       TYPE(cp_blacs_env_type), OPTIONAL, POINTER         :: blacs_env_ext
    2002             :       TYPE(dbcsr_type), OPTIONAL, POINTER                :: dbcsr_template
    2003             :       LOGICAL, INTENT(IN), OPTIONAL                      :: off_diagonal, skip_inverse
    2004             : 
    2005             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_ext_2c_int'
    2006             : 
    2007             :       INTEGER :: blk, group, handle, handle2, i_img, i_RI, iatom, iblk, ikind, img_tot, j_img, &
    2008             :          j_RI, jatom, jblk, jkind, n_dependent, natom, nblks_RI, nimg, nkind
    2009        7318 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist1, dist2
    2010        7318 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: present_atoms_i, present_atoms_j
    2011             :       INTEGER, DIMENSION(3)                              :: cell_b, cell_i, cell_j, cell_tot
    2012        7318 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, col_dist_ext, ri_blk_size_ext, &
    2013        7318 :                                                             row_dist, row_dist_ext
    2014        7318 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell, pgrid
    2015        7318 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    2016             :       LOGICAL                                            :: do_inverse_prv, found, my_offd, &
    2017             :                                                             skip_inverse_prv, use_template
    2018             :       REAL(dp)                                           :: bfac, dij, r0, r1, threshold
    2019             :       REAL(dp), DIMENSION(3)                             :: ri, rij, rj, rref, scoord
    2020        7318 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock
    2021             :       TYPE(cell_type), POINTER                           :: cell
    2022             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
    2023             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist, dbcsr_dist_ext
    2024             :       TYPE(dbcsr_iterator_type)                          :: dbcsr_iter
    2025             :       TYPE(dbcsr_type)                                   :: work, work_tight, work_tight_inv
    2026       51226 :       TYPE(dbt_type)                                     :: t_2c_tmp
    2027             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    2028             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    2029        7318 :          DIMENSION(:), TARGET                            :: basis_set_RI
    2030             :       TYPE(kpoint_type), POINTER                         :: kpoints
    2031             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2032             :       TYPE(neighbor_list_iterator_p_type), &
    2033        7318 :          DIMENSION(:), POINTER                           :: nl_iterator
    2034             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    2035        7318 :          POINTER                                         :: nl_2c
    2036        7318 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    2037        7318 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    2038             : 
    2039        7318 :       NULLIFY (qs_kind_set, nl_2c, nl_iterator, cell, kpoints, cell_to_index, index_to_cell, dist_2d, &
    2040        7318 :                para_env, pblock, blacs_env, particle_set, col_dist, row_dist, pgrid, &
    2041        7318 :                col_dist_ext, row_dist_ext)
    2042             : 
    2043        7318 :       CALL timeset(routineN, handle)
    2044             : 
    2045             :       !Idea: run over the neighbor list once for i and once for j, and record in which cell the MIC
    2046             :       !      atoms are. Then loop over the atoms and only take the pairs the we need
    2047             : 
    2048             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    2049        7318 :                       kpoints=kpoints, para_env=para_env, blacs_env=blacs_env, particle_set=particle_set)
    2050        7318 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    2051             : 
    2052        7318 :       do_inverse_prv = .FALSE.
    2053        7318 :       IF (PRESENT(do_inverse)) do_inverse_prv = do_inverse
    2054         280 :       IF (do_inverse_prv) THEN
    2055         280 :          CPASSERT(atom_i == atom_j)
    2056             :       END IF
    2057             : 
    2058        7318 :       skip_inverse_prv = .FALSE.
    2059        7318 :       IF (PRESENT(skip_inverse)) skip_inverse_prv = skip_inverse
    2060             : 
    2061        7318 :       my_offd = .FALSE.
    2062        7318 :       IF (PRESENT(off_diagonal)) my_offd = off_diagonal
    2063             : 
    2064        7318 :       IF (PRESENT(para_env_ext)) para_env => para_env_ext
    2065        7318 :       IF (PRESENT(blacs_env_ext)) blacs_env => blacs_env_ext
    2066             : 
    2067        7318 :       nimg = SIZE(mat_orig)
    2068             : 
    2069        7318 :       CALL timeset(routineN//"_nl_iter", handle2)
    2070             : 
    2071             :       !create our own dist_2d in the subgroup
    2072       29272 :       ALLOCATE (dist1(natom), dist2(natom))
    2073       21954 :       DO iatom = 1, natom
    2074       14636 :          dist1(iatom) = MOD(iatom, blacs_env%num_pe(1))
    2075       21954 :          dist2(iatom) = MOD(iatom, blacs_env%num_pe(2))
    2076             :       END DO
    2077        7318 :       CALL distribution_2d_create(dist_2d, dist1, dist2, nkind, particle_set, blacs_env_ext=blacs_env)
    2078             : 
    2079       33055 :       ALLOCATE (basis_set_RI(nkind))
    2080        7318 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    2081             : 
    2082             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    2083        7318 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    2084             : 
    2085       43908 :       ALLOCATE (present_atoms_i(natom, nimg), present_atoms_j(natom, nimg))
    2086      746851 :       present_atoms_i = 0
    2087      746851 :       present_atoms_j = 0
    2088             : 
    2089        7318 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    2090      291602 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    2091             :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, r=rij, cell=cell_j, &
    2092      284284 :                                 ikind=ikind, jkind=jkind)
    2093             : 
    2094     1137136 :          dij = NORM2(rij)
    2095             : 
    2096      284284 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    2097      284284 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    2098             : 
    2099      282070 :          IF (iatom == atom_i .AND. dij .LE. ri_data%kp_RI_range) present_atoms_i(jatom, j_img) = 1
    2100      289388 :          IF (iatom == atom_j .AND. dij .LE. ri_data%kp_RI_range) present_atoms_j(jatom, j_img) = 1
    2101             :       END DO
    2102        7318 :       CALL neighbor_list_iterator_release(nl_iterator)
    2103        7318 :       CALL release_neighbor_list_sets(nl_2c)
    2104        7318 :       CALL distribution_2d_release(dist_2d)
    2105        7318 :       CALL timestop(handle2)
    2106             : 
    2107        7318 :       CALL para_env%sum(present_atoms_i)
    2108        7318 :       CALL para_env%sum(present_atoms_j)
    2109             : 
    2110             :       !Need to build a work matrix with matching distribution to mat_orig
    2111             :       !If template is provided, use it. If not, we create it.
    2112        7318 :       use_template = .FALSE.
    2113        7318 :       IF (PRESENT(dbcsr_template)) THEN
    2114        6646 :          IF (ASSOCIATED(dbcsr_template)) use_template = .TRUE.
    2115             :       END IF
    2116             : 
    2117             :       IF (use_template) THEN
    2118        6414 :          CALL dbcsr_create(work, template=dbcsr_template)
    2119             :       ELSE
    2120         904 :          CALL dbcsr_get_info(mat_orig(1), distribution=dbcsr_dist)
    2121         904 :          CALL dbcsr_distribution_get(dbcsr_dist, row_dist=row_dist, col_dist=col_dist, group=group, pgrid=pgrid)
    2122        3616 :          ALLOCATE (row_dist_ext(ri_data%ncell_RI*natom), col_dist_ext(ri_data%ncell_RI*natom))
    2123        1808 :          ALLOCATE (ri_blk_size_ext(ri_data%ncell_RI*natom))
    2124        6428 :          DO i_RI = 1, ri_data%ncell_RI
    2125       27620 :             row_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = row_dist(:)
    2126       27620 :             col_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = col_dist(:)
    2127       17476 :             RI_blk_size_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2128             :          END DO
    2129             : 
    2130             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2131         904 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2132             :          CALL dbcsr_create(work, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2133         904 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2134         904 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2135         904 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2136             : 
    2137        2712 :          IF (PRESENT(dbcsr_template)) THEN
    2138         232 :             ALLOCATE (dbcsr_template)
    2139         232 :             CALL dbcsr_create(dbcsr_template, template=work)
    2140             :          END IF
    2141             :       END IF !use_template
    2142             : 
    2143       29272 :       cell_b(:) = index_to_cell(:, img_b)
    2144      253829 :       DO i_img = 1, nimg
    2145      246511 :          i_RI = ri_data%img_to_RI_cell(i_img)
    2146      246511 :          IF (i_RI == 0) CYCLE
    2147      195364 :          cell_i(:) = index_to_cell(:, i_img)
    2148     2062970 :          DO j_img = 1, nimg
    2149     2006811 :             j_RI = ri_data%img_to_RI_cell(j_img)
    2150     2006811 :             IF (j_RI == 0) CYCLE
    2151     1693732 :             cell_j(:) = index_to_cell(:, j_img)
    2152     1693732 :             cell_tot = cell_j - cell_i + cell_b
    2153             : 
    2154     2924091 :             IF (ANY([cell_tot(1), cell_tot(2), cell_tot(3)] < LBOUND(cell_to_index)) .OR. &
    2155             :                 ANY([cell_tot(1), cell_tot(2), cell_tot(3)] > UBOUND(cell_to_index))) CYCLE
    2156      387929 :             img_tot = cell_to_index(cell_tot(1), cell_tot(2), cell_tot(3))
    2157      387929 :             IF (img_tot > nimg .OR. img_tot < 1) CYCLE
    2158             : 
    2159      263923 :             CALL dbcsr_iterator_start(dbcsr_iter, mat_orig(img_tot))
    2160      754095 :             DO WHILE (dbcsr_iterator_blocks_left(dbcsr_iter))
    2161      490172 :                CALL dbcsr_iterator_next_block(dbcsr_iter, row=iatom, column=jatom, blk=blk)
    2162      490172 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2163      184998 :                IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2164       80543 :                IF (my_offd .AND. (i_RI - 1)*natom + iatom == (j_RI - 1)*natom + jatom) CYCLE
    2165             : 
    2166       80254 :                CALL dbcsr_get_block_p(mat_orig(img_tot), iatom, jatom, pblock, found)
    2167       80254 :                IF (.NOT. found) CYCLE
    2168             : 
    2169      754095 :                CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2170             : 
    2171             :             END DO
    2172     2481741 :             CALL dbcsr_iterator_stop(dbcsr_iter)
    2173             : 
    2174             :          END DO !j_img
    2175             :       END DO !i_img
    2176        7318 :       CALL dbcsr_finalize(work)
    2177             : 
    2178        7318 :       IF (do_inverse_prv) THEN
    2179             : 
    2180         280 :          r1 = ri_data%kp_RI_range
    2181         280 :          r0 = ri_data%kp_bump_rad
    2182             : 
    2183             :          !Because there are a lot of empty rows/cols in work, we need to get rid of them for inversion
    2184       20872 :          nblks_RI = SUM(present_atoms_i)
    2185        1400 :          ALLOCATE (col_dist_ext(nblks_RI), row_dist_ext(nblks_RI), RI_blk_size_ext(nblks_RI))
    2186         280 :          iblk = 0
    2187        7144 :          DO i_img = 1, nimg
    2188        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2189        6864 :             IF (i_RI == 0) CYCLE
    2190        5512 :             DO iatom = 1, natom
    2191        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2192        1156 :                iblk = iblk + 1
    2193        1156 :                col_dist_ext(iblk) = col_dist(iatom)
    2194        1156 :                row_dist_ext(iblk) = row_dist(iatom)
    2195       10352 :                RI_blk_size_ext(iblk) = ri_data%bsizes_RI(iatom)
    2196             :             END DO
    2197             :          END DO
    2198             : 
    2199             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2200         280 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2201             :          CALL dbcsr_create(work_tight, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2202         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2203             :          CALL dbcsr_create(work_tight_inv, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2204         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2205         280 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2206         280 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2207             : 
    2208             :          !We apply a bump function to the RI metric inverse for smooth RI basis extension:
    2209             :          ! S^-1 = B * ((P|Q)_D + B*(P|Q)_OD*B)^-1 * B, with D block-diagonal blocks and OD off-diagonal
    2210         280 :          rref = pbc(particle_set(atom_i)%r, cell)
    2211             : 
    2212         280 :          iblk = 0
    2213        7144 :          DO i_img = 1, nimg
    2214        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2215        6864 :             IF (i_RI == 0) CYCLE
    2216        5512 :             DO iatom = 1, natom
    2217        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2218        1156 :                iblk = iblk + 1
    2219             : 
    2220        1156 :                CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    2221        4624 :                CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    2222             : 
    2223        1156 :                jblk = 0
    2224       42764 :                DO j_img = 1, nimg
    2225       34744 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2226       34744 :                   IF (j_RI == 0) CYCLE
    2227       29504 :                   DO jatom = 1, natom
    2228       17344 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2229        5580 :                      jblk = jblk + 1
    2230             : 
    2231        5580 :                      CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    2232       22320 :                      CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    2233             : 
    2234        5580 :                      CALL dbcsr_get_block_p(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock, found)
    2235        5580 :                      IF (.NOT. found) CYCLE
    2236             : 
    2237        2508 :                      bfac = 1.0_dp
    2238       14088 :                      IF (iblk .NE. jblk) bfac = bump(NORM2(ri - rref), r0, r1)*bump(NORM2(rj - rref), r0, r1)
    2239     5075024 :                      CALL dbcsr_put_block(work_tight, iblk, jblk, bfac*pblock(:, :))
    2240             :                   END DO
    2241             :                END DO
    2242             :             END DO
    2243             :          END DO
    2244         280 :          CALL dbcsr_finalize(work_tight)
    2245         280 :          CALL dbcsr_clear(work)
    2246             : 
    2247         280 :          IF (.NOT. skip_inverse_prv) THEN
    2248         140 :             SELECT CASE (ri_data%t2c_method)
    2249             :             CASE (hfx_ri_do_2c_iter)
    2250           0 :                threshold = MAX(ri_data%filter_eps, 1.0e-12_dp)
    2251           0 :                CALL invert_hotelling(work_tight_inv, work_tight, threshold=threshold, silent=.FALSE.)
    2252             :             CASE (hfx_ri_do_2c_cholesky)
    2253         140 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2254         140 :                CALL cp_dbcsr_cholesky_decompose(work_tight_inv, para_env=para_env, blacs_env=blacs_env)
    2255             :                CALL cp_dbcsr_cholesky_invert(work_tight_inv, para_env=para_env, blacs_env=blacs_env, &
    2256         140 :                                              upper_to_full=.TRUE.)
    2257             :             CASE (hfx_ri_do_2c_diag)
    2258           0 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2259             :                CALL cp_dbcsr_power(work_tight_inv, -1.0_dp, ri_data%eps_eigval, n_dependent, &
    2260         140 :                                    para_env, blacs_env, verbose=ri_data%unit_nr_dbcsr > 0)
    2261             :             END SELECT
    2262             :          ELSE
    2263         140 :             CALL dbcsr_copy(work_tight_inv, work_tight)
    2264             :          END IF
    2265             : 
    2266             :          !move back data to standard extended RI pattern
    2267             :          !Note: we apply the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 later, because this matrix
    2268             :          !      is required for forces
    2269         280 :          iblk = 0
    2270        7144 :          DO i_img = 1, nimg
    2271        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2272        6864 :             IF (i_RI == 0) CYCLE
    2273        5512 :             DO iatom = 1, natom
    2274        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2275        1156 :                iblk = iblk + 1
    2276             : 
    2277        1156 :                jblk = 0
    2278       42764 :                DO j_img = 1, nimg
    2279       34744 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2280       34744 :                   IF (j_RI == 0) CYCLE
    2281       29504 :                   DO jatom = 1, natom
    2282       17344 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2283        5580 :                      jblk = jblk + 1
    2284             : 
    2285        5580 :                      CALL dbcsr_get_block_p(work_tight_inv, iblk, jblk, pblock, found)
    2286        5580 :                      IF (.NOT. found) CYCLE
    2287             : 
    2288       54737 :                      CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2289             :                   END DO
    2290             :                END DO
    2291             :             END DO
    2292             :          END DO
    2293         280 :          CALL dbcsr_finalize(work)
    2294             : 
    2295         280 :          CALL dbcsr_release(work_tight)
    2296         560 :          CALL dbcsr_release(work_tight_inv)
    2297             :       END IF
    2298             : 
    2299        7318 :       CALL dbt_create(work, t_2c_tmp)
    2300        7318 :       CALL dbt_copy_matrix_to_tensor(work, t_2c_tmp)
    2301        7318 :       CALL dbt_copy(t_2c_tmp, t_2c_pot, move_data=.TRUE.)
    2302        7318 :       CALL dbt_filter(t_2c_pot, ri_data%filter_eps)
    2303             : 
    2304        7318 :       CALL dbt_destroy(t_2c_tmp)
    2305        7318 :       CALL dbcsr_release(work)
    2306             : 
    2307        7318 :       CALL timestop(handle)
    2308             : 
    2309       29272 :    END SUBROUTINE get_ext_2c_int
    2310             : 
    2311             : ! **************************************************************************************************
    2312             : !> \brief Pre-contract the density matrices with the 3-center integrals:
    2313             : !>        P_sigma^a,lambda^a+c (mu^0 sigma^a| P^0)
    2314             : !> \param t_3c_apc ...
    2315             : !> \param rho_ao_t ...
    2316             : !> \param ri_data ...
    2317             : !> \param qs_env ...
    2318             : ! **************************************************************************************************
    2319         232 :    SUBROUTINE contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
    2320             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, rho_ao_t
    2321             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2322             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2323             : 
    2324             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'contract_pmat_3c'
    2325             : 
    2326             :       INTEGER                                            :: apc_img, batch_size, handle, i_batch, &
    2327             :                                                             i_img, i_spin, j_batch, n_batch_img, &
    2328             :                                                             n_batch_nze, nimg, nimg_nze, nspins
    2329             :       INTEGER(int_8)                                     :: nflop, nze
    2330         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_img, batch_ranges_nze, &
    2331         232 :                                                             int_indices
    2332         232 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ac_pairs
    2333             :       REAL(dp)                                           :: occ, t1, t2
    2334        2088 :       TYPE(dbt_type)                                     :: t_3c_tmp
    2335         232 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ints_stack, res_stack, rho_stack
    2336             :       TYPE(dft_control_type), POINTER                    :: dft_control
    2337             : 
    2338         232 :       CALL timeset(routineN, handle)
    2339             : 
    2340         232 :       CALL get_qs_env(qs_env, dft_control=dft_control)
    2341             : 
    2342         232 :       nimg = ri_data%nimg
    2343         232 :       nimg_nze = ri_data%nimg_nze
    2344         232 :       nspins = dft_control%nspins
    2345             : 
    2346         232 :       CALL dbt_create(t_3c_apc(1, 1), t_3c_tmp)
    2347             : 
    2348         232 :       batch_size = nimg/ri_data%n_mem
    2349             : 
    2350             :       !batching over all images
    2351         232 :       n_batch_img = nimg/batch_size
    2352         232 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch_img = n_batch_img + 1
    2353         696 :       ALLOCATE (batch_ranges_img(n_batch_img + 1))
    2354         810 :       DO i_batch = 1, n_batch_img
    2355         810 :          batch_ranges_img(i_batch) = (i_batch - 1)*batch_size + 1
    2356             :       END DO
    2357         232 :       batch_ranges_img(n_batch_img + 1) = nimg + 1
    2358             : 
    2359             :       !batching over images with non-zero 3c integrals
    2360         232 :       n_batch_nze = nimg_nze/batch_size
    2361         232 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
    2362         696 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
    2363         752 :       DO i_batch = 1, n_batch_nze
    2364         752 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
    2365             :       END DO
    2366         232 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
    2367             : 
    2368             :       !Create the stack tensors in the approriate distribution
    2369        7192 :       ALLOCATE (rho_stack(2), ints_stack(2), res_stack(2))
    2370             :       CALL get_stack_tensors(res_stack, rho_stack, ints_stack, rho_ao_t(1, 1), &
    2371         232 :                              ri_data%t_3c_int_ctr_1(1, 1), batch_size, ri_data, qs_env)
    2372             : 
    2373        1160 :       ALLOCATE (ac_pairs(nimg, 2), int_indices(nimg_nze))
    2374        4554 :       DO i_img = 1, nimg_nze
    2375        4554 :          int_indices(i_img) = i_img
    2376             :       END DO
    2377             : 
    2378         232 :       t1 = m_walltime()
    2379         752 :       DO j_batch = 1, n_batch_nze
    2380             :          !First batch is over the integrals. They are always in the same order, consistent with get_ac_pairs
    2381             :          CALL fill_3c_stack(ints_stack(1), ri_data%t_3c_int_ctr_1(1, :), int_indices, 3, ri_data, &
    2382        1560 :                             img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)])
    2383         520 :          CALL dbt_copy(ints_stack(1), ints_stack(2), move_data=.TRUE.)
    2384             : 
    2385        1492 :          DO i_spin = 1, nspins
    2386        3152 :             DO i_batch = 1, n_batch_img
    2387             :                !Second batch is over the P matrix. Here we fill the stacked rho tensors col by col
    2388       17120 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2389       15228 :                   CALL get_ac_pairs(ac_pairs, apc_img, ri_data, qs_env)
    2390             :                   CALL fill_2c_stack(rho_stack(1), rho_ao_t(i_spin, :), ac_pairs(:, 2), 1, ri_data, &
    2391             :                                      img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)], &
    2392       47576 :                                      shift=apc_img - batch_ranges_img(i_batch) + 1)
    2393             : 
    2394             :                END DO !apc_img
    2395        1892 :                CALL get_tensor_occupancy(rho_stack(1), nze, occ)
    2396        1892 :                IF (nze == 0) CYCLE
    2397        1650 :                CALL dbt_copy(rho_stack(1), rho_stack(2), move_data=.TRUE.)
    2398             : 
    2399             :                !The actual contraction
    2400        1650 :                CALL dbt_batched_contract_init(rho_stack(2))
    2401             :                CALL dbt_contract(1.0_dp, ints_stack(2), rho_stack(2), &
    2402             :                                  0.0_dp, res_stack(2), map_1=[1, 2], map_2=[3], &
    2403             :                                  contract_1=[3], notcontract_1=[1, 2], &
    2404             :                                  contract_2=[1], notcontract_2=[2], &
    2405        1650 :                                  filter_eps=ri_data%filter_eps, flop=nflop)
    2406        1650 :                ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2407        1650 :                CALL dbt_batched_contract_finalize(rho_stack(2))
    2408        1650 :                CALL dbt_copy(res_stack(2), res_stack(1), move_data=.TRUE.)
    2409             : 
    2410       18680 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2411             :                   !Destack the resulting tensor and put it in t_3c_apc with correct apc_img
    2412       14398 :                   CALL unstack_t_3c_apc(t_3c_tmp, res_stack(1), apc_img - batch_ranges_img(i_batch) + 1)
    2413       16290 :                   CALL dbt_copy(t_3c_tmp, t_3c_apc(i_spin, apc_img), summation=.TRUE., move_data=.TRUE.)
    2414             :                END DO
    2415             : 
    2416             :             END DO !i_batch
    2417             :          END DO !i_spin
    2418             :       END DO !j_batch
    2419         232 :       DEALLOCATE (batch_ranges_img)
    2420         232 :       DEALLOCATE (batch_ranges_nze)
    2421         232 :       t2 = m_walltime()
    2422         232 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    2423             : 
    2424         232 :       CALL dbt_destroy(rho_stack(1))
    2425         232 :       CALL dbt_destroy(rho_stack(2))
    2426         232 :       CALL dbt_destroy(ints_stack(1))
    2427         232 :       CALL dbt_destroy(ints_stack(2))
    2428         232 :       CALL dbt_destroy(res_stack(1))
    2429         232 :       CALL dbt_destroy(res_stack(2))
    2430         232 :       CALL dbt_destroy(t_3c_tmp)
    2431             : 
    2432         232 :       CALL timestop(handle)
    2433             : 
    2434        2320 :    END SUBROUTINE contract_pmat_3c
    2435             : 
    2436             : ! **************************************************************************************************
    2437             : !> \brief Pre-contract 3-center integrals with the bumped invrse RI metric, for each atom
    2438             : !> \param t_3c_int ...
    2439             : !> \param ri_data ...
    2440             : !> \param qs_env ...
    2441             : ! **************************************************************************************************
    2442          70 :    SUBROUTINE precontract_3c_ints(t_3c_int, ri_data, qs_env)
    2443             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_int
    2444             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2445             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2446             : 
    2447             :       CHARACTER(len=*), PARAMETER :: routineN = 'precontract_3c_ints'
    2448             : 
    2449             :       INTEGER                                            :: batch_size, handle, i_batch, i_img, &
    2450             :                                                             i_RI, iatom, is, n_batch, natom, &
    2451             :                                                             nblks, nblks_3c(3), nimg
    2452             :       INTEGER(int_8)                                     :: nflop
    2453          70 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: batch_ranges, bsizes_RI_ext, bsizes_RI_ext_split, &
    2454          70 :          bsizes_stack, dist1, dist2, dist3, dist_stack3, idx_to_at_AO, int_indices
    2455         630 :       TYPE(dbt_distribution_type)                        :: t_dist
    2456       14630 :       TYPE(dbt_type)                                     :: t_2c_RI_tmp(2), t_3c_tmp(3)
    2457             : 
    2458          70 :       CALL timeset(routineN, handle)
    2459             : 
    2460          70 :       CALL get_qs_env(qs_env, natom=natom)
    2461             : 
    2462          70 :       nimg = ri_data%nimg
    2463         210 :       ALLOCATE (int_indices(nimg))
    2464        1786 :       DO i_img = 1, nimg
    2465        1786 :          int_indices(i_img) = i_img
    2466             :       END DO
    2467             : 
    2468         210 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
    2469          70 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
    2470             : 
    2471          70 :       nblks = SIZE(ri_data%bsizes_RI_split)
    2472         210 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    2473         210 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    2474         506 :       DO i_RI = 1, ri_data%ncell_RI
    2475        1308 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2476        2344 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    2477             :       END DO
    2478             :       CALL create_2c_tensor(t_2c_RI_tmp(1), dist1, dist2, ri_data%pgrid_2d, &
    2479             :                             bsizes_RI_ext, bsizes_RI_ext, &
    2480             :                             name="(RI | RI)")
    2481          70 :       DEALLOCATE (dist1, dist2)
    2482             :       CALL create_2c_tensor(t_2c_RI_tmp(2), dist1, dist2, ri_data%pgrid_2d, &
    2483             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    2484             :                             name="(RI | RI)")
    2485          70 :       DEALLOCATE (dist1, dist2)
    2486             : 
    2487             :       !For more efficiency, we stack multiple images of the 3-center integrals into a single tensor
    2488          70 :       batch_size = nimg/ri_data%n_mem
    2489          70 :       n_batch = nimg/batch_size
    2490          70 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch = n_batch + 1
    2491         210 :       ALLOCATE (batch_ranges(n_batch + 1))
    2492         246 :       DO i_batch = 1, n_batch
    2493         246 :          batch_ranges(i_batch) = (i_batch - 1)*batch_size + 1
    2494             :       END DO
    2495          70 :       batch_ranges(n_batch + 1) = nimg + 1
    2496             : 
    2497          70 :       nblks = SIZE(ri_data%bsizes_AO_split)
    2498         210 :       ALLOCATE (bsizes_stack(batch_size*nblks))
    2499         910 :       DO is = 1, batch_size
    2500        3502 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    2501             :       END DO
    2502             : 
    2503          70 :       CALL dbt_get_info(t_3c_int(1, 1), nblks_total=nblks_3c)
    2504         630 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)), dist_stack3(batch_size*nblks_3c(3)))
    2505          70 :       CALL dbt_get_info(t_3c_int(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, proc_dist_3=dist3)
    2506         910 :       DO is = 1, batch_size
    2507        3502 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    2508             :       END DO
    2509             : 
    2510          70 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid, dist1, dist2, dist_stack3)
    2511             :       CALL dbt_create(t_3c_tmp(1), "ints_stack", t_dist, [1], [2, 3], bsizes_RI_ext_split, &
    2512          70 :                       ri_data%bsizes_AO_split, bsizes_stack)
    2513          70 :       CALL dbt_distribution_destroy(t_dist)
    2514          70 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    2515             : 
    2516          70 :       CALL dbt_create(t_3c_tmp(1), t_3c_tmp(2))
    2517          70 :       CALL dbt_create(t_3c_int(1, 1), t_3c_tmp(3))
    2518             : 
    2519         210 :       DO iatom = 1, natom
    2520         140 :          CALL dbt_copy(ri_data%t_2c_inv(1, iatom), t_2c_RI_tmp(1))
    2521         140 :          CALL apply_bump(t_2c_RI_tmp(1), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    2522         140 :          CALL dbt_copy(t_2c_RI_tmp(1), t_2c_RI_tmp(2), move_data=.TRUE.)
    2523             : 
    2524         140 :          CALL dbt_batched_contract_init(t_2c_RI_tmp(2))
    2525         492 :          DO i_batch = 1, n_batch
    2526             : 
    2527             :             CALL fill_3c_stack(t_3c_tmp(1), t_3c_int(1, :), int_indices, 3, ri_data, &
    2528             :                                img_bounds=[batch_ranges(i_batch), batch_ranges(i_batch + 1)], &
    2529        1056 :                                filter_at=iatom, filter_dim=2, idx_to_at=idx_to_at_AO)
    2530             : 
    2531             :             CALL dbt_contract(1.0_dp, t_2c_RI_tmp(2), t_3c_tmp(1), &
    2532             :                               0.0_dp, t_3c_tmp(2), map_1=[1], map_2=[2, 3], &
    2533             :                               contract_1=[2], notcontract_1=[1], &
    2534             :                               contract_2=[1], notcontract_2=[2, 3], &
    2535         352 :                               filter_eps=ri_data%filter_eps, flop=nflop)
    2536         352 :             ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2537             : 
    2538        3784 :             DO i_img = batch_ranges(i_batch), batch_ranges(i_batch + 1) - 1
    2539        3432 :                CALL unstack_t_3c_apc(t_3c_tmp(3), t_3c_tmp(2), i_img - batch_ranges(i_batch) + 1)
    2540             :                CALL dbt_copy(t_3c_tmp(3), ri_data%t_3c_int_ctr_1(1, i_img), summation=.TRUE., &
    2541        3784 :                              order=[2, 1, 3], move_data=.TRUE.)
    2542             :             END DO
    2543         492 :             CALL dbt_clear(t_3c_tmp(1))
    2544             :          END DO
    2545         210 :          CALL dbt_batched_contract_finalize(t_2c_RI_tmp(2))
    2546             : 
    2547             :       END DO
    2548          70 :       CALL dbt_destroy(t_2c_RI_tmp(1))
    2549          70 :       CALL dbt_destroy(t_2c_RI_tmp(2))
    2550          70 :       CALL dbt_destroy(t_3c_tmp(1))
    2551          70 :       CALL dbt_destroy(t_3c_tmp(2))
    2552          70 :       CALL dbt_destroy(t_3c_tmp(3))
    2553             : 
    2554        1786 :       DO i_img = 1, nimg
    2555        1786 :          CALL dbt_destroy(t_3c_int(1, i_img))
    2556             :       END DO
    2557             : 
    2558          70 :       CALL timestop(handle)
    2559             : 
    2560         350 :    END SUBROUTINE precontract_3c_ints
    2561             : 
    2562             : ! **************************************************************************************************
    2563             : !> \brief Copy the data of a 2D tensor living in the main MPI group to a sub-group, given the proc
    2564             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2565             : !> \param t2c_sub ...
    2566             : !> \param t2c_main ...
    2567             : !> \param group_size ...
    2568             : !> \param ngroups ...
    2569             : !> \param para_env ...
    2570             : ! **************************************************************************************************
    2571        8000 :    SUBROUTINE copy_2c_to_subgroup(t2c_sub, t2c_main, group_size, ngroups, para_env)
    2572             :       TYPE(dbt_type), INTENT(INOUT)                      :: t2c_sub, t2c_main
    2573             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2574             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2575             : 
    2576             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iblk, &
    2577             :                                                             igroup, iproc, ir, is, jblk, n_batch, &
    2578             :                                                             nocc, tag
    2579        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2
    2580        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: block_dest, block_source
    2581        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: current_dest
    2582             :       INTEGER, DIMENSION(2)                              :: ind, nblks
    2583             :       LOGICAL                                            :: found
    2584        8000 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2585        8000 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2586             :       TYPE(dbt_iterator_type)                            :: iter
    2587        8000 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2588             : 
    2589             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2590             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2591             :       !         to have blocks pre-reserved though
    2592             : 
    2593        8000 :       CALL dbt_get_info(t2c_main, nblks_total=nblks)
    2594             : 
    2595             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2596       32000 :       ALLOCATE (block_source(nblks(1), nblks(2)))
    2597      166468 :       block_source = -1
    2598        8000 :       nocc = 0
    2599        8000 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t2c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2600             :       CALL dbt_iterator_start(iter, t2c_main)
    2601             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2602             :          CALL dbt_iterator_next_block(iter, ind)
    2603             :          CALL dbt_get_block(t2c_main, ind, blk, found)
    2604             :          IF (.NOT. found) CYCLE
    2605             : 
    2606             :          block_source(ind(1), ind(2)) = para_env%mepos
    2607             : !$OMP ATOMIC
    2608             :          nocc = nocc + 1
    2609             :          DEALLOCATE (blk)
    2610             :       END DO
    2611             :       CALL dbt_iterator_stop(iter)
    2612             : !$OMP END PARALLEL
    2613             : 
    2614        8000 :       CALL para_env%sum(nocc)
    2615        8000 :       CALL para_env%sum(block_source)
    2616      166468 :       block_source = block_source + para_env%num_pe - 1
    2617        8000 :       IF (nocc == 0) RETURN
    2618             : 
    2619             :       !Loop over the sub tensor, get the block destination
    2620        7820 :       igroup = para_env%mepos/group_size
    2621       23460 :       ALLOCATE (block_dest(nblks(1), nblks(2)))
    2622      165208 :       block_dest = -1
    2623       29840 :       DO jblk = 1, nblks(2)
    2624      165208 :          DO iblk = 1, nblks(1)
    2625      135368 :             IF (block_source(iblk, jblk) == -1) CYCLE
    2626             : 
    2627       98868 :             CALL dbt_get_stored_coordinates(t2c_sub, [iblk, jblk], iproc)
    2628      157388 :             block_dest(iblk, jblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2629             :          END DO
    2630             :       END DO
    2631             : 
    2632       39100 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)))
    2633        7820 :       CALL dbt_get_info(t2c_main, blk_size_1=bsizes1, blk_size_2=bsizes2)
    2634             : 
    2635       39100 :       ALLOCATE (current_dest(nblks(1), nblks(2), 0:ngroups - 1))
    2636       23460 :       DO igroup = 0, ngroups - 1
    2637             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2638      330416 :          current_dest(:, :, igroup) = block_dest(:, :)
    2639       23460 :          CALL para_env%bcast(current_dest(:, :, igroup), source=igroup*group_size) !bcast from first proc in sub-group
    2640             :       END DO
    2641             : 
    2642             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2643        7820 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2644        7820 :       n_batch = (nocc*ngroups)/batch_size
    2645        7820 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2646             : 
    2647       15640 :       DO i_batch = 1, n_batch
    2648             :          !Loop over groups, blocks and send/receive
    2649      163104 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2650      163104 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2651             :          ir = 0
    2652             :          is = 0
    2653             :          i_msg = 0
    2654       29840 :          DO jblk = 1, nblks(2)
    2655      165208 :             DO iblk = 1, nblks(1)
    2656      428124 :                DO igroup = 0, ngroups - 1
    2657      270736 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2658             : 
    2659       65912 :                   i_msg = i_msg + 1
    2660       65912 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2661             : 
    2662             :                   !a unique tag per block, within this batch
    2663       65912 :                   tag = i_msg - (i_batch - 1)*batch_size
    2664             : 
    2665       65912 :                   found = .FALSE.
    2666       65912 :                   IF (para_env%mepos == block_source(iblk, jblk)) THEN
    2667       98868 :                      CALL dbt_get_block(t2c_main, [iblk, jblk], blk, found)
    2668             :                   END IF
    2669             : 
    2670             :                   !If blocks live on same proc, simply copy. Else MPI send/recv
    2671       65912 :                   IF (block_source(iblk, jblk) == current_dest(iblk, jblk, igroup)) THEN
    2672       98868 :                      IF (found) CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2673             :                   ELSE
    2674       32956 :                      IF (para_env%mepos == block_source(iblk, jblk) .AND. found) THEN
    2675       65912 :                         ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2676    21062150 :                         send_buff(tag)%array(:, :) = blk(:, :)
    2677       16478 :                         is = is + 1
    2678             :                         CALL para_env%isend(msgin=send_buff(tag)%array, dest=current_dest(iblk, jblk, igroup), &
    2679       16478 :                                             request=send_req(is), tag=tag)
    2680             :                      END IF
    2681             : 
    2682       32956 :                      IF (para_env%mepos == current_dest(iblk, jblk, igroup)) THEN
    2683       65912 :                         ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2684       16478 :                         ir = ir + 1
    2685             :                         CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk), &
    2686       16478 :                                             request=recv_req(ir), tag=tag)
    2687             :                      END IF
    2688             :                   END IF
    2689             : 
    2690      201280 :                   IF (found) DEALLOCATE (blk)
    2691             :                END DO
    2692             :             END DO
    2693             :          END DO
    2694             : 
    2695        7820 :          CALL mp_waitall(send_req(1:is))
    2696        7820 :          CALL mp_waitall(recv_req(1:ir))
    2697             :          !clean-up
    2698       73732 :          DO i = 1, batch_size
    2699       73732 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2700             :          END DO
    2701             : 
    2702             :          !Finally copy the data from the buffer to the sub-tensor
    2703             :          i_msg = 0
    2704       29840 :          DO jblk = 1, nblks(2)
    2705      165208 :             DO iblk = 1, nblks(1)
    2706      428124 :                DO igroup = 0, ngroups - 1
    2707      270736 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2708             : 
    2709       65912 :                   i_msg = i_msg + 1
    2710       65912 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2711             : 
    2712             :                   !a unique tag per block, within this batch
    2713       65912 :                   tag = i_msg - (i_batch - 1)*batch_size
    2714             : 
    2715       65912 :                   IF (para_env%mepos == current_dest(iblk, jblk, igroup) .AND. &
    2716      135368 :                       block_source(iblk, jblk) .NE. current_dest(iblk, jblk, igroup)) THEN
    2717             : 
    2718       65912 :                      ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk)))
    2719    21062150 :                      blk(:, :) = recv_buff(tag)%array(:, :)
    2720       82390 :                      CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2721       16478 :                      DEALLOCATE (blk)
    2722             :                   END IF
    2723             :                END DO
    2724             :             END DO
    2725             :          END DO
    2726             : 
    2727             :          !clean-up
    2728       73732 :          DO i = 1, batch_size
    2729       73732 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2730             :          END DO
    2731       15640 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2732             :       END DO !i_batch
    2733        7820 :       CALL dbt_finalize(t2c_sub)
    2734             : 
    2735       16000 :    END SUBROUTINE copy_2c_to_subgroup
    2736             : 
    2737             : ! **************************************************************************************************
    2738             : !> \brief Copy the data of a 3D tensor living in the main MPI group to a sub-group, given the proc
    2739             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2740             : !> \param t3c_sub ...
    2741             : !> \param t3c_main ...
    2742             : !> \param group_size ...
    2743             : !> \param ngroups ...
    2744             : !> \param para_env ...
    2745             : !> \param iatom_to_subgroup ...
    2746             : !> \param dim_at ...
    2747             : !> \param idx_to_at ...
    2748             : ! **************************************************************************************************
    2749       11690 :    SUBROUTINE copy_3c_to_subgroup(t3c_sub, t3c_main, group_size, ngroups, para_env, iatom_to_subgroup, &
    2750       11690 :                                   dim_at, idx_to_at)
    2751             :       TYPE(dbt_type), INTENT(INOUT)                      :: t3c_sub, t3c_main
    2752             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2753             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2754             :       TYPE(cp_1d_logical_p_type), DIMENSION(:), &
    2755             :          INTENT(INOUT), OPTIONAL                         :: iatom_to_subgroup
    2756             :       INTEGER, INTENT(IN), OPTIONAL                      :: dim_at
    2757             :       INTEGER, DIMENSION(:), OPTIONAL                    :: idx_to_at
    2758             : 
    2759             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iatom, &
    2760             :                                                             iblk, igroup, iproc, ir, is, jblk, &
    2761             :                                                             kblk, n_batch, nocc, tag
    2762       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2, bsizes3
    2763       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: block_dest, block_source
    2764       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :)        :: current_dest
    2765             :       INTEGER, DIMENSION(3)                              :: ind, nblks
    2766             :       LOGICAL                                            :: filter_at, found
    2767       11690 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    2768       11690 :       TYPE(cp_3d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2769             :       TYPE(dbt_iterator_type)                            :: iter
    2770       11690 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2771             : 
    2772             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2773             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2774             :       !         to have blocks pre-reserved though
    2775             : 
    2776       11690 :       CALL dbt_get_info(t3c_main, nblks_total=nblks)
    2777             : 
    2778             :       !in some cases, only copy a fraction of the 3c tensor to a given subgroup (corresponding to some atoms)
    2779       11690 :       filter_at = .FALSE.
    2780       11690 :       IF (PRESENT(iatom_to_subgroup) .AND. PRESENT(dim_at) .AND. PRESENT(idx_to_at)) THEN
    2781        6614 :          filter_at = .TRUE.
    2782        6614 :          CPASSERT(nblks(dim_at) == SIZE(idx_to_at))
    2783             :       END IF
    2784             : 
    2785             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2786       58450 :       ALLOCATE (block_source(nblks(1), nblks(2), nblks(3)))
    2787      640562 :       block_source = -1
    2788       11690 :       nocc = 0
    2789       11690 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t3c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2790             :       CALL dbt_iterator_start(iter, t3c_main)
    2791             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2792             :          CALL dbt_iterator_next_block(iter, ind)
    2793             :          CALL dbt_get_block(t3c_main, ind, blk, found)
    2794             :          IF (.NOT. found) CYCLE
    2795             : 
    2796             :          block_source(ind(1), ind(2), ind(3)) = para_env%mepos
    2797             : !$OMP ATOMIC
    2798             :          nocc = nocc + 1
    2799             :          DEALLOCATE (blk)
    2800             :       END DO
    2801             :       CALL dbt_iterator_stop(iter)
    2802             : !$OMP END PARALLEL
    2803             : 
    2804       11690 :       CALL para_env%sum(nocc)
    2805       11690 :       CALL para_env%sum(block_source)
    2806      640562 :       block_source = block_source + para_env%num_pe - 1
    2807       11690 :       IF (nocc == 0) RETURN
    2808             : 
    2809             :       !Loop over the sub tensor, get the block destination
    2810       11690 :       igroup = para_env%mepos/group_size
    2811       46760 :       ALLOCATE (block_dest(nblks(1), nblks(2), nblks(3)))
    2812      640562 :       block_dest = -1
    2813       35070 :       DO kblk = 1, nblks(3)
    2814      163314 :          DO jblk = 1, nblks(2)
    2815      628872 :             DO iblk = 1, nblks(1)
    2816      477248 :                IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2817             : 
    2818      491240 :                CALL dbt_get_stored_coordinates(t3c_sub, [iblk, jblk, kblk], iproc)
    2819      605492 :                block_dest(iblk, jblk, kblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2820             :             END DO
    2821             :          END DO
    2822             :       END DO
    2823             : 
    2824       81830 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)), bsizes3(nblks(3)))
    2825       11690 :       CALL dbt_get_info(t3c_main, blk_size_1=bsizes1, blk_size_2=bsizes2, blk_size_3=bsizes3)
    2826             : 
    2827       70140 :       ALLOCATE (current_dest(nblks(1), nblks(2), nblks(3), 0:ngroups - 1))
    2828       35070 :       DO igroup = 0, ngroups - 1
    2829             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2830     1281124 :          current_dest(:, :, :, igroup) = block_dest(:, :, :)
    2831       35070 :          CALL para_env%bcast(current_dest(:, :, :, igroup), source=igroup*group_size) !bcast from first proc in subgroup
    2832             :       END DO
    2833             : 
    2834             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2835       11690 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2836       11690 :       n_batch = (nocc*ngroups)/batch_size
    2837       11690 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2838             : 
    2839       23380 :       DO i_batch = 1, n_batch
    2840             :          !Loop over groups, blocks and send/receive
    2841      538000 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2842      538000 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2843             :          ir = 0
    2844             :          is = 0
    2845             :          i_msg = 0
    2846       35070 :          DO kblk = 1, nblks(3)
    2847      163314 :             DO jblk = 1, nblks(2)
    2848      628872 :                DO iblk = 1, nblks(1)
    2849     1559988 :                   DO igroup = 0, ngroups - 1
    2850      954496 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2851             : 
    2852      245620 :                      i_msg = i_msg + 1
    2853      245620 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2854             : 
    2855             :                      !a unique tag per block, within this batch
    2856      245620 :                      tag = i_msg - (i_batch - 1)*batch_size
    2857             : 
    2858      245620 :                      IF (filter_at) THEN
    2859      693136 :                         ind(:) = [iblk, jblk, kblk]
    2860      173284 :                         iatom = idx_to_at(ind(dim_at))
    2861      173284 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2862             :                      END IF
    2863             : 
    2864      158978 :                      found = .FALSE.
    2865      158978 :                      IF (para_env%mepos == block_source(iblk, jblk, kblk)) THEN
    2866      317956 :                         CALL dbt_get_block(t3c_main, [iblk, jblk, kblk], blk, found)
    2867             :                      END IF
    2868             : 
    2869             :                      !If blocks live on same proc, simply copy. Else MPI send/recv
    2870      158978 :                      IF (block_source(iblk, jblk, kblk) == current_dest(iblk, jblk, kblk, igroup)) THEN
    2871      329744 :                         IF (found) CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2872             :                      ELSE
    2873       76542 :                         IF (para_env%mepos == block_source(iblk, jblk, kblk) .AND. found) THEN
    2874      191355 :                            ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2875   187303822 :                            send_buff(tag)%array(:, :, :) = blk(:, :, :)
    2876       38271 :                            is = is + 1
    2877             :                            CALL para_env%isend(msgin=send_buff(tag)%array, &
    2878             :                                                dest=current_dest(iblk, jblk, kblk, igroup), &
    2879       38271 :                                                request=send_req(is), tag=tag)
    2880             :                         END IF
    2881             : 
    2882       76542 :                         IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup)) THEN
    2883      191355 :                            ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2884       38271 :                            ir = ir + 1
    2885             :                            CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk, kblk), &
    2886       38271 :                                                request=recv_req(ir), tag=tag)
    2887             :                         END IF
    2888             :                      END IF
    2889             : 
    2890      636226 :                      IF (found) DEALLOCATE (blk)
    2891             :                   END DO
    2892             :                END DO
    2893             :             END DO
    2894             :          END DO
    2895             : 
    2896       11690 :          CALL mp_waitall(send_req(1:is))
    2897       11690 :          CALL mp_waitall(recv_req(1:ir))
    2898             :          !clean-up
    2899      257310 :          DO i = 1, batch_size
    2900      257310 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2901             :          END DO
    2902             : 
    2903             :          !Finally copy the data from the buffer to the sub-tensor
    2904             :          i_msg = 0
    2905       35070 :          DO kblk = 1, nblks(3)
    2906      163314 :             DO jblk = 1, nblks(2)
    2907      628872 :                DO iblk = 1, nblks(1)
    2908     1559988 :                   DO igroup = 0, ngroups - 1
    2909      954496 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2910             : 
    2911      245620 :                      i_msg = i_msg + 1
    2912      245620 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2913             : 
    2914             :                      !a unique tag per block, within this batch
    2915      245620 :                      tag = i_msg - (i_batch - 1)*batch_size
    2916             : 
    2917      245620 :                      IF (filter_at) THEN
    2918      693136 :                         ind(:) = [iblk, jblk, kblk]
    2919      173284 :                         iatom = idx_to_at(ind(dim_at))
    2920      173284 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2921             :                      END IF
    2922             : 
    2923      158978 :                      IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup) .AND. &
    2924      477248 :                          block_source(iblk, jblk, kblk) .NE. current_dest(iblk, jblk, kblk, igroup)) THEN
    2925             : 
    2926      191355 :                         ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2927   187303822 :                         blk(:, :, :) = recv_buff(tag)%array(:, :, :)
    2928      267897 :                         CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2929       38271 :                         DEALLOCATE (blk)
    2930             :                      END IF
    2931             :                   END DO
    2932             :                END DO
    2933             :             END DO
    2934             :          END DO
    2935             : 
    2936             :          !clean-up
    2937      257310 :          DO i = 1, batch_size
    2938      257310 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2939             :          END DO
    2940       23380 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2941             :       END DO !i_batch
    2942       11690 :       CALL dbt_finalize(t3c_sub)
    2943             : 
    2944       23380 :    END SUBROUTINE copy_3c_to_subgroup
    2945             : 
    2946             : ! **************************************************************************************************
    2947             : !> \brief A routine that gather the pieces of the KS matrix accross the subgroup and puts it in the
    2948             : !>        main group. Each b_img, iatom, jatom tuple is one a single CPU
    2949             : !> \param ks_t ...
    2950             : !> \param ks_t_sub ...
    2951             : !> \param group_size ...
    2952             : !> \param sparsity_pattern ...
    2953             : !> \param para_env ...
    2954             : !> \param ri_data ...
    2955             : ! **************************************************************************************************
    2956         190 :    SUBROUTINE gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
    2957             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t, ks_t_sub
    2958             :       INTEGER, INTENT(IN)                                :: group_size
    2959             :       INTEGER, DIMENSION(:, :, :), INTENT(IN)            :: sparsity_pattern
    2960             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2961             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2962             : 
    2963             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'gather_ks_matrix'
    2964             : 
    2965             :       INTEGER                                            :: b_img, dest, handle, i, i_spin, iatom, &
    2966             :                                                             igroup, ir, is, jatom, n_mess, natom, &
    2967             :                                                             nimg, nspins, source, tag
    2968             :       LOGICAL                                            :: found
    2969         190 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2970         190 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2971         190 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2972             : 
    2973         190 :       CALL timeset(routineN, handle)
    2974             : 
    2975         190 :       nimg = SIZE(sparsity_pattern, 3)
    2976         190 :       natom = SIZE(sparsity_pattern, 2)
    2977         190 :       nspins = SIZE(ks_t, 1)
    2978             : 
    2979        4994 :       DO b_img = 1, nimg
    2980             :          n_mess = 0
    2981       11132 :          DO i_spin = 1, nspins
    2982       23788 :             DO jatom = 1, natom
    2983       44296 :                DO iatom = 1, natom
    2984       37968 :                   IF (sparsity_pattern(iatom, jatom, b_img) > -1) n_mess = n_mess + 1
    2985             :                END DO
    2986             :             END DO
    2987             :          END DO
    2988             : 
    2989       37136 :          ALLOCATE (send_buff(n_mess), recv_buff(n_mess))
    2990       41940 :          ALLOCATE (send_req(n_mess), recv_req(n_mess))
    2991        4804 :          ir = 0
    2992        4804 :          is = 0
    2993        4804 :          n_mess = 0
    2994        4804 :          tag = 0
    2995             : 
    2996       11132 :          DO i_spin = 1, nspins
    2997       23788 :             DO jatom = 1, natom
    2998       44296 :                DO iatom = 1, natom
    2999       25312 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3000        8776 :                   n_mess = n_mess + 1
    3001        8776 :                   tag = tag + 1
    3002             : 
    3003             :                   !sending the message
    3004       26328 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3005       26328 :                   CALL dbt_get_stored_coordinates(ks_t_sub(i_spin, b_img), [iatom, jatom], source) !source within sub
    3006        8776 :                   igroup = sparsity_pattern(iatom, jatom, b_img)
    3007        8776 :                   source = source + igroup*group_size
    3008        8776 :                   IF (para_env%mepos == source) THEN
    3009       13164 :                      CALL dbt_get_block(ks_t_sub(i_spin, b_img), [iatom, jatom], blk, found)
    3010        4388 :                      IF (source == dest) THEN
    3011        3547 :                         IF (found) CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3012             :                      ELSE
    3013       13764 :                         ALLOCATE (send_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3014      282773 :                         send_buff(n_mess)%array(:, :) = 0.0_dp
    3015        3441 :                         IF (found) THEN
    3016      194070 :                            send_buff(n_mess)%array(:, :) = blk(:, :)
    3017             :                         END IF
    3018        3441 :                         is = is + 1
    3019             :                         CALL para_env%isend(msgin=send_buff(n_mess)%array, dest=dest, &
    3020        3441 :                                             request=send_req(is), tag=tag)
    3021             :                      END IF
    3022        4388 :                      DEALLOCATE (blk)
    3023             :                   END IF
    3024             : 
    3025             :                   !receiving the message
    3026       21432 :                   IF (para_env%mepos == dest .AND. source .NE. dest) THEN
    3027       13764 :                      ALLOCATE (recv_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3028        3441 :                      ir = ir + 1
    3029             :                      CALL para_env%irecv(msgout=recv_buff(n_mess)%array, source=source, &
    3030        3441 :                                          request=recv_req(ir), tag=tag)
    3031             :                   END IF
    3032             :                END DO !iatom
    3033             :             END DO !jatom
    3034             :          END DO !ispin
    3035             : 
    3036        4804 :          CALL mp_waitall(send_req(1:is))
    3037        4804 :          CALL mp_waitall(recv_req(1:ir))
    3038             : 
    3039             :          !Copy the messages received into the KS matrix
    3040        4804 :          n_mess = 0
    3041       11132 :          DO i_spin = 1, nspins
    3042       23788 :             DO jatom = 1, natom
    3043       44296 :                DO iatom = 1, natom
    3044       25312 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3045        8776 :                   n_mess = n_mess + 1
    3046             : 
    3047       26328 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3048       21432 :                   IF (para_env%mepos == dest) THEN
    3049        4388 :                      IF (.NOT. ASSOCIATED(recv_buff(n_mess)%array)) CYCLE
    3050       13764 :                      ALLOCATE (blk(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3051      282773 :                      blk(:, :) = recv_buff(n_mess)%array(:, :)
    3052       17205 :                      CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3053        3441 :                      DEALLOCATE (blk)
    3054             :                   END IF
    3055             :                END DO
    3056             :             END DO
    3057             :          END DO
    3058             : 
    3059             :          !clean-up
    3060       13580 :          DO i = 1, n_mess
    3061        8776 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    3062       13580 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    3063             :          END DO
    3064        4994 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    3065             :       END DO !b_img
    3066             : 
    3067         190 :       CALL timestop(handle)
    3068             : 
    3069         190 :    END SUBROUTINE gather_ks_matrix
    3070             : 
    3071             : ! **************************************************************************************************
    3072             : !> \brief copy all required 2c tensors from the main MPI group to the subgroups
    3073             : !> \param mat_2c_pot ...
    3074             : !> \param t_2c_work ...
    3075             : !> \param t_2c_ao_tmp ...
    3076             : !> \param ks_t_split ...
    3077             : !> \param ks_t_sub ...
    3078             : !> \param group_size ...
    3079             : !> \param ngroups ...
    3080             : !> \param para_env ...
    3081             : !> \param para_env_sub ...
    3082             : !> \param ri_data ...
    3083             : ! **************************************************************************************************
    3084         190 :    SUBROUTINE get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
    3085             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3086             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3087             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work, t_2c_ao_tmp, ks_t_split
    3088             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t_sub
    3089             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3090             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3091             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3092             : 
    3093             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_tensors'
    3094             : 
    3095             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, iproc, &
    3096             :                                                             j, natom, nblks, nimg, nspins
    3097             :       INTEGER(int_8)                                     :: nze
    3098             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3099         190 :                                                             dist1, dist2
    3100             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3101         380 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3102         190 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3103             :       REAL(dp)                                           :: occ
    3104             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3105         570 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3106        2470 :       TYPE(dbt_type)                                     :: work, work_sub
    3107             : 
    3108         190 :       CALL timeset(routineN, handle)
    3109             : 
    3110             :       !Create the 2d pgrid
    3111         190 :       pdims_2d = 0
    3112         190 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3113             : 
    3114         190 :       natom = SIZE(ri_data%bsizes_RI)
    3115         190 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3116         570 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3117         570 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3118        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3119        3432 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3120        6064 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3121             :       END DO
    3122             : 
    3123             :       !nRI x nRI 2c tensors
    3124             :       CALL create_2c_tensor(t_2c_work(1), dist1, dist2, pgrid_2d, &
    3125             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3126             :                             name="(RI | RI)")
    3127         190 :       DEALLOCATE (dist1, dist2)
    3128             : 
    3129             :       CALL create_2c_tensor(t_2c_work(2), dist1, dist2, pgrid_2d, &
    3130             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3131         190 :                             name="(RI | RI)")
    3132         190 :       DEALLOCATE (dist1, dist2)
    3133             : 
    3134             :       !the AO based tensors
    3135             :       CALL create_2c_tensor(ks_t_split(1), dist1, dist2, pgrid_2d, &
    3136             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3137             :                             name="(AO | AO)")
    3138         190 :       DEALLOCATE (dist1, dist2)
    3139         190 :       CALL dbt_create(ks_t_split(1), ks_t_split(2))
    3140             : 
    3141             :       CALL create_2c_tensor(t_2c_ao_tmp(1), dist1, dist2, pgrid_2d, &
    3142             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, &
    3143             :                             name="(AO | AO)")
    3144         190 :       DEALLOCATE (dist1, dist2)
    3145             : 
    3146         190 :       nspins = SIZE(ks_t_sub, 1)
    3147         190 :       nimg = SIZE(ks_t_sub, 2)
    3148        4994 :       DO i_img = 1, nimg
    3149       11322 :          DO i_spin = 1, nspins
    3150       11132 :             CALL dbt_create(t_2c_ao_tmp(1), ks_t_sub(i_spin, i_img))
    3151             :          END DO
    3152             :       END DO
    3153             : 
    3154             :       !Finally the HFX potential matrices
    3155             :       !For now, we do a convoluted things where we go to tensors first, then back to matrices.
    3156             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3157             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3158             :                             name="(RI | RI)")
    3159         190 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3160             : 
    3161         760 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3162         190 :       iproc = 0
    3163         380 :       DO i = 0, pdims_2d(1) - 1
    3164         570 :          DO j = 0, pdims_2d(2) - 1
    3165         190 :             dbcsr_pgrid(i, j) = iproc
    3166         380 :             iproc = iproc + 1
    3167             :          END DO
    3168             :       END DO
    3169             : 
    3170             :       !We need to have the same exact 2d block dist as the tensors
    3171         760 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3172         570 :       row_dist(:) = dist1(:)
    3173         570 :       col_dist(:) = dist2(:)
    3174             : 
    3175         380 :       ALLOCATE (RI_blk_size(natom))
    3176         570 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3177             : 
    3178             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3179         190 :                                   row_dist=row_dist, col_dist=col_dist)
    3180             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3181         190 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3182             : 
    3183        4994 :       DO i_img = 1, nimg
    3184        4804 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3185        4804 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3186        4804 :          CALL get_tensor_occupancy(work, nze, occ)
    3187        4804 :          IF (nze == 0) CYCLE
    3188             : 
    3189        3708 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3190        3708 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3191        3708 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3192        8702 :          CALL dbt_clear(work_sub)
    3193             :       END DO
    3194             : 
    3195         190 :       CALL dbt_destroy(work)
    3196         190 :       CALL dbt_destroy(work_sub)
    3197         190 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3198         190 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3199         190 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3200         190 :       CALL timestop(handle)
    3201             : 
    3202        1710 :    END SUBROUTINE get_subgroup_2c_tensors
    3203             : 
    3204             : ! **************************************************************************************************
    3205             : !> \brief copy all required 3c tensors from the main MPI group to the subgroups
    3206             : !> \param t_3c_int ...
    3207             : !> \param t_3c_work_2 ...
    3208             : !> \param t_3c_work_3 ...
    3209             : !> \param t_3c_apc ...
    3210             : !> \param t_3c_apc_sub ...
    3211             : !> \param group_size ...
    3212             : !> \param ngroups ...
    3213             : !> \param para_env ...
    3214             : !> \param para_env_sub ...
    3215             : !> \param ri_data ...
    3216             : ! **************************************************************************************************
    3217         190 :    SUBROUTINE get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
    3218             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3219             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_int, t_3c_work_2, t_3c_work_3
    3220             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, t_3c_apc_sub
    3221             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3222             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3223             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3224             : 
    3225             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_tensors'
    3226             : 
    3227             :       INTEGER                                            :: batch_size, bfac, bo(2), handle, &
    3228             :                                                             handle2, i_blk, i_img, i_RI, i_spin, &
    3229             :                                                             ib, natom, nblks_AO, nblks_RI, nimg, &
    3230             :                                                             nspins
    3231             :       INTEGER(int_8)                                     :: nze
    3232         190 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3233         190 :                                                             bsizes_stack, bsizes_tmp, dist1, &
    3234         190 :                                                             dist2, dist3, dist_stack, idx_to_at
    3235             :       INTEGER, DIMENSION(3)                              :: pdims
    3236             :       REAL(dp)                                           :: occ
    3237        1710 :       TYPE(dbt_distribution_type)                        :: t_dist
    3238         570 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3239        4750 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3240             : 
    3241         190 :       CALL timeset(routineN, handle)
    3242             : 
    3243         190 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3244         570 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3245        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3246        6064 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3247             :       END DO
    3248             : 
    3249             :       !Preparing larger block sizes for efficient communication (less, bigger messages)
    3250             :       !we put 2 atoms per RI block
    3251         190 :       bfac = 2
    3252         190 :       natom = SIZE(ri_data%bsizes_RI)
    3253         190 :       nblks_RI = MAX(1, natom/bfac)
    3254         570 :       ALLOCATE (bsizes_tmp(nblks_RI))
    3255         380 :       DO i_blk = 1, nblks_RI
    3256         190 :          bo = get_limit(natom, nblks_RI, i_blk - 1)
    3257         760 :          bsizes_tmp(i_blk) = SUM(ri_data%bsizes_RI(bo(1):bo(2)))
    3258             :       END DO
    3259         570 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3260        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3261        2478 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = bsizes_tmp(:)
    3262             :       END DO
    3263             : 
    3264         190 :       batch_size = ri_data%kp_stack_size
    3265         190 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3266         570 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3267        6270 :       DO ib = 1, batch_size
    3268       29886 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3269             :       END DO
    3270             : 
    3271             :       !Create the pgrid for the configuration correspoinding to ri_data%t_3c_int_ctr_3
    3272         190 :       natom = SIZE(ri_data%bsizes_RI)
    3273         190 :       pdims = 0
    3274             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3275         760 :                             tensor_dims=[SIZE(bsizes_RI_ext_split), 1, batch_size*SIZE(ri_data%bsizes_AO_split)])
    3276             : 
    3277             :       !Create all required 3c tensors in that configuration
    3278             :       CALL create_3c_tensor(t_3c_int(1), dist1, dist2, dist3, &
    3279             :                             pgrid, bsizes_RI_ext_split, ri_data%bsizes_AO_split, &
    3280         190 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(RI | AO AO)")
    3281         190 :       nimg = SIZE(t_3c_int)
    3282        4804 :       DO i_img = 2, nimg
    3283        4804 :          CALL dbt_create(t_3c_int(1), t_3c_int(i_img))
    3284             :       END DO
    3285             : 
    3286             :       !The stacked work tensors, in a distribution that matches that of t_3c_int
    3287         380 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3288        6270 :       DO ib = 1, batch_size
    3289       29886 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3290             :       END DO
    3291             : 
    3292         190 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3293             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3294         190 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3295         190 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3296         190 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3297         190 :       CALL dbt_distribution_destroy(t_dist)
    3298         190 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3299             : 
    3300             :       !For more efficient communication, we use intermediate tensors with larger block size
    3301             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3302             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3303         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3304         190 :       DEALLOCATE (dist1, dist2, dist3)
    3305             : 
    3306             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3307             :                             ri_data%pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3308         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3309         190 :       DEALLOCATE (dist1, dist2, dist3)
    3310             : 
    3311             :       !Finally copy the integrals into the subgroups (if not there already)
    3312         190 :       CALL timeset(routineN//"_ints", handle2)
    3313         190 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
    3314        3208 :          DO i_img = 1, nimg
    3315        3208 :             CALL dbt_copy(ri_data%kp_t_3c_int(i_img), t_3c_int(i_img), move_data=.TRUE.)
    3316             :          END DO
    3317             :       ELSE
    3318        2486 :          ALLOCATE (ri_data%kp_t_3c_int(nimg))
    3319        1786 :          DO i_img = 1, nimg
    3320        1716 :             CALL dbt_create(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img))
    3321        1716 :             CALL get_tensor_occupancy(ri_data%t_3c_int_ctr_1(1, i_img), nze, occ)
    3322        1716 :             IF (nze == 0) CYCLE
    3323        1236 :             CALL dbt_copy(ri_data%t_3c_int_ctr_1(1, i_img), work_atom_block, order=[2, 1, 3])
    3324        1236 :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, ngroups, para_env)
    3325        3022 :             CALL dbt_copy(work_atom_block_sub, t_3c_int(i_img), move_data=.TRUE.)
    3326             :          END DO
    3327             :       END IF
    3328         190 :       CALL timestop(handle2)
    3329         190 :       CALL dbt_pgrid_destroy(pgrid)
    3330         190 :       CALL dbt_destroy(work_atom_block)
    3331         190 :       CALL dbt_destroy(work_atom_block_sub)
    3332             : 
    3333             :       !Do the same for the t_3c_ctr_2 configuration
    3334         190 :       pdims = 0
    3335             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3336         760 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3337             : 
    3338             :       !For more efficient communication, we use intermediate tensors with larger block size
    3339             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3340             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3341         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3342         190 :       DEALLOCATE (dist1, dist2, dist3)
    3343             : 
    3344             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3345             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3346         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3347         190 :       DEALLOCATE (dist1, dist2, dist3)
    3348             : 
    3349             :       !template for t_3c_apc_sub
    3350             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3351             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3352         190 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3353             : 
    3354             :       !create t_3c_work_2 tensors in a distribution that matches the above
    3355         380 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3356        6270 :       DO ib = 1, batch_size
    3357       29886 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3358             :       END DO
    3359             : 
    3360         190 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3361             :       CALL dbt_create(t_3c_work_2(1), "work_2_stack", t_dist, [1], [2, 3], &
    3362         190 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3363         190 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3364         190 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3365         190 :       CALL dbt_distribution_destroy(t_dist)
    3366         190 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3367             : 
    3368             :       !Finally copy data from t_3c_apc to the subgroups
    3369         570 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3370         190 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3371         190 :       nspins = SIZE(t_3c_apc, 1)
    3372         190 :       CALL timeset(routineN//"_apc", handle2)
    3373        4994 :       DO i_img = 1, nimg
    3374       11132 :          DO i_spin = 1, nspins
    3375        6328 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3376        6328 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3377        6328 :             IF (nze == 0) CYCLE
    3378        5462 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3379             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3380        5462 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3381       16594 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3382             :          END DO
    3383       11322 :          DO i_spin = 1, nspins
    3384       11132 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3385             :          END DO
    3386             :       END DO
    3387         190 :       CALL timestop(handle2)
    3388         190 :       CALL dbt_pgrid_destroy(pgrid)
    3389         190 :       CALL dbt_destroy(tmp)
    3390         190 :       CALL dbt_destroy(work_atom_block)
    3391         190 :       CALL dbt_destroy(work_atom_block_sub)
    3392             : 
    3393         190 :       CALL timestop(handle)
    3394             : 
    3395         760 :    END SUBROUTINE get_subgroup_3c_tensors
    3396             : 
    3397             : ! **************************************************************************************************
    3398             : !> \brief copy all required 2c force tensors from the main MPI group to the subgroups
    3399             : !> \param t_2c_inv ...
    3400             : !> \param t_2c_bint ...
    3401             : !> \param t_2c_metric ...
    3402             : !> \param mat_2c_pot ...
    3403             : !> \param t_2c_work ...
    3404             : !> \param rho_ao_t ...
    3405             : !> \param rho_ao_t_sub ...
    3406             : !> \param t_2c_der_metric ...
    3407             : !> \param t_2c_der_metric_sub ...
    3408             : !> \param mat_der_pot ...
    3409             : !> \param mat_der_pot_sub ...
    3410             : !> \param group_size ...
    3411             : !> \param ngroups ...
    3412             : !> \param para_env ...
    3413             : !> \param para_env_sub ...
    3414             : !> \param ri_data ...
    3415             : !> \note Main MPI group tensors are deleted within this routine, for memory optimization
    3416             : ! **************************************************************************************************
    3417          84 :    SUBROUTINE get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
    3418          42 :                                      rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
    3419          42 :                                      mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
    3420             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_inv, t_2c_bint, t_2c_metric
    3421             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3422             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work
    3423             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
    3424             :                                                             t_2c_der_metric_sub
    3425             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot, mat_der_pot_sub
    3426             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3427             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3428             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3429             : 
    3430             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_derivs'
    3431             : 
    3432             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, i_xyz, &
    3433             :                                                             iatom, iproc, j, natom, nblks, nimg, &
    3434             :                                                             nspins
    3435             :       INTEGER(int_8)                                     :: nze
    3436             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3437          42 :                                                             dist1, dist2
    3438             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3439          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3440          42 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3441             :       REAL(dp)                                           :: occ
    3442             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3443         126 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3444         546 :       TYPE(dbt_type)                                     :: work, work_sub
    3445             : 
    3446          42 :       CALL timeset(routineN, handle)
    3447             : 
    3448             :       !Note: a fair portion of this routine is copied from the energy version of it
    3449             :       !Create the 2d pgrid
    3450          42 :       pdims_2d = 0
    3451          42 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3452             : 
    3453          42 :       natom = SIZE(ri_data%bsizes_RI)
    3454          42 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3455         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3456         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3457         294 :       DO i_RI = 1, ri_data%ncell_RI
    3458         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3459        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3460             :       END DO
    3461             : 
    3462             :       !nRI x nRI 2c tensors
    3463             :       CALL create_2c_tensor(t_2c_inv(1), dist1, dist2, pgrid_2d, &
    3464             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3465             :                             name="(RI | RI)")
    3466          42 :       DEALLOCATE (dist1, dist2)
    3467             : 
    3468          42 :       CALL dbt_create(t_2c_inv(1), t_2c_bint(1))
    3469          42 :       CALL dbt_create(t_2c_inv(1), t_2c_metric(1))
    3470          84 :       DO iatom = 2, natom
    3471          42 :          CALL dbt_create(t_2c_inv(1), t_2c_inv(iatom))
    3472          42 :          CALL dbt_create(t_2c_inv(1), t_2c_bint(iatom))
    3473          84 :          CALL dbt_create(t_2c_inv(1), t_2c_metric(iatom))
    3474             :       END DO
    3475          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(1))
    3476          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(2))
    3477          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(3))
    3478          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(4))
    3479             : 
    3480             :       CALL create_2c_tensor(t_2c_work(5), dist1, dist2, pgrid_2d, &
    3481             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3482          42 :                             name="(RI | RI)")
    3483          42 :       DEALLOCATE (dist1, dist2)
    3484             : 
    3485             :       !copy the data from the main group.
    3486         126 :       DO iatom = 1, natom
    3487          84 :          CALL copy_2c_to_subgroup(t_2c_inv(iatom), ri_data%t_2c_inv(1, iatom), group_size, ngroups, para_env)
    3488          84 :          CALL copy_2c_to_subgroup(t_2c_bint(iatom), ri_data%t_2c_int(1, iatom), group_size, ngroups, para_env)
    3489         126 :          CALL copy_2c_to_subgroup(t_2c_metric(iatom), ri_data%t_2c_pot(1, iatom), group_size, ngroups, para_env)
    3490             :       END DO
    3491             : 
    3492             :       !This includes the derivatives of the RI metric, for which there is one per atom
    3493         168 :       DO i_xyz = 1, 3
    3494         420 :          DO iatom = 1, natom
    3495         252 :             CALL dbt_create(t_2c_inv(1), t_2c_der_metric_sub(iatom, i_xyz))
    3496             :             CALL copy_2c_to_subgroup(t_2c_der_metric_sub(iatom, i_xyz), t_2c_der_metric(iatom, i_xyz), &
    3497         252 :                                      group_size, ngroups, para_env)
    3498         378 :             CALL dbt_destroy(t_2c_der_metric(iatom, i_xyz))
    3499             :          END DO
    3500             :       END DO
    3501             : 
    3502             :       !AO x AO 2c tensors
    3503             :       CALL create_2c_tensor(rho_ao_t_sub(1, 1), dist1, dist2, pgrid_2d, &
    3504             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3505             :                             name="(AO | AO)")
    3506          42 :       DEALLOCATE (dist1, dist2)
    3507          42 :       nspins = SIZE(rho_ao_t, 1)
    3508          42 :       nimg = SIZE(rho_ao_t, 2)
    3509             : 
    3510        1052 :       DO i_img = 1, nimg
    3511        2228 :          DO i_spin = 1, nspins
    3512        1176 :             IF (.NOT. (i_img == 1 .AND. i_spin == 1)) &
    3513        1134 :                CALL dbt_create(rho_ao_t_sub(1, 1), rho_ao_t_sub(i_spin, i_img))
    3514             :             CALL copy_2c_to_subgroup(rho_ao_t_sub(i_spin, i_img), rho_ao_t(i_spin, i_img), &
    3515        1176 :                                      group_size, ngroups, para_env)
    3516        2186 :             CALL dbt_destroy(rho_ao_t(i_spin, i_img))
    3517             :          END DO
    3518             :       END DO
    3519             : 
    3520             :       !The RIxRI matrices, going through tensors
    3521             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3522             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3523             :                             name="(RI | RI)")
    3524          42 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3525             : 
    3526         168 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3527          42 :       iproc = 0
    3528          84 :       DO i = 0, pdims_2d(1) - 1
    3529         126 :          DO j = 0, pdims_2d(2) - 1
    3530          42 :             dbcsr_pgrid(i, j) = iproc
    3531          84 :             iproc = iproc + 1
    3532             :          END DO
    3533             :       END DO
    3534             : 
    3535             :       !We need to have the same exact 2d block dist as the tensors
    3536         168 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3537         126 :       row_dist(:) = dist1(:)
    3538         126 :       col_dist(:) = dist2(:)
    3539             : 
    3540          84 :       ALLOCATE (RI_blk_size(natom))
    3541         126 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3542             : 
    3543             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3544          42 :                                   row_dist=row_dist, col_dist=col_dist)
    3545             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3546          42 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3547             : 
    3548             :       !The HFX potential
    3549        1052 :       DO i_img = 1, nimg
    3550        1010 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3551        1010 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3552        1010 :          CALL get_tensor_occupancy(work, nze, occ)
    3553        1010 :          IF (nze == 0) CYCLE
    3554             : 
    3555         654 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3556         654 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3557         654 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3558        1706 :          CALL dbt_clear(work_sub)
    3559             :       END DO
    3560             : 
    3561             :       !The derivatives of the HFX potential
    3562         168 :       DO i_xyz = 1, 3
    3563        3198 :          DO i_img = 1, nimg
    3564        3030 :             CALL dbcsr_create(mat_der_pot_sub(i_img, i_xyz), template=mat_2c_pot(1))
    3565        3030 :             CALL dbt_copy_matrix_to_tensor(mat_der_pot(i_img, i_xyz), work)
    3566        3030 :             CALL dbcsr_release(mat_der_pot(i_img, i_xyz))
    3567        3030 :             CALL get_tensor_occupancy(work, nze, occ)
    3568        3030 :             IF (nze == 0) CYCLE
    3569             : 
    3570        1958 :             CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3571        1958 :             CALL dbt_copy_tensor_to_matrix(work_sub, mat_der_pot_sub(i_img, i_xyz))
    3572        1958 :             CALL dbcsr_filter(mat_der_pot_sub(i_img, i_xyz), ri_data%filter_eps)
    3573        5114 :             CALL dbt_clear(work_sub)
    3574             :          END DO
    3575             :       END DO
    3576             : 
    3577          42 :       CALL dbt_destroy(work)
    3578          42 :       CALL dbt_destroy(work_sub)
    3579          42 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3580          42 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3581          42 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3582             : 
    3583          42 :       CALL timestop(handle)
    3584             : 
    3585         336 :    END SUBROUTINE get_subgroup_2c_derivs
    3586             : 
    3587             : ! **************************************************************************************************
    3588             : !> \brief copy all required 3c derivative tensors from the main MPI group to the subgroups
    3589             : !> \param t_3c_work_2 ...
    3590             : !> \param t_3c_work_3 ...
    3591             : !> \param t_3c_der_AO ...
    3592             : !> \param t_3c_der_AO_sub ...
    3593             : !> \param t_3c_der_RI ...
    3594             : !> \param t_3c_der_RI_sub ...
    3595             : !> \param t_3c_apc ...
    3596             : !> \param t_3c_apc_sub ...
    3597             : !> \param t_3c_der_stack ...
    3598             : !> \param group_size ...
    3599             : !> \param ngroups ...
    3600             : !> \param para_env ...
    3601             : !> \param para_env_sub ...
    3602             : !> \param ri_data ...
    3603             : !> \note the tensor containing the derivatives in the main MPI group are deleted for memory
    3604             : ! **************************************************************************************************
    3605          42 :    SUBROUTINE get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
    3606          42 :                                      t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, &
    3607          42 :                                      t_3c_der_stack, group_size, ngroups, para_env, para_env_sub, &
    3608             :                                      ri_data)
    3609             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_work_2, t_3c_work_3
    3610             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_AO, t_3c_der_AO_sub, &
    3611             :                                                             t_3c_der_RI, t_3c_der_RI_sub, &
    3612             :                                                             t_3c_apc, t_3c_apc_sub
    3613             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_der_stack
    3614             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3615             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3616             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3617             : 
    3618             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_derivs'
    3619             : 
    3620             :       INTEGER                                            :: batch_size, handle, i_img, i_RI, i_spin, &
    3621             :                                                             i_xyz, ib, nblks_AO, nblks_RI, nimg, &
    3622             :                                                             nspins, pdims(3)
    3623             :       INTEGER(int_8)                                     :: nze
    3624          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3625          42 :                                                             bsizes_stack, dist1, dist2, dist3, &
    3626          42 :                                                             dist_stack, idx_to_at
    3627             :       REAL(dp)                                           :: occ
    3628         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    3629         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3630        1050 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3631             : 
    3632          42 :       CALL timeset(routineN, handle)
    3633             : 
    3634             :       !We use intermediate tensors with larger block size for more optimized communication
    3635          42 :       nblks_RI = SIZE(ri_data%bsizes_RI)
    3636         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3637         294 :       DO i_RI = 1, ri_data%ncell_RI
    3638         798 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
    3639             :       END DO
    3640             : 
    3641          42 :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), pdims=pdims)
    3642          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3643             : 
    3644             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3645             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3646          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3647          42 :       DEALLOCATE (dist1, dist2, dist3)
    3648             : 
    3649             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3650             :                             ri_data%pgrid_2, bsizes_RI_ext, ri_data%bsizes_AO, &
    3651          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3652          42 :       DEALLOCATE (dist1, dist2, dist3)
    3653          42 :       CALL dbt_pgrid_destroy(pgrid)
    3654             : 
    3655             :       !We use the 3c integrals on the subgroup as template for the derivatives
    3656          42 :       nimg = ri_data%nimg
    3657         168 :       DO i_xyz = 1, 3
    3658        3156 :          DO i_img = 1, nimg
    3659        3030 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_AO_sub(i_img, i_xyz))
    3660        3030 :             CALL get_tensor_occupancy(t_3c_der_AO(i_img, i_xyz), nze, occ)
    3661        3030 :             IF (nze == 0) CYCLE
    3662             : 
    3663        1930 :             CALL dbt_copy(t_3c_der_AO(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3664             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3665        1930 :                                      group_size, ngroups, para_env)
    3666        5086 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_AO_sub(i_img, i_xyz), move_data=.TRUE.)
    3667             :          END DO
    3668             : 
    3669        3156 :          DO i_img = 1, nimg
    3670        3030 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_RI_sub(i_img, i_xyz))
    3671        3030 :             CALL get_tensor_occupancy(t_3c_der_RI(i_img, i_xyz), nze, occ)
    3672        3030 :             IF (nze == 0) CYCLE
    3673             : 
    3674        1910 :             CALL dbt_copy(t_3c_der_RI(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3675             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3676        1910 :                                      group_size, ngroups, para_env)
    3677        5066 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_RI_sub(i_img, i_xyz), move_data=.TRUE.)
    3678             :          END DO
    3679             : 
    3680        3198 :          DO i_img = 1, nimg
    3681        3030 :             CALL dbt_destroy(t_3c_der_RI(i_img, i_xyz))
    3682        3156 :             CALL dbt_destroy(t_3c_der_AO(i_img, i_xyz))
    3683             :          END DO
    3684             :       END DO
    3685          42 :       CALL dbt_destroy(work_atom_block_sub)
    3686          42 :       CALL dbt_destroy(work_atom_block)
    3687             : 
    3688             :       !Deal with t_3c_apc
    3689          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3690         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3691         294 :       DO i_RI = 1, ri_data%ncell_RI
    3692        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3693             :       END DO
    3694             : 
    3695          42 :       pdims = 0
    3696             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3697         168 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3698             : 
    3699             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3700             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3701          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3702          42 :       DEALLOCATE (dist1, dist2, dist3)
    3703             : 
    3704             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3705             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3706          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3707          42 :       DEALLOCATE (dist1, dist2, dist3)
    3708             : 
    3709             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3710             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3711          42 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3712          42 :       DEALLOCATE (dist1, dist2, dist3)
    3713             : 
    3714         126 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3715          42 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3716          42 :       nspins = SIZE(t_3c_apc, 1)
    3717        1052 :       DO i_img = 1, nimg
    3718        2186 :          DO i_spin = 1, nspins
    3719        1176 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3720        1176 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3721        1176 :             IF (nze == 0) CYCLE
    3722        1152 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3723             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3724        1152 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3725        3338 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3726             :          END DO
    3727        2228 :          DO i_spin = 1, nspins
    3728        2186 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3729             :          END DO
    3730             :       END DO
    3731          42 :       CALL dbt_destroy(tmp)
    3732          42 :       CALL dbt_destroy(work_atom_block)
    3733          42 :       CALL dbt_destroy(work_atom_block_sub)
    3734          42 :       CALL dbt_pgrid_destroy(pgrid)
    3735             : 
    3736             :       !t_3c_work_3 based on structure of 3c integrals/derivs
    3737          42 :       batch_size = ri_data%kp_stack_size
    3738          42 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3739         126 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3740        1386 :       DO ib = 1, batch_size
    3741        6826 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3742             :       END DO
    3743             : 
    3744         294 :       ALLOCATE (dist1(ri_data%ncell_RI*nblks_RI), dist2(nblks_AO), dist3(nblks_AO))
    3745             :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3746          42 :                         proc_dist_3=dist3, pdims=pdims)
    3747             : 
    3748         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3749        1386 :       DO ib = 1, batch_size
    3750        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3751             :       END DO
    3752             : 
    3753          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3754          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3755             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3756          42 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3757          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3758          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3759          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(4))
    3760          42 :       CALL dbt_distribution_destroy(t_dist)
    3761          42 :       CALL dbt_pgrid_destroy(pgrid)
    3762          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3763             : 
    3764             :       !the derivatives are stacked in the same way
    3765          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(1))
    3766          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(2))
    3767          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(3))
    3768          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(4))
    3769          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(5))
    3770          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(6))
    3771             : 
    3772             :       !t_3c_work_2 based on structure of t_3c_apc
    3773         294 :       ALLOCATE (dist1(nblks_AO), dist2(ri_data%ncell_RI*nblks_RI), dist3(nblks_AO))
    3774             :       CALL dbt_get_info(t_3c_apc_sub(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3775          42 :                         proc_dist_3=dist3, pdims=pdims)
    3776             : 
    3777         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3778        1386 :       DO ib = 1, batch_size
    3779        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3780             :       END DO
    3781             : 
    3782          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3783          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3784             :       CALL dbt_create(t_3c_work_2(1), "work_3_stack", t_dist, [1], [2, 3], &
    3785          42 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3786          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3787          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3788          42 :       CALL dbt_distribution_destroy(t_dist)
    3789          42 :       CALL dbt_pgrid_destroy(pgrid)
    3790          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3791             : 
    3792          42 :       CALL timestop(handle)
    3793             : 
    3794          84 :    END SUBROUTINE get_subgroup_3c_derivs
    3795             : 
    3796             : ! **************************************************************************************************
    3797             : !> \brief A routine that reorders the t_3c_int tensors such that all items which are fully empty
    3798             : !>        are bunched together. This way, we can get much more efficient screening based on NZE
    3799             : !> \param t_3c_ints ...
    3800             : !> \param ri_data ...
    3801             : ! **************************************************************************************************
    3802          70 :    SUBROUTINE reorder_3c_ints(t_3c_ints, ri_data)
    3803             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_ints
    3804             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3805             : 
    3806             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_ints'
    3807             : 
    3808             :       INTEGER                                            :: handle, i_img, idx, idx_empty, idx_full, &
    3809             :                                                             nimg
    3810             :       INTEGER(int_8)                                     :: nze
    3811             :       REAL(dp)                                           :: occ
    3812          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3813             : 
    3814          70 :       CALL timeset(routineN, handle)
    3815             : 
    3816          70 :       nimg = ri_data%nimg
    3817        2486 :       ALLOCATE (t_3c_tmp(nimg))
    3818        1786 :       DO i_img = 1, nimg
    3819        1716 :          CALL dbt_create(t_3c_ints(i_img), t_3c_tmp(i_img))
    3820        1786 :          CALL dbt_copy(t_3c_ints(i_img), t_3c_tmp(i_img), move_data=.TRUE.)
    3821             :       END DO
    3822             : 
    3823             :       !Loop over the images, check if ints have NZE == 0, and put them at the start or end of the
    3824             :       !initial tensor array. Keep the mapping in an array
    3825         210 :       ALLOCATE (ri_data%idx_to_img(nimg))
    3826          70 :       idx_full = 0
    3827          70 :       idx_empty = nimg + 1
    3828             : 
    3829        1786 :       DO i_img = 1, nimg
    3830        1716 :          CALL get_tensor_occupancy(t_3c_tmp(i_img), nze, occ)
    3831        1716 :          IF (nze == 0) THEN
    3832         480 :             idx_empty = idx_empty - 1
    3833         480 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_empty), move_data=.TRUE.)
    3834         480 :             ri_data%idx_to_img(idx_empty) = i_img
    3835             :          ELSE
    3836        1236 :             idx_full = idx_full + 1
    3837        1236 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_full), move_data=.TRUE.)
    3838        1236 :             ri_data%idx_to_img(idx_full) = i_img
    3839             :          END IF
    3840        3502 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3841             :       END DO
    3842             : 
    3843             :       !store the highest image index with non-zero integrals
    3844          70 :       ri_data%nimg_nze = idx_full
    3845             : 
    3846         140 :       ALLOCATE (ri_data%img_to_idx(nimg))
    3847        1786 :       DO idx = 1, nimg
    3848        1786 :          ri_data%img_to_idx(ri_data%idx_to_img(idx)) = idx
    3849             :       END DO
    3850             : 
    3851          70 :       CALL timestop(handle)
    3852             : 
    3853        1856 :    END SUBROUTINE reorder_3c_ints
    3854             : 
    3855             : ! **************************************************************************************************
    3856             : !> \brief A routine that reorders the 3c derivatives, the same way that the integrals are, also to
    3857             : !>        increase efficiency of screening
    3858             : !> \param t_3c_derivs ...
    3859             : !> \param ri_data ...
    3860             : ! **************************************************************************************************
    3861          84 :    SUBROUTINE reorder_3c_derivs(t_3c_derivs, ri_data)
    3862             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_derivs
    3863             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3864             : 
    3865             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_derivs'
    3866             : 
    3867             :       INTEGER                                            :: handle, i_img, i_xyz, idx, nimg
    3868             :       INTEGER(int_8)                                     :: nze
    3869             :       REAL(dp)                                           :: occ
    3870          84 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3871             : 
    3872          84 :       CALL timeset(routineN, handle)
    3873             : 
    3874          84 :       nimg = ri_data%nimg
    3875        2944 :       ALLOCATE (t_3c_tmp(nimg))
    3876        2104 :       DO i_img = 1, nimg
    3877        2104 :          CALL dbt_create(t_3c_derivs(1, 1), t_3c_tmp(i_img))
    3878             :       END DO
    3879             : 
    3880         336 :       DO i_xyz = 1, 3
    3881        6312 :          DO i_img = 1, nimg
    3882        6312 :             CALL dbt_copy(t_3c_derivs(i_img, i_xyz), t_3c_tmp(i_img), move_data=.TRUE.)
    3883             :          END DO
    3884        6396 :          DO i_img = 1, nimg
    3885        6060 :             idx = ri_data%img_to_idx(i_img)
    3886        6060 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_derivs(idx, i_xyz), move_data=.TRUE.)
    3887        6060 :             CALL get_tensor_occupancy(t_3c_derivs(idx, i_xyz), nze, occ)
    3888        6312 :             IF (nze > 0) ri_data%nimg_nze = MAX(idx, ri_data%nimg_nze)
    3889             :          END DO
    3890             :       END DO
    3891             : 
    3892        2104 :       DO i_img = 1, nimg
    3893        2104 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3894             :       END DO
    3895             : 
    3896          84 :       CALL timestop(handle)
    3897             : 
    3898        2188 :    END SUBROUTINE reorder_3c_derivs
    3899             : 
    3900             : ! **************************************************************************************************
    3901             : !> \brief Get the sparsity pattern related to the non-symmetric AO basis overlap neighbor list
    3902             : !> \param pattern ...
    3903             : !> \param ri_data ...
    3904             : !> \param qs_env ...
    3905             : ! **************************************************************************************************
    3906         232 :    SUBROUTINE get_sparsity_pattern(pattern, ri_data, qs_env)
    3907             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: pattern
    3908             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3909             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    3910             : 
    3911             :       INTEGER                                            :: iatom, j_img, jatom, mj_img, natom, nimg
    3912         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bins
    3913         232 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: tmp_pattern
    3914             :       INTEGER, DIMENSION(3)                              :: cell_j
    3915         232 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    3916         232 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    3917             :       TYPE(dft_control_type), POINTER                    :: dft_control
    3918             :       TYPE(kpoint_type), POINTER                         :: kpoints
    3919             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    3920             :       TYPE(neighbor_list_iterator_p_type), &
    3921         232 :          DIMENSION(:), POINTER                           :: nl_iterator
    3922             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    3923         232 :          POINTER                                         :: nl_2c
    3924             : 
    3925         232 :       NULLIFY (nl_2c, nl_iterator, kpoints, cell_to_index, dft_control, index_to_cell, para_env)
    3926             : 
    3927         232 :       CALL get_qs_env(qs_env, kpoints=kpoints, dft_control=dft_control, para_env=para_env, natom=natom)
    3928         232 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell, sab_nl=nl_2c)
    3929             : 
    3930         232 :       nimg = ri_data%nimg
    3931       40930 :       pattern(:, :, :) = 0
    3932             : 
    3933             :       !We use the symmetric nl for all images that have an opposite cell
    3934         232 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3935       10017 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3936        9785 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3937             : 
    3938        9785 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3939        9785 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3940             : 
    3941        6948 :          mj_img = get_opp_index(j_img, qs_env)
    3942        6948 :          IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3943             : 
    3944        6659 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3945             : 
    3946        9785 :          pattern(iatom, jatom, j_img) = 1
    3947             :       END DO
    3948         232 :       CALL neighbor_list_iterator_release(nl_iterator)
    3949             : 
    3950             :       !If there is no opposite cell present, then we take into account the non-symmetric nl
    3951         232 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=nl_2c)
    3952             : 
    3953         232 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3954       13102 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3955       12870 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3956             : 
    3957       12870 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3958       12870 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3959             : 
    3960        8872 :          mj_img = get_opp_index(j_img, qs_env)
    3961        8872 :          IF (mj_img .LE. nimg .AND. mj_img > 0) CYCLE
    3962             : 
    3963         298 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3964             : 
    3965       12870 :          pattern(iatom, jatom, j_img) = 1
    3966             :       END DO
    3967         232 :       CALL neighbor_list_iterator_release(nl_iterator)
    3968             : 
    3969       81628 :       CALL para_env%sum(pattern)
    3970             : 
    3971             :       !If the opposite image is considered, then there is no need to compute diagonal twice
    3972        5814 :       DO j_img = 2, nimg
    3973       16978 :          DO iatom = 1, natom
    3974       16746 :             IF (pattern(iatom, iatom, j_img) .NE. 0) THEN
    3975        3788 :                mj_img = get_opp_index(j_img, qs_env)
    3976        3788 :                IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3977        3788 :                pattern(iatom, iatom, mj_img) = 0
    3978             :             END IF
    3979             :          END DO
    3980             :       END DO
    3981             : 
    3982             :       ! We want to equilibrate the sparsity pattern such that there are same amount of blocks
    3983             :       ! for each atom i of i,j pairs
    3984         696 :       ALLOCATE (bins(natom))
    3985         696 :       bins(:) = 0
    3986             : 
    3987        1160 :       ALLOCATE (tmp_pattern(natom, natom, nimg))
    3988       40930 :       tmp_pattern(:, :, :) = 0
    3989        6046 :       DO j_img = 1, nimg
    3990       17674 :          DO jatom = 1, natom
    3991       40698 :             DO iatom = 1, natom
    3992       23256 :                IF (pattern(iatom, jatom, j_img) == 0) CYCLE
    3993        7622 :                mj_img = get_opp_index(j_img, qs_env)
    3994             : 
    3995             :                !Should we take the i,j,b or th j,i,-b atomic block?
    3996       19250 :                IF (mj_img > nimg .OR. mj_img < 1) THEN
    3997             :                   !No opposite image, no choice
    3998         198 :                   bins(iatom) = bins(iatom) + 1
    3999         198 :                   tmp_pattern(iatom, jatom, j_img) = 1
    4000             :                ELSE
    4001             : 
    4002        7424 :                   IF (bins(iatom) > bins(jatom)) THEN
    4003        1498 :                      bins(jatom) = bins(jatom) + 1
    4004        1498 :                      tmp_pattern(jatom, iatom, mj_img) = 1
    4005             :                   ELSE
    4006        5926 :                      bins(iatom) = bins(iatom) + 1
    4007        5926 :                      tmp_pattern(iatom, jatom, j_img) = 1
    4008             :                   END IF
    4009             :                END IF
    4010             :             END DO
    4011             :          END DO
    4012             :       END DO
    4013             : 
    4014             :       ! -1 => unoccupied, 0 => occupied
    4015       40930 :       pattern(:, :, :) = tmp_pattern(:, :, :) - 1
    4016             : 
    4017         464 :    END SUBROUTINE get_sparsity_pattern
    4018             : 
    4019             : ! **************************************************************************************************
    4020             : !> \brief Distribute the iatom, jatom, b_img triplet over the subgroupd to spread the load
    4021             : !>        the group id for each triplet is passed as the value of sparsity_pattern(i, j, b),
    4022             : !>        with -1 being an unoccupied block
    4023             : !> \param sparsity_pattern ...
    4024             : !> \param ngroups ...
    4025             : !> \param ri_data ...
    4026             : ! **************************************************************************************************
    4027         232 :    SUBROUTINE get_sub_dist(sparsity_pattern, ngroups, ri_data)
    4028             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: sparsity_pattern
    4029             :       INTEGER, INTENT(IN)                                :: ngroups
    4030             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4031             : 
    4032             :       INTEGER                                            :: b_img, ctr, iat, iatom, igroup, jatom, &
    4033             :                                                             natom, nimg, ub
    4034         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: max_at_per_group
    4035             :       REAL(dp)                                           :: cost
    4036         232 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4037             : 
    4038         232 :       natom = SIZE(sparsity_pattern, 2)
    4039         232 :       nimg = SIZE(sparsity_pattern, 3)
    4040             : 
    4041             :       !To avoid unnecessary data replication accross the subgroups, we want to have a limited number
    4042             :       !of subgroup with the data of a given iatom. At the minimum, all groups have 1 atom
    4043             :       !We assume that the cost associated to each iatom is roughly the same
    4044         232 :       IF (.NOT. ALLOCATED(ri_data%iatom_to_subgroup)) THEN
    4045         350 :          ALLOCATE (ri_data%iatom_to_subgroup(natom), max_at_per_group(ngroups))
    4046         150 :          DO iatom = 1, natom
    4047         100 :             NULLIFY (ri_data%iatom_to_subgroup(iatom)%array)
    4048         200 :             ALLOCATE (ri_data%iatom_to_subgroup(iatom)%array(ngroups))
    4049         350 :             ri_data%iatom_to_subgroup(iatom)%array(:) = .FALSE.
    4050             :          END DO
    4051             : 
    4052          50 :          ub = natom/ngroups
    4053          50 :          IF (ub*ngroups < natom) ub = ub + 1
    4054         150 :          max_at_per_group(:) = MAX(1, ub)
    4055             : 
    4056             :          !We want each atom to be present the same amount of times. Some groups might have more atoms
    4057             :          !than other to achieve this.
    4058             :          ctr = 0
    4059         150 :          DO WHILE (MODULO(SUM(max_at_per_group), natom) .NE. 0)
    4060           0 :             igroup = MODULO(ctr, ngroups) + 1
    4061           0 :             max_at_per_group(igroup) = max_at_per_group(igroup) + 1
    4062          50 :             ctr = ctr + 1
    4063             :          END DO
    4064             : 
    4065             :          ctr = 0
    4066         150 :          DO igroup = 1, ngroups
    4067         250 :             DO iat = 1, max_at_per_group(igroup)
    4068         100 :                iatom = MODULO(ctr, natom) + 1
    4069         100 :                ri_data%iatom_to_subgroup(iatom)%array(igroup) = .TRUE.
    4070         200 :                ctr = ctr + 1
    4071             :             END DO
    4072             :          END DO
    4073             :       END IF
    4074             : 
    4075         696 :       ALLOCATE (bins(ngroups))
    4076         696 :       bins = 0.0_dp
    4077        6046 :       DO b_img = 1, nimg
    4078       17674 :          DO jatom = 1, natom
    4079       40698 :             DO iatom = 1, natom
    4080       23256 :                IF (sparsity_pattern(iatom, jatom, b_img) == -1) CYCLE
    4081       38110 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4082             : 
    4083             :                !Use cost information from previous SCF if available
    4084      554266 :                IF (ANY(ri_data%kp_cost > EPSILON(0.0_dp))) THEN
    4085        5404 :                   cost = ri_data%kp_cost(iatom, jatom, b_img)
    4086             :                ELSE
    4087        2218 :                   cost = REAL(ri_data%bsizes_AO(iatom)*ri_data%bsizes_AO(jatom), dp)
    4088             :                END IF
    4089        7622 :                bins(igroup + 1) = bins(igroup + 1) + cost
    4090       34884 :                sparsity_pattern(iatom, jatom, b_img) = igroup
    4091             :             END DO
    4092             :          END DO
    4093             :       END DO
    4094             : 
    4095         232 :    END SUBROUTINE get_sub_dist
    4096             : 
    4097             : ! **************************************************************************************************
    4098             : !> \brief A rouine that updates the sparsity pattern for force calculation, where all i,j,b combinations
    4099             : !>        are visited.
    4100             : !> \param force_pattern ...
    4101             : !> \param scf_pattern ...
    4102             : !> \param ngroups ...
    4103             : !> \param ri_data ...
    4104             : !> \param qs_env ...
    4105             : ! **************************************************************************************************
    4106          42 :    SUBROUTINE update_pattern_to_forces(force_pattern, scf_pattern, ngroups, ri_data, qs_env)
    4107             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: force_pattern, scf_pattern
    4108             :       INTEGER, INTENT(IN)                                :: ngroups
    4109             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4110             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4111             : 
    4112             :       INTEGER                                            :: b_img, iatom, igroup, jatom, mb_img, &
    4113             :                                                             natom, nimg
    4114          42 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4115             : 
    4116          42 :       natom = SIZE(scf_pattern, 2)
    4117          42 :       nimg = SIZE(scf_pattern, 3)
    4118             : 
    4119         126 :       ALLOCATE (bins(ngroups))
    4120         126 :       bins = 0.0_dp
    4121             : 
    4122        1052 :       DO b_img = 1, nimg
    4123        1010 :          mb_img = get_opp_index(b_img, qs_env)
    4124        3072 :          DO jatom = 1, natom
    4125        7070 :             DO iatom = 1, natom
    4126             :                !Important: same distribution as KS matrix, because reuse t_3c_apc
    4127       20200 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4128             : 
    4129             :                !check that block not already treated
    4130        4040 :                IF (scf_pattern(iatom, jatom, b_img) > -1) CYCLE
    4131             : 
    4132             :                !If not, take the cost of block j, i, -b (same energy contribution)
    4133        4902 :                IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
    4134        2486 :                   IF (scf_pattern(jatom, iatom, mb_img) == -1) CYCLE
    4135        1038 :                   bins(igroup + 1) = bins(igroup + 1) + ri_data%kp_cost(jatom, iatom, mb_img)
    4136        1038 :                   force_pattern(iatom, jatom, b_img) = igroup
    4137             :                END IF
    4138             :             END DO
    4139             :          END DO
    4140             :       END DO
    4141             : 
    4142          42 :    END SUBROUTINE update_pattern_to_forces
    4143             : 
    4144             : ! **************************************************************************************************
    4145             : !> \brief A routine that determines the extend of the KP RI-HFX periodic images, including for the
    4146             : !>        extension of the RI basis
    4147             : !> \param ri_data ...
    4148             : !> \param qs_env ...
    4149             : ! **************************************************************************************************
    4150          70 :    SUBROUTINE get_kp_and_ri_images(ri_data, qs_env)
    4151             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4152             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4153             : 
    4154             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_kp_and_ri_images'
    4155             : 
    4156             :       INTEGER :: cell_j(3), cell_k(3), handle, i_img, iatom, ikind, j_img, jatom, jcell, katom, &
    4157             :          kcell, kp_index_lbounds(3), kp_index_ubounds(3), natom, ngroups, nimg, nkind, pcoord(3), &
    4158             :          pdims(3)
    4159          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist_AO_1, dist_AO_2, dist_RI, &
    4160          70 :                                                             nRI_per_atom, present_img, RI_cells
    4161          70 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    4162             :       REAL(dp)                                           :: bump_fact, dij, dik, image_range, &
    4163             :                                                             RI_range, rij(3), rik(3)
    4164         490 :       TYPE(dbt_type)                                     :: t_dummy
    4165             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4166             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4167             :       TYPE(distribution_3d_type)                         :: dist_3d
    4168             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4169          70 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4170             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4171          70 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4172             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4173             :       TYPE(neighbor_list_3c_iterator_type)               :: nl_3c_iter
    4174             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4175             :       TYPE(neighbor_list_iterator_p_type), &
    4176          70 :          DIMENSION(:), POINTER                           :: nl_iterator
    4177             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4178          70 :          POINTER                                         :: nl_2c
    4179          70 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4180          70 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4181             :       TYPE(section_vals_type), POINTER                   :: hfx_section
    4182             : 
    4183          70 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, nl_iterator, dft_control, &
    4184          70 :                particle_set, kpoints, para_env, cell_to_index, hfx_section)
    4185             : 
    4186          70 :       CALL timeset(routineN, handle)
    4187             : 
    4188             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, &
    4189             :                       dft_control=dft_control, particle_set=particle_set, kpoints=kpoints, &
    4190          70 :                       para_env=para_env, natom=natom)
    4191          70 :       nimg = dft_control%nimages
    4192          70 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index)
    4193         280 :       kp_index_lbounds = LBOUND(cell_to_index)
    4194         280 :       kp_index_ubounds = UBOUND(cell_to_index)
    4195             : 
    4196          70 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
    4197          70 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
    4198             : 
    4199         496 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4200          70 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4201          70 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4202             : 
    4203             :       !In case of shortrange HFX potential, it is imprtant to be consistent with the rest of the KP
    4204             :       !code, and use EPS_SCHWARZ to determine the range (rather than eps_filter_2c in normal RI-HFX)
    4205          70 :       IF (ri_data%hfx_pot%potential_type == do_potential_short) THEN
    4206           0 :          CALL erfc_cutoff(ri_data%eps_schwarz, ri_data%hfx_pot%omega, ri_data%hfx_pot%cutoff_radius)
    4207             :       END IF
    4208             : 
    4209             :       !Determine the range for contributing periodic images, and for the RI basis extension
    4210          70 :       ri_data%kp_RI_range = 0.0_dp
    4211          70 :       ri_data%kp_image_range = 0.0_dp
    4212         178 :       DO ikind = 1, nkind
    4213             : 
    4214         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4215         108 :          CALL get_gto_basis_set(basis_set_AO(ikind)%gto_basis_set, kind_radius=RI_range)
    4216         108 :          ri_data%kp_RI_range = MAX(RI_range, ri_data%kp_RI_range)
    4217             : 
    4218         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4219         108 :          CALL init_interaction_radii_orb_basis(basis_set_RI(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4220         108 :          CALL get_gto_basis_set(basis_set_RI(ikind)%gto_basis_set, kind_radius=image_range)
    4221             : 
    4222         108 :          image_range = 2.0_dp*image_range + cutoff_screen_factor*ri_data%hfx_pot%cutoff_radius
    4223         286 :          ri_data%kp_image_range = MAX(image_range, ri_data%kp_image_range)
    4224             :       END DO
    4225             : 
    4226          70 :       CALL section_vals_val_get(hfx_section, "KP_RI_BUMP_FACTOR", r_val=bump_fact)
    4227          70 :       ri_data%kp_bump_rad = bump_fact*ri_data%kp_RI_range
    4228             : 
    4229             :       !For the extent of the KP RI-HFX images, we are limited by the RI-HFX potential in
    4230             :       !(mu^0 sigma^a|P^0) (P^0|Q^b) (Q^b|nu^b lambda^a+c), if there is no contact between
    4231             :       !any P^0 and Q^b, then image b does not contribute
    4232             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4233          70 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4234             : 
    4235         210 :       ALLOCATE (present_img(nimg))
    4236        3120 :       present_img = 0
    4237          70 :       ri_data%nimg = 0
    4238          70 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    4239        1568 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    4240        1498 :          CALL get_iterator_info(nl_iterator, r=rij, cell=cell_j)
    4241             : 
    4242        5992 :          dij = NORM2(rij)
    4243             : 
    4244        1498 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4245        1498 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    4246             : 
    4247        1466 :          IF (dij > ri_data%kp_image_range) CYCLE
    4248             : 
    4249        1466 :          ri_data%nimg = MAX(j_img, ri_data%nimg)
    4250        1498 :          present_img(j_img) = 1
    4251             : 
    4252             :       END DO
    4253          70 :       CALL neighbor_list_iterator_release(nl_iterator)
    4254          70 :       CALL release_neighbor_list_sets(nl_2c)
    4255          70 :       CALL para_env%max(ri_data%nimg)
    4256          70 :       IF (ri_data%nimg > nimg) &
    4257           0 :          CPABORT("Make sure the smallest exponent of the RI-HFX basis is larger than that of the ORB basis.")
    4258             : 
    4259             :       !Keep track of which images will not contribute, so that can be ignored before calculation
    4260          70 :       CALL para_env%sum(present_img)
    4261         210 :       ALLOCATE (ri_data%present_images(ri_data%nimg))
    4262        1786 :       ri_data%present_images = 0
    4263        1786 :       DO i_img = 1, ri_data%nimg
    4264        1786 :          IF (present_img(i_img) > 0) ri_data%present_images(i_img) = 1
    4265             :       END DO
    4266             : 
    4267             :       CALL create_3c_tensor(t_dummy, dist_AO_1, dist_AO_2, dist_RI, &
    4268             :                             ri_data%pgrid, ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4269          70 :                             map1=[1, 2], map2=[3], name="(AO AO | RI)")
    4270             : 
    4271          70 :       CALL dbt_mp_environ_pgrid(ri_data%pgrid, pdims, pcoord)
    4272          70 :       CALL mp_comm_t3c%create(ri_data%pgrid%mp_comm_2d, 3, pdims)
    4273             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4274          70 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4275          70 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4276          70 :       CALL dbt_destroy(t_dummy)
    4277             : 
    4278             :       !For the extension of the RI basis P in (mu^0 sigma^a |P^i), we consider an atom if the distance,
    4279             :       !between mu^0 and P^i if smaller or equal to the kind radius of mu^0
    4280             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, &
    4281             :                                    ri_data%ri_metric, "HFX_3c_nl", qs_env, op_pos=2, sym_ij=.FALSE., &
    4282          70 :                                    own_dist=.TRUE.)
    4283             : 
    4284         140 :       ALLOCATE (RI_cells(nimg))
    4285        3120 :       RI_cells = 0
    4286             : 
    4287         210 :       ALLOCATE (nRI_per_atom(natom))
    4288         210 :       nRI_per_atom = 0
    4289             : 
    4290          70 :       CALL neighbor_list_3c_iterator_create(nl_3c_iter, nl_3c)
    4291       58508 :       DO WHILE (neighbor_list_3c_iterate(nl_3c_iter) == 0)
    4292             :          CALL get_3c_iterator_info(nl_3c_iter, cell_k=cell_k, rik=rik, cell_j=cell_j, &
    4293       58438 :                                    iatom=iatom, jatom=jatom, katom=katom)
    4294      233752 :          dik = NORM2(rik)
    4295             : 
    4296      409066 :          IF (ANY([cell_j(1), cell_j(2), cell_j(3)] < kp_index_lbounds) .OR. &
    4297             :              ANY([cell_j(1), cell_j(2), cell_j(3)] > kp_index_ubounds)) CYCLE
    4298             : 
    4299       58438 :          jcell = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4300       58438 :          IF (jcell > nimg .OR. jcell < 1) CYCLE
    4301             : 
    4302      385667 :          IF (ANY([cell_k(1), cell_k(2), cell_k(3)] < kp_index_lbounds) .OR. &
    4303             :              ANY([cell_k(1), cell_k(2), cell_k(3)] > kp_index_ubounds)) CYCLE
    4304             : 
    4305       51169 :          kcell = cell_to_index(cell_k(1), cell_k(2), cell_k(3))
    4306       51169 :          IF (kcell > nimg .OR. kcell < 1) CYCLE
    4307             : 
    4308       43523 :          IF (dik > ri_data%kp_RI_range) CYCLE
    4309        5791 :          RI_cells(kcell) = 1
    4310             : 
    4311        5861 :          IF (jcell == 1 .AND. iatom == jatom) nRI_per_atom(iatom) = nRI_per_atom(iatom) + ri_data%bsizes_RI(katom)
    4312             :       END DO
    4313          70 :       CALL neighbor_list_3c_iterator_destroy(nl_3c_iter)
    4314          70 :       CALL neighbor_list_3c_destroy(nl_3c)
    4315          70 :       CALL para_env%sum(RI_cells)
    4316          70 :       CALL para_env%sum(nRI_per_atom)
    4317             : 
    4318         140 :       ALLOCATE (ri_data%img_to_RI_cell(nimg))
    4319          70 :       ri_data%ncell_RI = 0
    4320        3120 :       ri_data%img_to_RI_cell = 0
    4321        3120 :       DO i_img = 1, nimg
    4322        3120 :          IF (RI_cells(i_img) > 0) THEN
    4323         436 :             ri_data%ncell_RI = ri_data%ncell_RI + 1
    4324         436 :             ri_data%img_to_RI_cell(i_img) = ri_data%ncell_RI
    4325             :          END IF
    4326             :       END DO
    4327             : 
    4328         210 :       ALLOCATE (ri_data%RI_cell_to_img(ri_data%ncell_RI))
    4329        3120 :       DO i_img = 1, nimg
    4330        3120 :          IF (ri_data%img_to_RI_cell(i_img) > 0) ri_data%RI_cell_to_img(ri_data%img_to_RI_cell(i_img)) = i_img
    4331             :       END DO
    4332             : 
    4333             :       !Print some info
    4334          70 :       IF (ri_data%unit_nr > 0) THEN
    4335             :          WRITE (ri_data%unit_nr, FMT="(/T3,A,I29)") &
    4336          35 :             "KP-HFX_RI_INFO| Number of RI-KP parallel groups:", ngroups
    4337             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F31.3,A)") &
    4338          35 :             "KP-HFX_RI_INFO| RI basis extension radius:", ri_data%kp_RI_range*angstrom, " Ang"
    4339             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F12.3,A, F6.3, A)") &
    4340          35 :             "KP-HFX_RI_INFO| RI basis bump factor and bump radius:", bump_fact, " /", &
    4341          70 :             ri_data%kp_bump_rad*angstrom, " Ang"
    4342             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I16,A)") &
    4343          35 :             "KP-HFX_RI_INFO| The extended RI bases cover up to ", ri_data%ncell_RI, " unit cells"
    4344             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I18)") &
    4345         105 :             "KP-HFX_RI_INFO| Average number of sgf in extended RI bases:", SUM(nRI_per_atom)/natom
    4346             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F13.3,A)") &
    4347          35 :             "KP-HFX_RI_INFO| Consider all image cells within a radius of ", ri_data%kp_image_range*angstrom, " Ang"
    4348             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I27/)") &
    4349          35 :             "KP-HFX_RI_INFO| Number of image cells considered: ", ri_data%nimg
    4350          35 :          CALL m_flush(ri_data%unit_nr)
    4351             :       END IF
    4352             : 
    4353          70 :       CALL timestop(handle)
    4354             : 
    4355         840 :    END SUBROUTINE get_kp_and_ri_images
    4356             : 
    4357             : ! **************************************************************************************************
    4358             : !> \brief A routine that creates tensors structure for rho_ao and 3c_ints in a stacked format for
    4359             : !>        the efficient contractions of rho_sigma^0,lambda^c * (mu^0 sigam^a | P) => TAS tensors
    4360             : !> \param res_stack ...
    4361             : !> \param rho_stack ...
    4362             : !> \param ints_stack ...
    4363             : !> \param rho_template ...
    4364             : !> \param ints_template ...
    4365             : !> \param stack_size ...
    4366             : !> \param ri_data ...
    4367             : !> \param qs_env ...
    4368             : !> \note The result tensor has the exact same shape and distribution as the integral tensor
    4369             : ! **************************************************************************************************
    4370         232 :    SUBROUTINE get_stack_tensors(res_stack, rho_stack, ints_stack, rho_template, ints_template, &
    4371             :                                 stack_size, ri_data, qs_env)
    4372             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: res_stack, rho_stack, ints_stack
    4373             :       TYPE(dbt_type), INTENT(INOUT)                      :: rho_template, ints_template
    4374             :       INTEGER, INTENT(IN)                                :: stack_size
    4375             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4376             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4377             : 
    4378             :       INTEGER                                            :: is, nblks, nblks_3c(3), pdims_3d(3)
    4379         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_stack, dist1, &
    4380         232 :                                                             dist2, dist3, dist_stack1, &
    4381         232 :                                                             dist_stack2, dist_stack3
    4382        2088 :       TYPE(dbt_distribution_type)                        :: t_dist
    4383         696 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4384             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4385             : 
    4386         232 :       NULLIFY (para_env)
    4387             : 
    4388         232 :       CALL get_qs_env(qs_env, para_env=para_env)
    4389             : 
    4390         232 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4391         696 :       ALLOCATE (bsizes_stack(stack_size*nblks))
    4392        3082 :       DO is = 1, stack_size
    4393       11582 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    4394             :       END DO
    4395             : 
    4396        2088 :       ALLOCATE (dist1(nblks), dist2(nblks), dist_stack1(stack_size*nblks), dist_stack2(stack_size*nblks))
    4397         232 :       CALL dbt_get_info(rho_template, proc_dist_1=dist1, proc_dist_2=dist2)
    4398        3082 :       DO is = 1, stack_size
    4399       11350 :          dist_stack1((is - 1)*nblks + 1:is*nblks) = dist1(:)
    4400       11582 :          dist_stack2((is - 1)*nblks + 1:is*nblks) = dist2(:)
    4401             :       END DO
    4402             : 
    4403             :       !First 2c tensor matches the distribution of template
    4404             :       !It is stacked in both directions
    4405         232 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_2d, dist_stack1, dist_stack2)
    4406         232 :       CALL dbt_create(rho_stack(1), "RHO_stack", t_dist, [1], [2], bsizes_stack, bsizes_stack)
    4407         232 :       CALL dbt_distribution_destroy(t_dist)
    4408         232 :       DEALLOCATE (dist1, dist2, dist_stack1, dist_stack2)
    4409             : 
    4410             :       !Second 2c tensor has optimal distribution on the 2d pgrid
    4411         232 :       CALL create_2c_tensor(rho_stack(2), dist1, dist2, ri_data%pgrid_2d, bsizes_stack, bsizes_stack, name="RHO_stack")
    4412         232 :       DEALLOCATE (dist1, dist2)
    4413             : 
    4414         232 :       CALL dbt_get_info(ints_template, nblks_total=nblks_3c)
    4415        1624 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)))
    4416        1160 :       ALLOCATE (dist_stack3(stack_size*nblks_3c(3)), bsizes_RI_ext(nblks_3c(2)))
    4417             :       CALL dbt_get_info(ints_template, proc_dist_1=dist1, proc_dist_2=dist2, &
    4418         232 :                         proc_dist_3=dist3, blk_size_2=bsizes_RI_ext)
    4419        3082 :       DO is = 1, stack_size
    4420       11582 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    4421             :       END DO
    4422             : 
    4423             :       !First 3c tensor matches the distribution of template
    4424         232 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_1, dist1, dist2, dist_stack3)
    4425             :       CALL dbt_create(ints_stack(1), "ints_stack", t_dist, [1, 2], [3], ri_data%bsizes_AO_split, &
    4426         232 :                       bsizes_RI_ext, bsizes_stack)
    4427         232 :       CALL dbt_distribution_destroy(t_dist)
    4428         232 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    4429             : 
    4430             :       !Second 3c tensor has optimal pgrid
    4431         232 :       pdims_3d = 0
    4432         928 :       CALL dbt_pgrid_create(para_env, pdims_3d, pgrid, tensor_dims=[nblks_3c(1), nblks_3c(2), stack_size*nblks_3c(3)])
    4433             :       CALL create_3c_tensor(ints_stack(2), dist1, dist2, dist3, pgrid, ri_data%bsizes_AO_split, &
    4434         232 :                             bsizes_RI_ext, bsizes_stack, [1, 2], [3], name="ints_stack")
    4435         232 :       DEALLOCATE (dist1, dist2, dist3)
    4436         232 :       CALL dbt_pgrid_destroy(pgrid)
    4437             : 
    4438             :       !The result tensor has the same shape and dist as the integral tensor
    4439         232 :       CALL dbt_create(ints_stack(1), res_stack(1))
    4440         232 :       CALL dbt_create(ints_stack(2), res_stack(2))
    4441             : 
    4442         464 :    END SUBROUTINE get_stack_tensors
    4443             : 
    4444             : ! **************************************************************************************************
    4445             : !> \brief Fill the stack of 3c tensors accrding to the order in the images input
    4446             : !> \param t_3c_stack ...
    4447             : !> \param t_3c_in ...
    4448             : !> \param images ...
    4449             : !> \param stack_dim ...
    4450             : !> \param ri_data ...
    4451             : !> \param filter_at ...
    4452             : !> \param filter_dim ...
    4453             : !> \param idx_to_at ...
    4454             : !> \param img_bounds ...
    4455             : ! **************************************************************************************************
    4456       22326 :    SUBROUTINE fill_3c_stack(t_3c_stack, t_3c_in, images, stack_dim, ri_data, filter_at, filter_dim, &
    4457       22326 :                             idx_to_at, img_bounds)
    4458             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_stack
    4459             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_in
    4460             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4461             :       INTEGER, INTENT(IN)                                :: stack_dim
    4462             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4463             :       INTEGER, INTENT(IN), OPTIONAL                      :: filter_at, filter_dim
    4464             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: idx_to_at
    4465             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2)
    4466             : 
    4467             :       INTEGER                                            :: dest(3), i_img, idx, ind(3), lb, nblks, &
    4468             :                                                             nimg, offset, ub
    4469             :       LOGICAL                                            :: do_filter, found
    4470       22326 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4471             :       TYPE(dbt_iterator_type)                            :: iter
    4472             : 
    4473             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4474             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4475       22326 :       nimg = ri_data%nimg
    4476       22326 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4477             : 
    4478       22326 :       do_filter = .FALSE.
    4479       21806 :       IF (PRESENT(filter_at) .AND. PRESENT(filter_dim) .AND. PRESENT(idx_to_at)) do_filter = .TRUE.
    4480             : 
    4481       22326 :       lb = 1
    4482       22326 :       ub = nimg
    4483       22326 :       offset = 0
    4484       22326 :       IF (PRESENT(img_bounds)) THEN
    4485       22326 :          lb = img_bounds(1)
    4486       22326 :          ub = img_bounds(2) - 1
    4487       22326 :          offset = lb - 1
    4488             :       END IF
    4489             : 
    4490      449976 :       DO idx = lb, ub
    4491      427650 :          i_img = images(idx)
    4492      427650 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4493             : 
    4494             : !$OMP PARALLEL DEFAULT(NONE) &
    4495             : !$OMP SHARED(idx,i_img,t_3c_in,t_3c_stack,nblks,stack_dim,filter_at,filter_dim,idx_to_at,do_filter,offset) &
    4496      449976 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4497             :          CALL dbt_iterator_start(iter, t_3c_in(i_img))
    4498             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4499             :             CALL dbt_iterator_next_block(iter, ind)
    4500             :             CALL dbt_get_block(t_3c_in(i_img), ind, blk, found)
    4501             :             IF (.NOT. found) CYCLE
    4502             : 
    4503             :             IF (do_filter) THEN
    4504             :                IF (.NOT. idx_to_at(ind(filter_dim)) == filter_at) CYCLE
    4505             :             END IF
    4506             : 
    4507             :             IF (stack_dim == 1) THEN
    4508             :                dest = [(idx - offset - 1)*nblks + ind(1), ind(2), ind(3)]
    4509             :             ELSE IF (stack_dim == 2) THEN
    4510             :                dest = [ind(1), (idx - offset - 1)*nblks + ind(2), ind(3)]
    4511             :             ELSE
    4512             :                dest = [ind(1), ind(2), (idx - offset - 1)*nblks + ind(3)]
    4513             :             END IF
    4514             : 
    4515             :             CALL dbt_put_block(t_3c_stack, dest, SHAPE(blk), blk)
    4516             :             DEALLOCATE (blk)
    4517             :          END DO
    4518             :          CALL dbt_iterator_stop(iter)
    4519             : !$OMP END PARALLEL
    4520             :       END DO !i_img
    4521       22326 :       CALL dbt_finalize(t_3c_stack)
    4522             : 
    4523       44652 :    END SUBROUTINE fill_3c_stack
    4524             : 
    4525             : ! **************************************************************************************************
    4526             : !> \brief Fill the stack of 2c tensors based on the content of images input
    4527             : !> \param t_2c_stack ...
    4528             : !> \param t_2c_in ...
    4529             : !> \param images ...
    4530             : !> \param stack_dim ...
    4531             : !> \param ri_data ...
    4532             : !> \param img_bounds ...
    4533             : !> \param shift ...
    4534             : ! **************************************************************************************************
    4535       15228 :    SUBROUTINE fill_2c_stack(t_2c_stack, t_2c_in, images, stack_dim, ri_data, img_bounds, shift)
    4536             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_stack
    4537             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_in
    4538             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4539             :       INTEGER, INTENT(IN)                                :: stack_dim
    4540             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4541             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2), shift
    4542             : 
    4543             :       INTEGER                                            :: dest(2), i_img, idx, ind(2), lb, &
    4544             :                                                             my_shift, nblks, nimg, offset, ub
    4545             :       LOGICAL                                            :: found
    4546       15228 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    4547             :       TYPE(dbt_iterator_type)                            :: iter
    4548             : 
    4549             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4550             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4551       15228 :       nimg = ri_data%nimg
    4552       15228 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4553             : 
    4554       15228 :       lb = 1
    4555       15228 :       ub = nimg
    4556       15228 :       offset = 0
    4557       15228 :       IF (PRESENT(img_bounds)) THEN
    4558       15228 :          lb = img_bounds(1)
    4559       15228 :          ub = img_bounds(2) - 1
    4560       15228 :          offset = lb - 1
    4561             :       END IF
    4562             : 
    4563       15228 :       my_shift = 1
    4564       15228 :       IF (PRESENT(shift)) my_shift = shift
    4565             : 
    4566      202454 :       DO idx = lb, ub
    4567      187226 :          i_img = images(idx)
    4568      187226 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4569             : 
    4570             : !$OMP PARALLEL DEFAULT(NONE) SHARED(idx,i_img,t_2c_in,t_2c_stack,nblks,stack_dim,offset,my_shift) &
    4571      202454 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4572             :          CALL dbt_iterator_start(iter, t_2c_in(i_img))
    4573             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4574             :             CALL dbt_iterator_next_block(iter, ind)
    4575             :             CALL dbt_get_block(t_2c_in(i_img), ind, blk, found)
    4576             :             IF (.NOT. found) CYCLE
    4577             : 
    4578             :             IF (stack_dim == 1) THEN
    4579             :                dest = [(idx - offset - 1)*nblks + ind(1), (my_shift - 1)*nblks + ind(2)]
    4580             :             ELSE
    4581             :                dest = [(my_shift - 1)*nblks + ind(1), (idx - offset - 1)*nblks + ind(2)]
    4582             :             END IF
    4583             : 
    4584             :             CALL dbt_put_block(t_2c_stack, dest, SHAPE(blk), blk)
    4585             :             DEALLOCATE (blk)
    4586             :          END DO
    4587             :          CALL dbt_iterator_stop(iter)
    4588             : !$OMP END PARALLEL
    4589             :       END DO !idx
    4590       15228 :       CALL dbt_finalize(t_2c_stack)
    4591             : 
    4592       30456 :    END SUBROUTINE fill_2c_stack
    4593             : 
    4594             : ! **************************************************************************************************
    4595             : !> \brief Unstacks a stacked 3c tensor containing t_3c_apc
    4596             : !> \param t_3c_apc ...
    4597             : !> \param t_stacked ...
    4598             : !> \param idx ...
    4599             : ! **************************************************************************************************
    4600       17830 :    SUBROUTINE unstack_t_3c_apc(t_3c_apc, t_stacked, idx)
    4601             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_apc, t_stacked
    4602             :       INTEGER, INTENT(IN)                                :: idx
    4603             : 
    4604             :       INTEGER                                            :: current_idx
    4605             :       INTEGER, DIMENSION(3)                              :: ind, nblks_3c
    4606             :       LOGICAL                                            :: found
    4607       17830 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4608             :       TYPE(dbt_iterator_type)                            :: iter
    4609             : 
    4610             :       !Note: t_3c_apc and t_stacked must have the same ditribution
    4611       17830 :       CALL dbt_get_info(t_3c_apc, nblks_total=nblks_3c)
    4612             : 
    4613       17830 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_apc,t_stacked,idx,nblks_3c) PRIVATE(iter,ind,blk,found,current_idx)
    4614             :       CALL dbt_iterator_start(iter, t_stacked)
    4615             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4616             :          CALL dbt_iterator_next_block(iter, ind)
    4617             : 
    4618             :          !tensor is stacked along the 3rd dimension
    4619             :          current_idx = (ind(3) - 1)/nblks_3c(3) + 1
    4620             :          IF (.NOT. idx == current_idx) CYCLE
    4621             : 
    4622             :          CALL dbt_get_block(t_stacked, ind, blk, found)
    4623             :          IF (.NOT. found) CYCLE
    4624             : 
    4625             :          CALL dbt_put_block(t_3c_apc, [ind(1), ind(2), ind(3) - (idx - 1)*nblks_3c(3)], SHAPE(blk), blk)
    4626             :          DEALLOCATE (blk)
    4627             :       END DO
    4628             :       CALL dbt_iterator_stop(iter)
    4629             : !$OMP END PARALLEL
    4630             : 
    4631       17830 :    END SUBROUTINE unstack_t_3c_apc
    4632             : 
    4633             : ! **************************************************************************************************
    4634             : !> \brief copies the 3c integrals correspoinding to a single atom mu from the general (P^0| mu^0 sigam^a)
    4635             : !> \param t_3c_at ...
    4636             : !> \param t_3c_ints ...
    4637             : !> \param iatom ...
    4638             : !> \param dim_at ...
    4639             : !> \param idx_to_at ...
    4640             : ! **************************************************************************************************
    4641           0 :    SUBROUTINE get_atom_3c_ints(t_3c_at, t_3c_ints, iatom, dim_at, idx_to_at)
    4642             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_at, t_3c_ints
    4643             :       INTEGER, INTENT(IN)                                :: iatom, dim_at
    4644             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: idx_to_at
    4645             : 
    4646             :       INTEGER, DIMENSION(3)                              :: ind
    4647             :       LOGICAL                                            :: found
    4648           0 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4649             :       TYPE(dbt_iterator_type)                            :: iter
    4650             : 
    4651           0 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_ints,t_3c_at,iatom,idx_to_at,dim_at) PRIVATE(iter,ind,blk,found)
    4652             :       CALL dbt_iterator_start(iter, t_3c_ints)
    4653             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4654             :          CALL dbt_iterator_next_block(iter, ind)
    4655             :          IF (.NOT. idx_to_at(ind(dim_at)) == iatom) CYCLE
    4656             : 
    4657             :          CALL dbt_get_block(t_3c_ints, ind, blk, found)
    4658             :          IF (.NOT. found) CYCLE
    4659             : 
    4660             :          CALL dbt_put_block(t_3c_at, ind, SHAPE(blk), blk)
    4661             :          DEALLOCATE (blk)
    4662             :       END DO
    4663             :       CALL dbt_iterator_stop(iter)
    4664             : !$OMP END PARALLEL
    4665           0 :       CALL dbt_finalize(t_3c_at)
    4666             : 
    4667           0 :    END SUBROUTINE get_atom_3c_ints
    4668             : 
    4669             : ! **************************************************************************************************
    4670             : !> \brief Precalculate the 3c and 2c derivatives tensors
    4671             : !> \param t_3c_der_RI ...
    4672             : !> \param t_3c_der_AO ...
    4673             : !> \param mat_der_pot ...
    4674             : !> \param t_2c_der_metric ...
    4675             : !> \param ri_data ...
    4676             : !> \param qs_env ...
    4677             : ! **************************************************************************************************
    4678          42 :    SUBROUTINE precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
    4679             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_RI, t_3c_der_AO
    4680             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot
    4681             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_2c_der_metric
    4682             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4683             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4684             : 
    4685             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'precalc_derivatives'
    4686             : 
    4687             :       INTEGER                                            :: handle, handle2, i_img, i_mem, i_RI, &
    4688             :                                                             i_xyz, iatom, n_mem, natom, nblks_RI, &
    4689             :                                                             ncell_RI, nimg, nkind, nthreads
    4690             :       INTEGER(int_8)                                     :: nze
    4691          42 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: bsizes_RI_ext, bsizes_RI_ext_split, dist_AO_1, &
    4692          84 :          dist_AO_2, dist_RI, dist_RI_ext, dummy_end, dummy_start, end_blocks, start_blocks
    4693             :       INTEGER, DIMENSION(3)                              :: pcoord, pdims
    4694          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_bsize, row_bsize
    4695             :       REAL(dp)                                           :: occ
    4696             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist
    4697             :       TYPE(dbcsr_type)                                   :: dbcsr_template
    4698          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_metric
    4699         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    4700         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4701         378 :       TYPE(dbt_type)                                     :: t_3c_template
    4702          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :, :)    :: t_3c_der_AO_prv, t_3c_der_RI_prv
    4703             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4704             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4705             :       TYPE(distribution_3d_type)                         :: dist_3d
    4706             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4707          42 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4708          42 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4709             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4710             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4711             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4712          42 :          POINTER                                         :: nl_2c
    4713          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4714          42 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4715             : 
    4716          42 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, particle_set, dft_control, para_env, row_bsize, col_bsize)
    4717             : 
    4718          42 :       CALL timeset(routineN, handle)
    4719             : 
    4720             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, natom=natom, &
    4721          42 :                       particle_set=particle_set, dft_control=dft_control, para_env=para_env)
    4722             : 
    4723          42 :       nimg = ri_data%nimg
    4724          42 :       ncell_RI = ri_data%ncell_RI
    4725             : 
    4726         300 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4727          42 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4728          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_RI)
    4729          42 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4730          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_AO)
    4731             : 
    4732             :       !Dealing with the 3c derivatives
    4733          42 :       nthreads = 1
    4734          42 : !$    nthreads = omp_get_num_threads()
    4735          42 :       pdims = 0
    4736         168 :       CALL dbt_pgrid_create(para_env, pdims, pgrid, tensor_dims=[MAX(1, natom/(ri_data%n_mem*nthreads)), natom, natom])
    4737             : 
    4738             :       CALL create_3c_tensor(t_3c_template, dist_AO_1, dist_AO_2, dist_RI, pgrid, &
    4739             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4740          42 :                             map1=[1, 2], map2=[3], name="tmp")
    4741          42 :       CALL dbt_destroy(t_3c_template)
    4742             : 
    4743             :       !We stack the RI basis images. Keep consistent distribution
    4744          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    4745         126 :       ALLOCATE (dist_RI_ext(natom*ncell_RI))
    4746          84 :       ALLOCATE (bsizes_RI_ext(natom*ncell_RI))
    4747         126 :       ALLOCATE (bsizes_RI_ext_split(nblks_RI*ncell_RI))
    4748         294 :       DO i_RI = 1, ncell_RI
    4749         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    4750         756 :          dist_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = dist_RI(:)
    4751        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    4752             :       END DO
    4753             : 
    4754          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist_AO_1, dist_AO_2, dist_RI_ext)
    4755             :       CALL dbt_create(t_3c_template, "KP_3c_der", t_dist, [1, 2], [3], &
    4756          42 :                       ri_data%bsizes_AO, ri_data%bsizes_AO, bsizes_RI_ext)
    4757          42 :       CALL dbt_distribution_destroy(t_dist)
    4758             : 
    4759        7404 :       ALLOCATE (t_3c_der_RI_prv(nimg, 1, 3), t_3c_der_AO_prv(nimg, 1, 3))
    4760         168 :       DO i_xyz = 1, 3
    4761        3198 :          DO i_img = 1, nimg
    4762        3030 :             CALL dbt_create(t_3c_template, t_3c_der_RI_prv(i_img, 1, i_xyz))
    4763        3156 :             CALL dbt_create(t_3c_template, t_3c_der_AO_prv(i_img, 1, i_xyz))
    4764             :          END DO
    4765             :       END DO
    4766          42 :       CALL dbt_destroy(t_3c_template)
    4767             : 
    4768          42 :       CALL dbt_mp_environ_pgrid(pgrid, pdims, pcoord)
    4769          42 :       CALL mp_comm_t3c%create(pgrid%mp_comm_2d, 3, pdims)
    4770             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4771          42 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4772          42 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4773          42 :       CALL dbt_pgrid_destroy(pgrid)
    4774             : 
    4775             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, ri_data%ri_metric, &
    4776          42 :                                    "HFX_3c_nl", qs_env, op_pos=2, sym_jk=.FALSE., own_dist=.TRUE.)
    4777             : 
    4778          42 :       n_mem = ri_data%n_mem
    4779             :       CALL create_tensor_batches(ri_data%bsizes_RI, n_mem, dummy_start, dummy_end, &
    4780             :                                  start_blocks, end_blocks)
    4781          42 :       DEALLOCATE (dummy_start, dummy_end)
    4782             : 
    4783             :       CALL create_3c_tensor(t_3c_template, dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid_2, &
    4784             :                             bsizes_RI_ext_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    4785          42 :                             map1=[1], map2=[2, 3], name="der (RI | AO AO)")
    4786         168 :       DO i_xyz = 1, 3
    4787        3198 :          DO i_img = 1, nimg
    4788        3030 :             CALL dbt_create(t_3c_template, t_3c_der_RI(i_img, i_xyz))
    4789        3156 :             CALL dbt_create(t_3c_template, t_3c_der_AO(i_img, i_xyz))
    4790             :          END DO
    4791             :       END DO
    4792             : 
    4793         116 :       DO i_mem = 1, n_mem
    4794             :          CALL build_3c_derivatives(t_3c_der_AO_prv, t_3c_der_RI_prv, ri_data%filter_eps, qs_env, &
    4795             :                                    nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, &
    4796             :                                    ri_data%ri_metric, der_eps=ri_data%eps_schwarz_forces, op_pos=2, &
    4797             :                                    do_kpoints=.TRUE., do_hfx_kpoints=.TRUE., &
    4798             :                                    bounds_k=[start_blocks(i_mem), end_blocks(i_mem)], &
    4799         222 :                                    RI_range=ri_data%kp_RI_range, img_to_RI_cell=ri_data%img_to_RI_cell)
    4800             : 
    4801          74 :          CALL timeset(routineN//"_cpy", handle2)
    4802             :          !We go from (mu^0 sigma^i | P^j) to (P^i| sigma^j mu^0) and finally to (P^i| mu^0 sigma^j)
    4803        1994 :          DO i_img = 1, nimg
    4804        7754 :             DO i_xyz = 1, 3
    4805             :                !derivative wrt to mu^0
    4806        5760 :                CALL get_tensor_occupancy(t_3c_der_AO_prv(i_img, 1, i_xyz), nze, occ)
    4807        5760 :                IF (nze > 0) THEN
    4808             :                   CALL dbt_copy(t_3c_der_AO_prv(i_img, 1, i_xyz), t_3c_template, &
    4809        3434 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4810        3434 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4811             :                   CALL dbt_copy(t_3c_template, t_3c_der_AO(i_img, i_xyz), &
    4812        3434 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4813             :                END IF
    4814             : 
    4815             :                !derivative wrt to P^i
    4816        5760 :                CALL get_tensor_occupancy(t_3c_der_RI_prv(i_img, 1, i_xyz), nze, occ)
    4817       13440 :                IF (nze > 0) THEN
    4818             :                   CALL dbt_copy(t_3c_der_RI_prv(i_img, 1, i_xyz), t_3c_template, &
    4819        3424 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4820        3424 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4821             :                   CALL dbt_copy(t_3c_template, t_3c_der_RI(i_img, i_xyz), &
    4822        3424 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4823             :                END IF
    4824             :             END DO
    4825             :          END DO
    4826         190 :          CALL timestop(handle2)
    4827             :       END DO
    4828          42 :       CALL dbt_destroy(t_3c_template)
    4829             : 
    4830          42 :       CALL neighbor_list_3c_destroy(nl_3c)
    4831         168 :       DO i_xyz = 1, 3
    4832        3198 :          DO i_img = 1, nimg
    4833        3030 :             CALL dbt_destroy(t_3c_der_RI_prv(i_img, 1, i_xyz))
    4834        3156 :             CALL dbt_destroy(t_3c_der_AO_prv(i_img, 1, i_xyz))
    4835             :          END DO
    4836             :       END DO
    4837        6102 :       DEALLOCATE (t_3c_der_RI_prv, t_3c_der_AO_prv)
    4838             : 
    4839             :       !Reorder 3c derivatives to be consistant with ints
    4840          42 :       CALL reorder_3c_derivs(t_3c_der_RI, ri_data)
    4841          42 :       CALL reorder_3c_derivs(t_3c_der_AO, ri_data)
    4842             : 
    4843          42 :       CALL timeset(routineN//"_2c", handle2)
    4844             :       !The 2-center derivatives
    4845          42 :       CALL cp_dbcsr_dist2d_to_dist(dist_2d, dbcsr_dist)
    4846         126 :       ALLOCATE (row_bsize(SIZE(ri_data%bsizes_RI)))
    4847          84 :       ALLOCATE (col_bsize(SIZE(ri_data%bsizes_RI)))
    4848         126 :       row_bsize(:) = ri_data%bsizes_RI
    4849         126 :       col_bsize(:) = ri_data%bsizes_RI
    4850             : 
    4851             :       CALL dbcsr_create(dbcsr_template, "2c_der", dbcsr_dist, dbcsr_type_no_symmetry, &
    4852          42 :                         row_bsize, col_bsize)
    4853          42 :       CALL dbcsr_distribution_release(dbcsr_dist)
    4854          42 :       DEALLOCATE (col_bsize, row_bsize)
    4855             : 
    4856        3282 :       ALLOCATE (mat_der_metric(nimg, 3))
    4857         168 :       DO i_xyz = 1, 3
    4858        3198 :          DO i_img = 1, nimg
    4859        3030 :             CALL dbcsr_create(mat_der_pot(i_img, i_xyz), template=dbcsr_template)
    4860        3156 :             CALL dbcsr_create(mat_der_metric(i_img, i_xyz), template=dbcsr_template)
    4861             :          END DO
    4862             :       END DO
    4863          42 :       CALL dbcsr_release(dbcsr_template)
    4864             : 
    4865             :       !HFX potential derivatives
    4866             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4867          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4868             :       CALL build_2c_derivatives(mat_der_pot, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4869          42 :                                 basis_set_RI, basis_set_RI, ri_data%hfx_pot, do_kpoints=.TRUE.)
    4870          42 :       CALL release_neighbor_list_sets(nl_2c)
    4871             : 
    4872             :       !RI metric derivatives
    4873             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    4874          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4875             :       CALL build_2c_derivatives(mat_der_metric, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4876          42 :                                 basis_set_RI, basis_set_RI, ri_data%ri_metric, do_kpoints=.TRUE.)
    4877          42 :       CALL release_neighbor_list_sets(nl_2c)
    4878             : 
    4879             :       !Get into extended RI basis and tensor format
    4880         168 :       DO i_xyz = 1, 3
    4881         378 :          DO iatom = 1, natom
    4882         252 :             CALL dbt_create(ri_data%t_2c_inv(1, 1), t_2c_der_metric(iatom, i_xyz))
    4883             :             CALL get_ext_2c_int(t_2c_der_metric(iatom, i_xyz), mat_der_metric(:, i_xyz), &
    4884         378 :                                 iatom, iatom, 1, ri_data, qs_env)
    4885             :          END DO
    4886        3198 :          DO i_img = 1, nimg
    4887        3156 :             CALL dbcsr_release(mat_der_metric(i_img, i_xyz))
    4888             :          END DO
    4889             :       END DO
    4890          42 :       CALL timestop(handle2)
    4891             : 
    4892          42 :       CALL timestop(handle)
    4893             : 
    4894         252 :    END SUBROUTINE precalc_derivatives
    4895             : 
    4896             : ! **************************************************************************************************
    4897             : !> \brief Update the forces due to the derivative of the a 2-center product d/dR (Q|R)
    4898             : !> \param force ...
    4899             : !> \param t_2c_contr A precontracted tensor containing sum_abcdPS (ab|P)(P|Q)^-1 (R|S)^-1 (S|cd) P_ac P_bd
    4900             : !> \param t_2c_der the d/dR (Q|R) tensor, in all 3 cartesian directions
    4901             : !> \param atom_of_kind ...
    4902             : !> \param kind_of ...
    4903             : !> \param img in which periodic image the second center of the tensor is
    4904             : !> \param pref ...
    4905             : !> \param ri_data ...
    4906             : !> \param qs_env ...
    4907             : !> \param work_virial ...
    4908             : !> \param cell ...
    4909             : !> \param particle_set ...
    4910             : !> \param diag ...
    4911             : !> \param offdiag ...
    4912             : !> \note IMPORTANT: t_tc_contr and t_2c_der need to have the same distribution. Atomic block sizes are
    4913             : !>                  assumed
    4914             : ! **************************************************************************************************
    4915        2895 :    SUBROUTINE get_2c_der_force(force, t_2c_contr, t_2c_der, atom_of_kind, kind_of, img, pref, &
    4916             :                                ri_data, qs_env, work_virial, cell, particle_set, diag, offdiag)
    4917             : 
    4918             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    4919             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_contr
    4920             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_der
    4921             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    4922             :       INTEGER, INTENT(IN)                                :: img
    4923             :       REAL(dp), INTENT(IN)                               :: pref
    4924             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4925             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4926             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    4927             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    4928             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    4929             :          POINTER                                         :: particle_set
    4930             :       LOGICAL, INTENT(IN), OPTIONAL                      :: diag, offdiag
    4931             : 
    4932             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_2c_der_force'
    4933             : 
    4934             :       INTEGER                                            :: handle, i_img, i_RI, i_xyz, iat, &
    4935             :                                                             iat_of_kind, ikind, j_img, j_RI, &
    4936             :                                                             j_xyz, jat, jat_of_kind, jkind, natom
    4937             :       INTEGER, DIMENSION(2)                              :: ind
    4938        2895 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    4939             :       LOGICAL                                            :: found, my_diag, my_offdiag, use_virial
    4940             :       REAL(dp)                                           :: new_force
    4941        2895 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :), TARGET     :: contr_blk, der_blk
    4942             :       REAL(dp), DIMENSION(3)                             :: scoord
    4943             :       TYPE(dbt_iterator_type)                            :: iter
    4944             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4945             : 
    4946        2895 :       NULLIFY (kpoints, index_to_cell)
    4947             : 
    4948             :       !Loop over the blocks of d/dR (Q|R), contract with the corresponding block of t_2c_contr and
    4949             :       !update the relevant force
    4950             : 
    4951        2895 :       CALL timeset(routineN, handle)
    4952             : 
    4953        2895 :       use_virial = .FALSE.
    4954        2895 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    4955             : 
    4956        2895 :       my_diag = .FALSE.
    4957        2895 :       IF (PRESENT(diag)) my_diag = diag
    4958             : 
    4959        2316 :       my_offdiag = .FALSE.
    4960        2316 :       IF (PRESENT(diag)) my_offdiag = offdiag
    4961             : 
    4962        2895 :       CALL get_qs_env(qs_env, kpoints=kpoints, natom=natom)
    4963        2895 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    4964             : 
    4965             : !$OMP PARALLEL DEFAULT(NONE) &
    4966             : !$OMP SHARED(t_2c_der,t_2c_contr,work_virial,force,use_virial,natom,index_to_cell,ri_data,img) &
    4967             : !$OMP SHARED(pref,atom_of_kind,kind_of,particle_set,cell,my_diag,my_offdiag) &
    4968             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk,contr_blk,found,new_force,i_RI,i_img,j_RI,j_img) &
    4969        2895 : !$OMP PRIVATE(iat,jat,iat_of_kind,jat_of_kind,ikind,jkind,scoord)
    4970             :       DO i_xyz = 1, 3
    4971             :          CALL dbt_iterator_start(iter, t_2c_der(i_xyz))
    4972             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4973             :             CALL dbt_iterator_next_block(iter, ind)
    4974             : 
    4975             :             !Only take forecs due to block diagonal or block off-diagonal, depending on arguments
    4976             :             IF ((my_diag .AND. .NOT. my_offdiag) .OR. (.NOT. my_diag .AND. my_offdiag)) THEN
    4977             :                IF (my_diag .AND. (ind(1) .NE. ind(2))) CYCLE
    4978             :                IF (my_offdiag .AND. (ind(1) == ind(2))) CYCLE
    4979             :             END IF
    4980             : 
    4981             :             CALL dbt_get_block(t_2c_der(i_xyz), ind, der_blk, found)
    4982             :             CPASSERT(found)
    4983             :             CALL dbt_get_block(t_2c_contr, ind, contr_blk, found)
    4984             : 
    4985             :             IF (found) THEN
    4986             : 
    4987             :                !an element of d/dR (Q|R) corresponds to 2 things because of translational invariance
    4988             :                !(Q'| R) = - (Q| R'), once wrt the center on Q, and once on R
    4989             :                new_force = pref*SUM(der_blk(:, :)*contr_blk(:, :))
    4990             : 
    4991             :                i_RI = (ind(1) - 1)/natom + 1
    4992             :                i_img = ri_data%RI_cell_to_img(i_RI)
    4993             :                iat = ind(1) - (i_RI - 1)*natom
    4994             :                iat_of_kind = atom_of_kind(iat)
    4995             :                ikind = kind_of(iat)
    4996             : 
    4997             :                j_RI = (ind(2) - 1)/natom + 1
    4998             :                j_img = ri_data%RI_cell_to_img(j_RI)
    4999             :                jat = ind(2) - (j_RI - 1)*natom
    5000             :                jat_of_kind = atom_of_kind(jat)
    5001             :                jkind = kind_of(jat)
    5002             : 
    5003             :                !Force on iatom (first center)
    5004             : !$OMP ATOMIC
    5005             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5006             :                                                           + new_force
    5007             : 
    5008             :                IF (use_virial) THEN
    5009             : 
    5010             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5011             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_img), dp)
    5012             : 
    5013             :                   DO j_xyz = 1, 3
    5014             : !$OMP ATOMIC
    5015             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5016             :                   END DO
    5017             :                END IF
    5018             : 
    5019             :                !Force on jatom (second center)
    5020             : !$OMP ATOMIC
    5021             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5022             :                                                           - new_force
    5023             : 
    5024             :                IF (use_virial) THEN
    5025             : 
    5026             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5027             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, j_img) + index_to_cell(:, img), dp)
    5028             : 
    5029             :                   DO j_xyz = 1, 3
    5030             : !$OMP ATOMIC
    5031             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) - new_force*scoord(j_xyz)
    5032             :                   END DO
    5033             :                END IF
    5034             : 
    5035             :                DEALLOCATE (contr_blk)
    5036             :             END IF
    5037             : 
    5038             :             DEALLOCATE (der_blk)
    5039             :          END DO !iter
    5040             :          CALL dbt_iterator_stop(iter)
    5041             : 
    5042             :       END DO !i_xyz
    5043             : !$OMP END PARALLEL
    5044        2895 :       CALL timestop(handle)
    5045             : 
    5046        5790 :    END SUBROUTINE get_2c_der_force
    5047             : 
    5048             : ! **************************************************************************************************
    5049             : !> \brief This routines calculates the force contribution from a trace over 3D tensors, i.e.
    5050             : !>        force = sum_ijk A_ijk B_ijk., the B tensor is (P^0| sigma^0 lambda^img), with P in the
    5051             : !>        extended RI basis. Note that all tensors are stacked along the 3rd dimension
    5052             : !> \param force ...
    5053             : !> \param t_3c_contr ...
    5054             : !> \param t_3c_der_1 ...
    5055             : !> \param t_3c_der_2 ...
    5056             : !> \param atom_of_kind ...
    5057             : !> \param kind_of ...
    5058             : !> \param idx_to_at_RI ...
    5059             : !> \param idx_to_at_AO ...
    5060             : !> \param i_images ...
    5061             : !> \param lb_img ...
    5062             : !> \param pref ...
    5063             : !> \param ri_data ...
    5064             : !> \param qs_env ...
    5065             : !> \param work_virial ...
    5066             : !> \param cell ...
    5067             : !> \param particle_set ...
    5068             : ! **************************************************************************************************
    5069        1666 :    SUBROUTINE get_force_from_3c_trace(force, t_3c_contr, t_3c_der_1, t_3c_der_2, atom_of_kind, kind_of, &
    5070        3332 :                                       idx_to_at_RI, idx_to_at_AO, i_images, lb_img, pref, &
    5071             :                                       ri_data, qs_env, work_virial, cell, particle_set)
    5072             : 
    5073             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    5074             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_contr
    5075             :       TYPE(dbt_type), DIMENSION(3), INTENT(INOUT)        :: t_3c_der_1, t_3c_der_2
    5076             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of, idx_to_at_RI, &
    5077             :                                                             idx_to_at_AO, i_images
    5078             :       INTEGER, INTENT(IN)                                :: lb_img
    5079             :       REAL(dp), INTENT(IN)                               :: pref
    5080             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    5081             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    5082             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    5083             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    5084             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    5085             :          POINTER                                         :: particle_set
    5086             : 
    5087             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_force_from_3c_trace'
    5088             : 
    5089             :       INTEGER :: handle, i_RI, i_xyz, iat, iat_of_kind, idx, ikind, j_xyz, jat, jat_of_kind, &
    5090             :          jkind, kat, kat_of_kind, kkind, nblks_AO, nblks_RI, RI_img
    5091             :       INTEGER, DIMENSION(3)                              :: ind
    5092        1666 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    5093             :       LOGICAL                                            :: found, found_1, found_2, use_virial
    5094             :       REAL(dp)                                           :: new_force
    5095        1666 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET  :: contr_blk, der_blk_1, der_blk_2, &
    5096        1666 :                                                             der_blk_3
    5097             :       REAL(dp), DIMENSION(3)                             :: scoord
    5098             :       TYPE(dbt_iterator_type)                            :: iter
    5099             :       TYPE(kpoint_type), POINTER                         :: kpoints
    5100             : 
    5101        1666 :       NULLIFY (kpoints, index_to_cell)
    5102             : 
    5103        1666 :       CALL timeset(routineN, handle)
    5104             : 
    5105        1666 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    5106        1666 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    5107             : 
    5108        1666 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    5109        1666 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    5110             : 
    5111        1666 :       use_virial = .FALSE.
    5112        1666 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    5113             : 
    5114             : !$OMP PARALLEL DEFAULT(NONE) &
    5115             : !$OMP SHARED(t_3c_der_1, t_3c_der_2,t_3c_contr,work_virial,force,use_virial,index_to_cell,i_images,lb_img) &
    5116             : !$OMP SHARED(pref,idx_to_at_AO,atom_of_kind,kind_of,particle_set,cell,idx_to_at_RI,ri_data,nblks_RI,nblks_AO) &
    5117             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk_1,contr_blk,found,new_force,iat,iat_of_kind,ikind,scoord) &
    5118        1666 : !$OMP PRIVATE(jat,kat,jat_of_kind,kat_of_kind,jkind,kkind,i_RI,RI_img,der_blk_2,der_blk_3,found_1,found_2,idx)
    5119             :       CALL dbt_iterator_start(iter, t_3c_contr)
    5120             :       DO WHILE (dbt_iterator_blocks_left(iter))
    5121             :          CALL dbt_iterator_next_block(iter, ind)
    5122             : 
    5123             :          CALL dbt_get_block(t_3c_contr, ind, contr_blk, found)
    5124             :          IF (found) THEN
    5125             : 
    5126             :             DO i_xyz = 1, 3
    5127             :                CALL dbt_get_block(t_3c_der_1(i_xyz), ind, der_blk_1, found_1)
    5128             :                IF (.NOT. found_1) THEN
    5129             :                   DEALLOCATE (der_blk_1)
    5130             :                   ALLOCATE (der_blk_1(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5131             :                   der_blk_1(:, :, :) = 0.0_dp
    5132             :                END IF
    5133             :                CALL dbt_get_block(t_3c_der_2(i_xyz), ind, der_blk_2, found_2)
    5134             :                IF (.NOT. found_2) THEN
    5135             :                   DEALLOCATE (der_blk_2)
    5136             :                   ALLOCATE (der_blk_2(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5137             :                   der_blk_2(:, :, :) = 0.0_dp
    5138             :                END IF
    5139             : 
    5140             :                ALLOCATE (der_blk_3(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5141             :                der_blk_3(:, :, :) = -(der_blk_1(:, :, :) + der_blk_2(:, :, :))
    5142             : 
    5143             :                !We assume the tensors are in the format (P^0| sigma^0 mu^a+c-b), with P a member of the
    5144             :                !extended RI basis set
    5145             : 
    5146             :                !Force for the first center (RI extended basis, zero cell)
    5147             :                new_force = pref*SUM(der_blk_1(:, :, :)*contr_blk(:, :, :))
    5148             : 
    5149             :                i_RI = (ind(1) - 1)/nblks_RI + 1
    5150             :                RI_img = ri_data%RI_cell_to_img(i_RI)
    5151             :                iat = idx_to_at_RI(ind(1) - (i_RI - 1)*nblks_RI)
    5152             :                iat_of_kind = atom_of_kind(iat)
    5153             :                ikind = kind_of(iat)
    5154             : 
    5155             : !$OMP ATOMIC
    5156             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5157             :                                                           + new_force
    5158             : 
    5159             :                IF (use_virial) THEN
    5160             : 
    5161             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5162             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, RI_img), dp)
    5163             : 
    5164             :                   DO j_xyz = 1, 3
    5165             : !$OMP ATOMIC
    5166             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5167             :                   END DO
    5168             :                END IF
    5169             : 
    5170             :                !Force with respect to the second center (AO basis, zero cell)
    5171             :                new_force = pref*SUM(der_blk_2(:, :, :)*contr_blk(:, :, :))
    5172             :                jat = idx_to_at_AO(ind(2))
    5173             :                jat_of_kind = atom_of_kind(jat)
    5174             :                jkind = kind_of(jat)
    5175             : 
    5176             : !$OMP ATOMIC
    5177             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5178             :                                                           + new_force
    5179             : 
    5180             :                IF (use_virial) THEN
    5181             : 
    5182             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5183             : 
    5184             :                   DO j_xyz = 1, 3
    5185             : !$OMP ATOMIC
    5186             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5187             :                   END DO
    5188             :                END IF
    5189             : 
    5190             :                !Force with respect to the third center (AO basis, apc_img - b_img)
    5191             :                !Note: tensors are stacked along the 3rd direction
    5192             :                new_force = pref*SUM(der_blk_3(:, :, :)*contr_blk(:, :, :))
    5193             :                idx = (ind(3) - 1)/nblks_AO + 1
    5194             :                kat = idx_to_at_AO(ind(3) - (idx - 1)*nblks_AO)
    5195             :                kat_of_kind = atom_of_kind(kat)
    5196             :                kkind = kind_of(kat)
    5197             : 
    5198             : !$OMP ATOMIC
    5199             :                force(kkind)%fock_4c(i_xyz, kat_of_kind) = force(kkind)%fock_4c(i_xyz, kat_of_kind) &
    5200             :                                                           + new_force
    5201             : 
    5202             :                IF (use_virial) THEN
    5203             :                   CALL real_to_scaled(scoord, pbc(particle_set(kat)%r, cell), cell)
    5204             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_images(lb_img - 1 + idx)), dp)
    5205             : 
    5206             :                   DO j_xyz = 1, 3
    5207             : !$OMP ATOMIC
    5208             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5209             :                   END DO
    5210             :                END IF
    5211             : 
    5212             :                DEALLOCATE (der_blk_1, der_blk_2, der_blk_3)
    5213             :             END DO !i_xyz
    5214             :             DEALLOCATE (contr_blk)
    5215             :          END IF !found
    5216             :       END DO !iter
    5217             :       CALL dbt_iterator_stop(iter)
    5218             : !$OMP END PARALLEL
    5219        1666 :       CALL timestop(handle)
    5220             : 
    5221        3332 :    END SUBROUTINE get_force_from_3c_trace
    5222             : 
    5223             : END MODULE

Generated by: LCOV version 1.15