LCOV - code coverage report
Current view: top level - src - hfx_ri_kp.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b4bd748) Lines: 2161 2191 98.6 %
Date: 2025-03-09 07:56:22 Functions: 38 40 95.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief RI-methods for HFX and K-points.
      10             : !> \auhtor Augustin Bussy (01.2023)
      11             : ! **************************************************************************************************
      12             : 
      13             : MODULE hfx_ri_kp
      14             :    USE admm_types,                      ONLY: get_admm_env
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      16             :                                               get_atomic_kind_set
      17             :    USE basis_set_types,                 ONLY: get_gto_basis_set,&
      18             :                                               gto_basis_set_p_type
      19             :    USE bibliography,                    ONLY: Bussy2024,&
      20             :                                               cite_reference
      21             :    USE cell_types,                      ONLY: cell_type,&
      22             :                                               pbc,&
      23             :                                               real_to_scaled,&
      24             :                                               scaled_to_real
      25             :    USE cp_array_utils,                  ONLY: cp_1d_logical_p_type,&
      26             :                                               cp_2d_r_p_type,&
      27             :                                               cp_3d_r_p_type
      28             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_create,&
      29             :                                               cp_blacs_env_release,&
      30             :                                               cp_blacs_env_type
      31             :    USE cp_control_types,                ONLY: dft_control_type
      32             :    USE cp_dbcsr_api,                    ONLY: &
      33             :         dbcsr_add, dbcsr_clear, dbcsr_copy, dbcsr_create, dbcsr_distribution_get, &
      34             :         dbcsr_distribution_new, dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_filter, &
      35             :         dbcsr_finalize, dbcsr_get_block_p, dbcsr_get_info, dbcsr_iterator_blocks_left, &
      36             :         dbcsr_iterator_next_block, dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, &
      37             :         dbcsr_p_type, dbcsr_put_block, dbcsr_release, dbcsr_type, dbcsr_type_no_symmetry, &
      38             :         dbcsr_type_symmetric
      39             :    USE cp_dbcsr_cholesky,               ONLY: cp_dbcsr_cholesky_decompose,&
      40             :                                               cp_dbcsr_cholesky_invert
      41             :    USE cp_dbcsr_contrib,                ONLY: dbcsr_dot
      42             :    USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
      43             :    USE cp_dbcsr_diag,                   ONLY: cp_dbcsr_power
      44             :    USE cp_dbcsr_operations,             ONLY: cp_dbcsr_dist2d_to_dist
      45             :    USE dbt_api,                         ONLY: &
      46             :         dbt_batched_contract_finalize, dbt_batched_contract_init, dbt_clear, dbt_contract, &
      47             :         dbt_copy, dbt_copy_matrix_to_tensor, dbt_copy_tensor_to_matrix, dbt_create, dbt_destroy, &
      48             :         dbt_distribution_destroy, dbt_distribution_new, dbt_distribution_type, dbt_filter, &
      49             :         dbt_finalize, dbt_get_block, dbt_get_info, dbt_get_stored_coordinates, &
      50             :         dbt_iterator_blocks_left, dbt_iterator_next_block, dbt_iterator_start, dbt_iterator_stop, &
      51             :         dbt_iterator_type, dbt_mp_environ_pgrid, dbt_pgrid_create, dbt_pgrid_destroy, &
      52             :         dbt_pgrid_type, dbt_put_block, dbt_scale, dbt_type
      53             :    USE distribution_2d_types,           ONLY: distribution_2d_release,&
      54             :                                               distribution_2d_type
      55             :    USE hfx_ri,                          ONLY: get_idx_to_atom,&
      56             :                                               hfx_ri_pre_scf_calc_tensors
      57             :    USE hfx_types,                       ONLY: hfx_ri_type
      58             :    USE input_constants,                 ONLY: do_potential_short,&
      59             :                                               hfx_ri_do_2c_cholesky,&
      60             :                                               hfx_ri_do_2c_diag,&
      61             :                                               hfx_ri_do_2c_iter
      62             :    USE input_cp2k_hfx,                  ONLY: ri_pmat
      63             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      64             :                                               section_vals_type,&
      65             :                                               section_vals_val_get,&
      66             :                                               section_vals_val_set
      67             :    USE iterate_matrix,                  ONLY: invert_hotelling
      68             :    USE kinds,                           ONLY: dp,&
      69             :                                               int_8
      70             :    USE kpoint_types,                    ONLY: get_kpoint_info,&
      71             :                                               kpoint_type
      72             :    USE libint_2c_3c,                    ONLY: cutoff_screen_factor
      73             :    USE machine,                         ONLY: m_flush,&
      74             :                                               m_walltime
      75             :    USE mathlib,                         ONLY: erfc_cutoff
      76             :    USE message_passing,                 ONLY: mp_cart_type,&
      77             :                                               mp_para_env_type,&
      78             :                                               mp_request_type,&
      79             :                                               mp_waitall
      80             :    USE particle_methods,                ONLY: get_particle_set
      81             :    USE particle_types,                  ONLY: particle_type
      82             :    USE physcon,                         ONLY: angstrom
      83             :    USE qs_environment_types,            ONLY: get_qs_env,&
      84             :                                               qs_environment_type
      85             :    USE qs_force_types,                  ONLY: qs_force_type
      86             :    USE qs_integral_utils,               ONLY: basis_set_list_setup
      87             :    USE qs_interactions,                 ONLY: init_interaction_radii_orb_basis
      88             :    USE qs_kind_types,                   ONLY: qs_kind_type
      89             :    USE qs_neighbor_list_types,          ONLY: get_iterator_info,&
      90             :                                               neighbor_list_iterate,&
      91             :                                               neighbor_list_iterator_create,&
      92             :                                               neighbor_list_iterator_p_type,&
      93             :                                               neighbor_list_iterator_release,&
      94             :                                               neighbor_list_set_p_type,&
      95             :                                               release_neighbor_list_sets
      96             :    USE qs_scf_types,                    ONLY: qs_scf_env_type
      97             :    USE qs_tensors,                      ONLY: &
      98             :         build_2c_derivatives, build_2c_neighbor_lists, build_3c_derivatives, &
      99             :         build_3c_neighbor_lists, get_3c_iterator_info, get_tensor_occupancy, &
     100             :         neighbor_list_3c_destroy, neighbor_list_3c_iterate, neighbor_list_3c_iterator_create, &
     101             :         neighbor_list_3c_iterator_destroy
     102             :    USE qs_tensors_types,                ONLY: create_2c_tensor,&
     103             :                                               create_3c_tensor,&
     104             :                                               create_tensor_batches,&
     105             :                                               distribution_2d_create,&
     106             :                                               distribution_3d_create,&
     107             :                                               distribution_3d_type,&
     108             :                                               neighbor_list_3c_iterator_type,&
     109             :                                               neighbor_list_3c_type
     110             :    USE util,                            ONLY: get_limit
     111             :    USE virial_types,                    ONLY: virial_type
     112             : #include "./base/base_uses.f90"
     113             : 
     114             : !$ USE OMP_LIB, ONLY: omp_get_num_threads
     115             : 
     116             :    IMPLICIT NONE
     117             :    PRIVATE
     118             : 
     119             :    PUBLIC :: hfx_ri_update_ks_kp, hfx_ri_update_forces_kp
     120             : 
     121             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_ri_kp'
     122             : CONTAINS
     123             : 
     124             : ! NOTES: for a start, we do not seek performance, but accuracy. So in this first implementation,
     125             : !        we give little consideration to batching, load balance and such.
     126             : !        We also put everything here, even if there is some code replication with the original RI_HFX
     127             : !        We will only work in the RHO flavor
     128             : !        For now, we will also always assume that there is a single para_env, and that there is no
     129             : !        K-point subgroup. This might change in the future
     130             : 
     131             : ! **************************************************************************************************
     132             : !> \brief I_1nitialize the ri_data for K-point. For now, we take the normal, usual existing ri_data
     133             : !>        and we adapt it to our needs
     134             : !> \param dbcsr_template ...
     135             : !> \param ri_data ...
     136             : !> \param qs_env ...
     137             : ! **************************************************************************************************
     138          70 :    SUBROUTINE adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     139             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     140             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     141             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     142             : 
     143             :       INTEGER                                            :: i_img, i_RI, i_spin, iatom, natom, &
     144             :                                                             nblks_RI, nimg, nkind, nspins
     145          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, dist1, dist2, dist3
     146             :       TYPE(dft_control_type), POINTER                    :: dft_control
     147             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     148             : 
     149          70 :       NULLIFY (dft_control, para_env)
     150             : 
     151             :       !The main thing that we need to do is to allocate more space for the integrals, such that there
     152             :       !is room for each periodic image. Note that we only go in 1D, i.e. we store (mu^0 sigma^a|P^0),
     153             :       !and (P^0|Q^a) => the RI basis is always in the main cell.
     154             : 
     155             :       !Get kpoint info
     156          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, para_env=para_env, nkind=nkind)
     157          70 :       nimg = ri_data%nimg
     158             : 
     159             :       !Along the RI direction we have basis elements spread accross ncell_RI images.
     160          70 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
     161         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     162         506 :       DO i_RI = 1, ri_data%ncell_RI
     163        2344 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
     164             :       END DO
     165             : 
     166        4202 :       ALLOCATE (ri_data%t_3c_int_ctr_1(1, nimg))
     167             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_1(1, 1), dist1, dist2, dist3, &
     168             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     169          70 :                             ri_data%bsizes_AO_split, [1, 2], [3], name="(AO RI | AO)")
     170             : 
     171        1716 :       DO i_img = 2, nimg
     172        1716 :          CALL dbt_create(ri_data%t_3c_int_ctr_1(1, 1), ri_data%t_3c_int_ctr_1(1, i_img))
     173             :       END DO
     174          70 :       DEALLOCATE (dist1, dist2, dist3)
     175             : 
     176         770 :       ALLOCATE (ri_data%t_3c_int_ctr_2(1, 1))
     177             :       CALL create_3c_tensor(ri_data%t_3c_int_ctr_2(1, 1), dist1, dist2, dist3, &
     178             :                             ri_data%pgrid_1, ri_data%bsizes_AO_split, bsizes_RI_ext, &
     179          70 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
     180          70 :       DEALLOCATE (dist1, dist2, dist3)
     181             : 
     182             :       !We use full block sizes for the 2c quantities
     183          70 :       DEALLOCATE (bsizes_RI_ext)
     184          70 :       nblks_RI = SIZE(ri_data%bsizes_RI)
     185         210 :       ALLOCATE (bsizes_RI_ext(nblks_RI*ri_data%ncell_RI))
     186         506 :       DO i_RI = 1, ri_data%ncell_RI
     187        1378 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
     188             :       END DO
     189             : 
     190        3010 :       ALLOCATE (ri_data%t_2c_inv(1, natom), ri_data%t_2c_int(1, natom), ri_data%t_2c_pot(1, natom))
     191             :       CALL create_2c_tensor(ri_data%t_2c_inv(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     192             :                             bsizes_RI_ext, bsizes_RI_ext, &
     193          70 :                             name="(RI | RI)")
     194          70 :       DEALLOCATE (dist1, dist2)
     195          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, 1))
     196          70 :       CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, 1))
     197         140 :       DO iatom = 2, natom
     198          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_inv(1, iatom))
     199          70 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_int(1, iatom))
     200         140 :          CALL dbt_create(ri_data%t_2c_inv(1, 1), ri_data%t_2c_pot(1, iatom))
     201             :       END DO
     202             : 
     203         350 :       ALLOCATE (ri_data%kp_cost(natom, natom, nimg))
     204       12082 :       ri_data%kp_cost = 0.0_dp
     205             : 
     206             :       !We store the density and KS matrix in tensor format
     207          70 :       nspins = dft_control%nspins
     208        9218 :       ALLOCATE (ri_data%rho_ao_t(nspins, nimg), ri_data%ks_t(nspins, nimg))
     209             :       CALL create_2c_tensor(ri_data%rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     210             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     211          70 :                             name="(AO | AO)")
     212          70 :       DEALLOCATE (dist1, dist2)
     213             : 
     214          70 :       CALL dbt_create(dbcsr_template, ri_data%ks_t(1, 1))
     215             : 
     216          70 :       IF (nspins == 2) THEN
     217          24 :          CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(2, 1))
     218          24 :          CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(2, 1))
     219             :       END IF
     220        1716 :       DO i_img = 2, nimg
     221        3710 :          DO i_spin = 1, nspins
     222        1994 :             CALL dbt_create(ri_data%rho_ao_t(1, 1), ri_data%rho_ao_t(i_spin, i_img))
     223        3640 :             CALL dbt_create(ri_data%ks_t(1, 1), ri_data%ks_t(i_spin, i_img))
     224             :          END DO
     225             :       END DO
     226             : 
     227         210 :    END SUBROUTINE adapt_ri_data_to_kp
     228             : 
     229             : ! **************************************************************************************************
     230             : !> \brief The pre-scf steps for RI-HFX k-points calculation. Namely the calculation of the integrals
     231             : !> \param dbcsr_template ...
     232             : !> \param ri_data ...
     233             : !> \param qs_env ...
     234             : ! **************************************************************************************************
     235          70 :    SUBROUTINE hfx_ri_pre_scf_kp(dbcsr_template, ri_data, qs_env)
     236             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_template
     237             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     238             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     239             : 
     240             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'hfx_ri_pre_scf_kp'
     241             : 
     242             :       INTEGER                                            :: handle, i_img, iatom, natom, nimg, nkind
     243          70 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: t_2c_op_pot, t_2c_op_RI
     244          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: t_3c_int
     245             :       TYPE(dft_control_type), POINTER                    :: dft_control
     246             : 
     247          70 :       NULLIFY (dft_control)
     248             : 
     249          70 :       CALL timeset(routineN, handle)
     250             : 
     251          70 :       CALL get_qs_env(qs_env, dft_control=dft_control, natom=natom, nkind=nkind)
     252             : 
     253          70 :       CALL cleanup_kp(ri_data)
     254             : 
     255             :       !We do all the checks on what we allow in this initial implementation
     256          70 :       IF (ri_data%flavor .NE. ri_pmat) CPABORT("K-points RI-HFX only with RHO flavor")
     257          70 :       IF (ri_data%same_op) ri_data%same_op = .FALSE. !force the full calculation with RI metric
     258          70 :       IF (ABS(ri_data%eps_pgf_orb - dft_control%qs_control%eps_pgf_orb) > 1.0E-16_dp) &
     259           0 :          CPABORT("RI%EPS_PGF_ORB and QS%EPS_PGF_ORB must be identical for RI-HFX k-points")
     260             : 
     261          70 :       CALL get_kp_and_ri_images(ri_data, qs_env)
     262          70 :       nimg = ri_data%nimg
     263             : 
     264             :       !Calculate the integrals
     265        3712 :       ALLOCATE (t_2c_op_pot(nimg), t_2c_op_RI(nimg))
     266        4202 :       ALLOCATE (t_3c_int(1, nimg))
     267          70 :       CALL hfx_ri_pre_scf_calc_tensors(qs_env, ri_data, t_2c_op_RI, t_2c_op_pot, t_3c_int, do_kpoints=.TRUE.)
     268             : 
     269             :       !Make sure the internals have the k-point format
     270          70 :       CALL adapt_ri_data_to_kp(dbcsr_template, ri_data, qs_env)
     271             : 
     272             :       !For each atom i, we calculate the inverse RI metric (P^0 | Q^0)^-1 without external bumping yet
     273             :       !Also store the off-diagonal integrals of the RI metric in case of forces, bumped from the left
     274         210 :       DO iatom = 1, natom
     275             :          CALL get_ext_2c_int(ri_data%t_2c_inv(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     276         140 :                              do_inverse=.TRUE.)
     277             :          !for the forces:
     278             :          !off-diagonl RI metric bumped from the left
     279             :          CALL get_ext_2c_int(ri_data%t_2c_int(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, &
     280         140 :                              qs_env, off_diagonal=.TRUE.)
     281         140 :          CALL apply_bump(ri_data%t_2c_int(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     282             : 
     283             :          !RI metric with bumped off-diagonal blocks (but not inverted), depumed from left and right
     284             :          CALL get_ext_2c_int(ri_data%t_2c_pot(1, iatom), t_2c_op_RI, iatom, iatom, 1, ri_data, qs_env, &
     285         140 :                              do_inverse=.TRUE., skip_inverse=.TRUE.)
     286             :          CALL apply_bump(ri_data%t_2c_pot(1, iatom), iatom, ri_data, qs_env, from_left=.TRUE., &
     287         210 :                          from_right=.TRUE., debump=.TRUE.)
     288             : 
     289             :       END DO
     290             : 
     291        1786 :       DO i_img = 1, nimg
     292        1786 :          CALL dbcsr_release(t_2c_op_RI(i_img))
     293             :       END DO
     294             : 
     295        3572 :       ALLOCATE (ri_data%kp_mat_2c_pot(1, nimg))
     296        1786 :       DO i_img = 1, nimg
     297        1716 :          CALL dbcsr_create(ri_data%kp_mat_2c_pot(1, i_img), template=t_2c_op_pot(i_img))
     298        1716 :          CALL dbcsr_copy(ri_data%kp_mat_2c_pot(1, i_img), t_2c_op_pot(i_img))
     299        1786 :          CALL dbcsr_release(t_2c_op_pot(i_img))
     300             :       END DO
     301             : 
     302             :       !Pre-contract all 3c integrals with the bumped inverse RI metric (P^0|Q^0)^-1,
     303             :       !and store in ri_data%t_3c_int_ctr_1
     304          70 :       CALL precontract_3c_ints(t_3c_int, ri_data, qs_env)
     305             : 
     306             :       !reorder the 3c integrals such that empty images are bunched up together
     307          70 :       CALL reorder_3c_ints(ri_data%t_3c_int_ctr_1(1, :), ri_data)
     308             : 
     309          70 :       CALL timestop(handle)
     310             : 
     311        1856 :    END SUBROUTINE hfx_ri_pre_scf_kp
     312             : 
     313             : ! **************************************************************************************************
     314             : !> \brief clean-up the KP specific data from ri_data
     315             : !> \param ri_data ...
     316             : ! **************************************************************************************************
     317          70 :    SUBROUTINE cleanup_kp(ri_data)
     318             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     319             : 
     320             :       INTEGER                                            :: i, j
     321             : 
     322          70 :       IF (ALLOCATED(ri_data%kp_cost)) DEALLOCATE (ri_data%kp_cost)
     323          70 :       IF (ALLOCATED(ri_data%idx_to_img)) DEALLOCATE (ri_data%idx_to_img)
     324          70 :       IF (ALLOCATED(ri_data%img_to_idx)) DEALLOCATE (ri_data%img_to_idx)
     325          70 :       IF (ALLOCATED(ri_data%present_images)) DEALLOCATE (ri_data%present_images)
     326          70 :       IF (ALLOCATED(ri_data%img_to_RI_cell)) DEALLOCATE (ri_data%img_to_RI_cell)
     327          70 :       IF (ALLOCATED(ri_data%RI_cell_to_img)) DEALLOCATE (ri_data%RI_cell_to_img)
     328             : 
     329          70 :       IF (ALLOCATED(ri_data%kp_mat_2c_pot)) THEN
     330         540 :          DO j = 1, SIZE(ri_data%kp_mat_2c_pot, 2)
     331        1060 :             DO i = 1, SIZE(ri_data%kp_mat_2c_pot, 1)
     332        1040 :                CALL dbcsr_release(ri_data%kp_mat_2c_pot(i, j))
     333             :             END DO
     334             :          END DO
     335          20 :          DEALLOCATE (ri_data%kp_mat_2c_pot)
     336             :       END IF
     337             : 
     338          70 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
     339         540 :          DO i = 1, SIZE(ri_data%kp_t_3c_int)
     340         540 :             CALL dbt_destroy(ri_data%kp_t_3c_int(i))
     341             :          END DO
     342         540 :          DEALLOCATE (ri_data%kp_t_3c_int)
     343             :       END IF
     344             : 
     345          70 :       IF (ALLOCATED(ri_data%t_2c_inv)) THEN
     346         160 :          DO j = 1, SIZE(ri_data%t_2c_inv, 2)
     347         250 :             DO i = 1, SIZE(ri_data%t_2c_inv, 1)
     348         180 :                CALL dbt_destroy(ri_data%t_2c_inv(i, j))
     349             :             END DO
     350             :          END DO
     351         160 :          DEALLOCATE (ri_data%t_2c_inv)
     352             :       END IF
     353             : 
     354          70 :       IF (ALLOCATED(ri_data%t_2c_int)) THEN
     355         160 :          DO j = 1, SIZE(ri_data%t_2c_int, 2)
     356         250 :             DO i = 1, SIZE(ri_data%t_2c_int, 1)
     357         180 :                CALL dbt_destroy(ri_data%t_2c_int(i, j))
     358             :             END DO
     359             :          END DO
     360         160 :          DEALLOCATE (ri_data%t_2c_int)
     361             :       END IF
     362             : 
     363          70 :       IF (ALLOCATED(ri_data%t_2c_pot)) THEN
     364         160 :          DO j = 1, SIZE(ri_data%t_2c_pot, 2)
     365         250 :             DO i = 1, SIZE(ri_data%t_2c_pot, 1)
     366         180 :                CALL dbt_destroy(ri_data%t_2c_pot(i, j))
     367             :             END DO
     368             :          END DO
     369         160 :          DEALLOCATE (ri_data%t_2c_pot)
     370             :       END IF
     371             : 
     372          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_1)) THEN
     373         640 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_1, 2)
     374        1210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_1, 1)
     375        1140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_1(i, j))
     376             :             END DO
     377             :          END DO
     378         640 :          DEALLOCATE (ri_data%t_3c_int_ctr_1)
     379             :       END IF
     380             : 
     381          70 :       IF (ALLOCATED(ri_data%t_3c_int_ctr_2)) THEN
     382         140 :          DO j = 1, SIZE(ri_data%t_3c_int_ctr_2, 2)
     383         210 :             DO i = 1, SIZE(ri_data%t_3c_int_ctr_2, 1)
     384         140 :                CALL dbt_destroy(ri_data%t_3c_int_ctr_2(i, j))
     385             :             END DO
     386             :          END DO
     387         140 :          DEALLOCATE (ri_data%t_3c_int_ctr_2)
     388             :       END IF
     389             : 
     390          70 :       IF (ALLOCATED(ri_data%rho_ao_t)) THEN
     391         640 :          DO j = 1, SIZE(ri_data%rho_ao_t, 2)
     392        1428 :             DO i = 1, SIZE(ri_data%rho_ao_t, 1)
     393        1358 :                CALL dbt_destroy(ri_data%rho_ao_t(i, j))
     394             :             END DO
     395             :          END DO
     396         858 :          DEALLOCATE (ri_data%rho_ao_t)
     397             :       END IF
     398             : 
     399          70 :       IF (ALLOCATED(ri_data%ks_t)) THEN
     400         640 :          DO j = 1, SIZE(ri_data%ks_t, 2)
     401        1428 :             DO i = 1, SIZE(ri_data%ks_t, 1)
     402        1358 :                CALL dbt_destroy(ri_data%ks_t(i, j))
     403             :             END DO
     404             :          END DO
     405         858 :          DEALLOCATE (ri_data%ks_t)
     406             :       END IF
     407             : 
     408          70 :    END SUBROUTINE cleanup_kp
     409             : 
     410             : ! **************************************************************************************************
     411             : !> \brief Update the KS matrices for each real-space image
     412             : !> \param qs_env ...
     413             : !> \param ri_data ...
     414             : !> \param ks_matrix ...
     415             : !> \param ehfx ...
     416             : !> \param rho_ao ...
     417             : !> \param geometry_did_change ...
     418             : !> \param nspins ...
     419             : !> \param hf_fraction ...
     420             : ! **************************************************************************************************
     421         190 :    SUBROUTINE hfx_ri_update_ks_kp(qs_env, ri_data, ks_matrix, ehfx, rho_ao, &
     422             :                                   geometry_did_change, nspins, hf_fraction)
     423             : 
     424             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     425             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     426             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: ks_matrix
     427             :       REAL(KIND=dp), INTENT(OUT)                         :: ehfx
     428             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     429             :       LOGICAL, INTENT(IN)                                :: geometry_did_change
     430             :       INTEGER, INTENT(IN)                                :: nspins
     431             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     432             : 
     433             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_ks_kp'
     434             : 
     435             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_spin, iatom, &
     436             :          iblk, igroup, jatom, mb_img, n_batch_nze, natom, ngroups, nimg, nimg_nze
     437             :       INTEGER(int_8)                                     :: nflop, nze
     438         190 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_at, batch_ranges_nze, &
     439         190 :                                                             idx_to_at_AO
     440         190 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     441         190 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: sparsity_pattern
     442             :       LOGICAL                                            :: use_delta_p
     443             :       REAL(dp)                                           :: etmp, fac, occ, pfac, pref, t1, t2, t3, &
     444             :                                                             t4
     445             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     446             :       TYPE(dbcsr_type)                                   :: ks_desymm, rho_desymm, tmp
     447         190 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     448             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     449         190 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ks_t_split, t_2c_ao_tmp, t_2c_work, &
     450         190 :                                                             t_3c_int, t_3c_work_2, t_3c_work_3
     451         190 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: ks_t, ks_t_sub, t_3c_apc, t_3c_apc_sub
     452             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     453             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     454             : 
     455         190 :       NULLIFY (para_env, para_env_sub, blacs_env_sub, hfx_section, dbcsr_template)
     456             : 
     457         190 :       CALL cite_reference(Bussy2024)
     458             : 
     459         190 :       CALL timeset(routineN, handle)
     460             : 
     461         190 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natom)
     462             : 
     463         190 :       IF (nspins == 1) THEN
     464         112 :          fac = 0.5_dp*hf_fraction
     465             :       ELSE
     466          78 :          fac = 1.0_dp*hf_fraction
     467             :       END IF
     468             : 
     469         190 :       IF (geometry_did_change) THEN
     470          70 :          CALL hfx_ri_pre_scf_kp(ks_matrix(1, 1)%matrix, ri_data, qs_env)
     471             :       END IF
     472         190 :       nimg = ri_data%nimg
     473         190 :       nimg_nze = ri_data%nimg_nze
     474             : 
     475             :       !We need to calculate the KS matrix for each periodic cell with index b: F_mu^0,nu^b
     476             :       !F_mu^0,nu^b = -0.5 sum_a,c P_sigma^0,lambda^c (mu^0, sigma^a| P^0) V_P^0,Q^b (Q^b| nu^b lambda^a+c)
     477             :       !with V_P^0,Q^b = (P^0|R^0)^-1 * (R^0|S^b) * (S^b|Q^b)^-1
     478             : 
     479             :       !We use a local RI basis set for each atom in the system, which inlcudes RI basis elements for
     480             :       !each neighboring atom standing within the KIND radius (decay of Gaussian with smallest exponent)
     481             : 
     482             :       !We also limit the number of periodic images we consider accorrding to the HFX potentail in the
     483             :       !RI basis, because if V_P^0,Q^b is zero everywhere, then image b can be ignored (RI basis less diffuse)
     484             : 
     485             :       !We manage to calculate each KS matrix doing a double loop on iamges, and a double loop on atoms
     486             :       !First, we pre-contract and store P_sigma^0,lambda^c (mu^0, sigma^a| P^0) (P^0|R^0)^-1 into T_mu^0,lambda^a+c,P^0
     487             :       !Then, we loop over b_img, iatom, jatom to get (R^0|S^b)
     488             :       !Finally, we do an additional loop over a+c images where we do (R^0|S^b) (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c)
     489             :       !and the final contraction with T_mu^0,lambda^a+c,P^0
     490             : 
     491             :       !Note that the 3-center integrals are pre-contracted with the RI metric, and that the same tensor can be used
     492             :       !(mu^0, sigma^a| P^0) (P^0|R^0)  <===> (S^b|Q^b)^-1 (Q^b| nu^b lambda^a+c) by relabelling the images
     493             : 
     494         190 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     495         190 :       CALL section_vals_val_get(hfx_section, "KP_USE_DELTA_P", l_val=use_delta_p)
     496             : 
     497             :       !By default, build the density tensor based on the difference of this SCF P and that of the prev. SCF
     498         190 :       pfac = -1.0_dp
     499         190 :       IF (.NOT. use_delta_p) pfac = 0.0_dp
     500         190 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, pfac, ri_data, qs_env)
     501             : 
     502       13412 :       ALLOCATE (ks_t(nspins, nimg))
     503        4994 :       DO i_img = 1, nimg
     504       11322 :          DO i_spin = 1, nspins
     505       11132 :             CALL dbt_create(ri_data%ks_t(1, 1), ks_t(i_spin, i_img))
     506             :          END DO
     507             :       END DO
     508             : 
     509         570 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     510         190 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     511             : 
     512             :       !First we calculate and store T^1_mu^0,lambda^a+c,P = P_mu^0,lambda^c * (mu_0 sigma^a | P^0) (P^0|R^0)^-1
     513             :       !To avoid doing nimg**2 tiny contractions that do not scale well with a large number of CPUs,
     514             :       !we instead do a single loop over the a+c image index. For each a+c, we get a list of allowed
     515             :       !combination of a,c indices. Then we build TAS tensors P_mu^0,lambda^c with all concerned c's
     516             :       !and (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 with all a's. Then we perform a single contraction with larger tensors,
     517             :       !were the sum over a,c is automatically taken care of
     518       13222 :       ALLOCATE (t_3c_apc(nspins, nimg))
     519        4994 :       DO i_img = 1, nimg
     520       11322 :          DO i_spin = 1, nspins
     521       11132 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     522             :          END DO
     523             :       END DO
     524         190 :       CALL contract_pmat_3c(t_3c_apc, ri_data%rho_ao_t, ri_data, qs_env)
     525             : 
     526         190 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     527         190 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     528         190 :       CALL section_vals_val_get(hfx_section, "KP_STACK_SIZE", i_val=batch_size)
     529         190 :       ri_data%kp_stack_size = batch_size
     530             : 
     531         190 :       IF (MOD(para_env%num_pe, ngroups) .NE. 0) THEN
     532           0 :          CPWARN("KP_NGROUPS must be an integer divisor of the total number of MPI ranks. It was set to 1.")
     533           0 :          ngroups = 1
     534           0 :          CALL section_vals_val_set(hfx_section, "KP_NGROUPS", i_val=ngroups)
     535             :       END IF
     536         190 :       IF ((MOD(ngroups, natom) .NE. 0) .AND. (MOD(natom, ngroups) .NE. 0) .AND. geometry_did_change) THEN
     537           0 :          IF (ngroups > 1) THEN
     538           0 :             CPWARN("Better load balancing is reached if NGROUPS is a multiple/divisor of the number of atoms")
     539             :          END IF
     540             :       END IF
     541         190 :       group_size = para_env%num_pe/ngroups
     542         190 :       igroup = para_env%mepos/group_size
     543             : 
     544         190 :       ALLOCATE (para_env_sub)
     545         190 :       CALL para_env_sub%from_split(para_env, igroup)
     546         190 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     547             : 
     548             :       ! The sparsity pattern of each iatom, jatom pair, on each b_img, and on which subgroup
     549         950 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     550         190 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     551         190 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     552             : 
     553             :       !Get all the required tensors in the subgroups
     554       24106 :       ALLOCATE (mat_2c_pot(nimg), ks_t_sub(nspins, nimg), t_2c_ao_tmp(1), ks_t_split(2), t_2c_work(3))
     555             :       CALL get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
     556         190 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     557             : 
     558       24296 :       ALLOCATE (t_3c_int(nimg), t_3c_apc_sub(nspins, nimg), t_3c_work_2(3), t_3c_work_3(3))
     559             :       CALL get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
     560         190 :                                    group_size, ngroups, para_env, para_env_sub, ri_data)
     561             : 
     562             :       !We go atom by atom, therefore there is an automatic batching along that direction
     563             :       !Also, because we stack the 3c tensors nimg times, we naturally do some batching there too
     564         570 :       ALLOCATE (batch_ranges_at(natom + 1))
     565         190 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     566         190 :       iatom = 0
     567         928 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     568         928 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     569         380 :             iatom = iatom + 1
     570         380 :             batch_ranges_at(iatom) = iblk
     571             :          END IF
     572             :       END DO
     573             : 
     574         190 :       n_batch_nze = nimg_nze/batch_size
     575         190 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
     576         570 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
     577         404 :       DO i_batch = 1, n_batch_nze
     578         404 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     579             :       END DO
     580         190 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
     581             : 
     582         190 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     583         190 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     584         190 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     585         190 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     586             : 
     587         190 :       t1 = m_walltime()
     588       33818 :       ri_data%kp_cost(:, :, :) = 0.0_dp
     589         570 :       ALLOCATE (iapc_pairs(nimg, 2))
     590        4994 :       DO b_img = 1, nimg
     591        4804 :          CALL dbt_batched_contract_init(ks_t_split(1))
     592        4804 :          CALL dbt_batched_contract_init(ks_t_split(2))
     593       14412 :          DO jatom = 1, natom
     594       33628 :             DO iatom = 1, natom
     595       19216 :                IF (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup) CYCLE
     596        3232 :                pref = 1.0_dp
     597        3232 :                IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp
     598             : 
     599             :                !measure the cost of the given i, j, b configuration
     600        3232 :                t3 = m_walltime()
     601             : 
     602             :                !Get the proper HFX potential 2c integrals (R_i^0|S_j^b)
     603        3232 :                CALL timeset(routineN//"_2c", handle2)
     604             :                CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     605             :                                    blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     606        3232 :                                    dbcsr_template=dbcsr_template)
     607        3232 :                CALL dbt_copy(t_2c_work(1), t_2c_work(2), move_data=.TRUE.) !move to split blocks
     608        3232 :                CALL dbt_filter(t_2c_work(2), ri_data%filter_eps)
     609        3232 :                CALL timestop(handle2)
     610             : 
     611        3232 :                CALL dbt_batched_contract_init(t_2c_work(2))
     612        3232 :                CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env)
     613        3232 :                CALL timeset(routineN//"_3c", handle2)
     614             : 
     615             :                !Stack the (S^b|Q^b)^-1 * (Q^b| nu^b lambda^a+c) integrals over a+c and multiply by (R_i^0|S_j^b)
     616        7331 :                DO i_batch = 1, n_batch_nze
     617             :                   CALL fill_3c_stack(t_3c_work_3(3), t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
     618             :                                      filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
     619       12297 :                                      img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     620        4099 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1), move_data=.TRUE.)
     621             : 
     622             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_3c_work_3(1), &
     623             :                                     0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
     624             :                                     contract_1=[2], notcontract_1=[1], &
     625             :                                     contract_2=[1], notcontract_2=[2, 3], &
     626        4099 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     627        4099 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     628        4099 :                   CALL dbt_copy(t_3c_work_3(2), t_3c_work_2(2), order=[2, 1, 3], move_data=.TRUE.)
     629        4099 :                   CALL dbt_copy(t_3c_work_3(3), t_3c_work_3(1))
     630             : 
     631             :                   !Stack the P_sigma^a,lambda^a+c * (mu^0 sigma^a | P^0)*(P^0|R^0)^-1 integrals over a+c and contract
     632             :                   !to get the final block of the KS matrix
     633       12586 :                   DO i_spin = 1, nspins
     634             :                      CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
     635             :                                         ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
     636       15765 :                                         img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     637        5255 :                      CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
     638        5255 :                      IF (nze == 0) CYCLE
     639        5235 :                      CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
     640             :                      CALL dbt_contract(-pref*fac, t_3c_work_2(1), t_3c_work_2(2), &
     641             :                                        1.0_dp, ks_t_split(i_spin), map_1=[1], map_2=[2], &
     642             :                                        contract_1=[2, 3], notcontract_1=[1], &
     643             :                                        contract_2=[2, 3], notcontract_2=[1], &
     644             :                                        filter_eps=ri_data%filter_eps, &
     645        5235 :                                        move_data=i_spin == nspins, flop=nflop)
     646       14589 :                      ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
     647             :                   END DO
     648             :                END DO !i_batch
     649        3232 :                CALL timestop(handle2)
     650        3232 :                CALL dbt_batched_contract_finalize(t_2c_work(2))
     651             : 
     652        3232 :                t4 = m_walltime()
     653       35288 :                ri_data%kp_cost(iatom, jatom, b_img) = t4 - t3
     654             :             END DO !iatom
     655             :          END DO !jatom
     656        4804 :          CALL dbt_batched_contract_finalize(ks_t_split(1))
     657        4804 :          CALL dbt_batched_contract_finalize(ks_t_split(2))
     658             : 
     659       11322 :          DO i_spin = 1, nspins
     660        6328 :             CALL dbt_copy(ks_t_split(i_spin), t_2c_ao_tmp(1), move_data=.TRUE.)
     661       11132 :             CALL dbt_copy(t_2c_ao_tmp(1), ks_t_sub(i_spin, b_img), summation=.TRUE.)
     662             :          END DO
     663             :       END DO !b_img
     664         190 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
     665         190 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
     666         190 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
     667         190 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
     668         190 :       CALL para_env%sync()
     669         190 :       CALL para_env%sum(ri_data%dbcsr_nflop)
     670         190 :       CALL para_env%sum(ri_data%kp_cost)
     671         190 :       t2 = m_walltime()
     672         190 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
     673             : 
     674             :       !transfer KS tensor from subgroup to main group
     675         190 :       CALL gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
     676             : 
     677             :       !Keep the 3c integrals on the subgroups to avoid communication at next SCF step
     678        4994 :       DO i_img = 1, nimg
     679        4994 :          CALL dbt_copy(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img), move_data=.TRUE.)
     680             :       END DO
     681             : 
     682             :       !clean-up subgroup tensors
     683         190 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     684         190 :       CALL dbt_destroy(ks_t_split(1))
     685         190 :       CALL dbt_destroy(ks_t_split(2))
     686         190 :       CALL dbt_destroy(t_2c_work(1))
     687         190 :       CALL dbt_destroy(t_2c_work(2))
     688         190 :       CALL dbt_destroy(t_3c_work_2(1))
     689         190 :       CALL dbt_destroy(t_3c_work_2(2))
     690         190 :       CALL dbt_destroy(t_3c_work_2(3))
     691         190 :       CALL dbt_destroy(t_3c_work_3(1))
     692         190 :       CALL dbt_destroy(t_3c_work_3(2))
     693         190 :       CALL dbt_destroy(t_3c_work_3(3))
     694        4994 :       DO i_img = 1, nimg
     695        4804 :          CALL dbt_destroy(t_3c_int(i_img))
     696        4804 :          CALL dbcsr_release(mat_2c_pot(i_img))
     697       11322 :          DO i_spin = 1, nspins
     698        6328 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
     699       11132 :             CALL dbt_destroy(ks_t_sub(i_spin, i_img))
     700             :          END DO
     701             :       END DO
     702         190 :       IF (ASSOCIATED(dbcsr_template)) THEN
     703         190 :          CALL dbcsr_release(dbcsr_template)
     704         190 :          DEALLOCATE (dbcsr_template)
     705             :       END IF
     706             : 
     707             :       !End of subgroup parallelization
     708         190 :       CALL cp_blacs_env_release(blacs_env_sub)
     709         190 :       CALL para_env_sub%free()
     710         190 :       DEALLOCATE (para_env_sub)
     711             : 
     712             :       !Currently, rho_ao_t holds the density difference (wrt to pref SCF step).
     713             :       !ks_t also hold that diff, while only having half the blocks => need to add to prev ks_t and symmetrize
     714             :       !We need the full thing for the energy, on the next SCF step
     715         190 :       CALL get_pmat_images(ri_data%rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     716         458 :       DO i_spin = 1, nspins
     717        6786 :          DO b_img = 1, nimg
     718        6328 :             CALL dbt_copy(ks_t(i_spin, b_img), ri_data%ks_t(i_spin, b_img), summation=.TRUE.)
     719             : 
     720             :             !desymmetrize
     721        6328 :             mb_img = get_opp_index(b_img, qs_env)
     722        6596 :             IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
     723        5708 :                CALL dbt_copy(ks_t(i_spin, mb_img), ri_data%ks_t(i_spin, b_img), order=[2, 1], summation=.TRUE.)
     724             :             END IF
     725             :          END DO
     726             :       END DO
     727        4994 :       DO b_img = 1, nimg
     728       11322 :          DO i_spin = 1, nspins
     729       11132 :             CALL dbt_destroy(ks_t(i_spin, b_img))
     730             :          END DO
     731             :       END DO
     732             : 
     733             :       !calculate the energy
     734         190 :       CALL dbt_create(ri_data%ks_t(1, 1), t_2c_ao_tmp(1))
     735         190 :       CALL dbcsr_create(tmp, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_symmetric)
     736         190 :       CALL dbcsr_create(ks_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     737         190 :       CALL dbcsr_create(rho_desymm, template=ks_matrix(1, 1)%matrix, matrix_type=dbcsr_type_no_symmetry)
     738         190 :       ehfx = 0.0_dp
     739        4994 :       DO i_img = 1, nimg
     740       11322 :          DO i_spin = 1, nspins
     741        6328 :             CALL dbt_filter(ri_data%ks_t(i_spin, i_img), ri_data%filter_eps)
     742        6328 :             CALL dbt_copy(ri_data%ks_t(i_spin, i_img), t_2c_ao_tmp(1))
     743        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), ks_desymm)
     744        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), tmp)
     745        6328 :             CALL dbcsr_add(ks_matrix(i_spin, i_img)%matrix, tmp, 1.0_dp, 1.0_dp)
     746             : 
     747        6328 :             CALL dbt_copy(ri_data%rho_ao_t(i_spin, i_img), t_2c_ao_tmp(1))
     748        6328 :             CALL dbt_copy_tensor_to_matrix(t_2c_ao_tmp(1), rho_desymm)
     749             : 
     750        6328 :             CALL dbcsr_dot(ks_desymm, rho_desymm, etmp)
     751        6328 :             ehfx = ehfx + 0.5_dp*etmp
     752             : 
     753       11132 :             IF (.NOT. use_delta_p) CALL dbt_clear(ri_data%ks_t(i_spin, i_img))
     754             :          END DO
     755             :       END DO
     756         190 :       CALL dbcsr_release(rho_desymm)
     757         190 :       CALL dbcsr_release(ks_desymm)
     758         190 :       CALL dbcsr_release(tmp)
     759         190 :       CALL dbt_destroy(t_2c_ao_tmp(1))
     760             : 
     761         190 :       CALL timestop(handle)
     762             : 
     763       33346 :    END SUBROUTINE hfx_ri_update_ks_kp
     764             : 
     765             : ! **************************************************************************************************
     766             : !> \brief Update the K-points RI-HFX forces
     767             : !> \param qs_env ...
     768             : !> \param ri_data ...
     769             : !> \param nspins ...
     770             : !> \param hf_fraction ...
     771             : !> \param rho_ao ...
     772             : !> \param use_virial ...
     773             : !> \note Because this routine uses stored quantities calculated in the energy calculation, they should
     774             : !>       always be called by pairs, and with the same input densities
     775             : ! **************************************************************************************************
     776          42 :    SUBROUTINE hfx_ri_update_forces_kp(qs_env, ri_data, nspins, hf_fraction, rho_ao, use_virial)
     777             : 
     778             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     779             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
     780             :       INTEGER, INTENT(IN)                                :: nspins
     781             :       REAL(KIND=dp), INTENT(IN)                          :: hf_fraction
     782             :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: rho_ao
     783             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_virial
     784             : 
     785             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'hfx_ri_update_forces_kp'
     786             : 
     787             :       INTEGER :: b_img, batch_size, group_size, handle, handle2, i_batch, i_img, i_loop, i_spin, &
     788             :          i_xyz, iatom, iblk, igroup, j_xyz, jatom, k_xyz, n_batch, natom, ngroups, nimg, nimg_nze
     789             :       INTEGER(int_8)                                     :: nflop, nze
     790          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_of_kind, batch_ranges_at, &
     791          42 :                                                             batch_ranges_nze, dist1, dist2, &
     792          42 :                                                             i_images, idx_to_at_AO, idx_to_at_RI, &
     793          42 :                                                             kind_of
     794          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: iapc_pairs
     795          42 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: force_pattern, sparsity_pattern
     796             :       INTEGER, DIMENSION(2, 1)                           :: bounds_iat, bounds_jat
     797             :       LOGICAL                                            :: use_virial_prv
     798             :       REAL(dp)                                           :: fac, occ, pref, t1, t2
     799             :       REAL(dp), DIMENSION(3, 3)                          :: work_virial
     800          42 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     801             :       TYPE(cell_type), POINTER                           :: cell
     802             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env_sub
     803          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:)        :: mat_2c_pot
     804          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_pot, mat_der_pot_sub
     805             :       TYPE(dbcsr_type), POINTER                          :: dbcsr_template
     806         714 :       TYPE(dbt_type)                                     :: t_2c_R, t_2c_R_split
     807          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_2c_bint, t_2c_binv, t_2c_der_pot, &
     808          84 :                                                             t_2c_inv, t_2c_metric, t_2c_work, &
     809          42 :                                                             t_3c_der_stack, t_3c_work_2, &
     810          42 :                                                             t_3c_work_3
     811          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :) :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
     812          84 :          t_2c_der_metric_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_AO, t_3c_der_AO_sub, t_3c_der_RI, &
     813          42 :          t_3c_der_RI_sub
     814             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
     815          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     816          42 :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
     817             :       TYPE(section_vals_type), POINTER                   :: hfx_section
     818             :       TYPE(virial_type), POINTER                         :: virial
     819             : 
     820          42 :       NULLIFY (para_env, para_env_sub, hfx_section, blacs_env_sub, dbcsr_template, force, atomic_kind_set, &
     821          42 :                virial, particle_set, cell)
     822             : 
     823          42 :       CALL timeset(routineN, handle)
     824             : 
     825          42 :       use_virial_prv = .FALSE.
     826          42 :       IF (PRESENT(use_virial)) use_virial_prv = use_virial
     827             : 
     828          42 :       IF (nspins == 1) THEN
     829          28 :          fac = 0.5_dp*hf_fraction
     830             :       ELSE
     831          14 :          fac = 1.0_dp*hf_fraction
     832             :       END IF
     833             : 
     834             :       CALL get_qs_env(qs_env, natom=natom, para_env=para_env, force=force, cell=cell, virial=virial, &
     835          42 :                       atomic_kind_set=atomic_kind_set, particle_set=particle_set)
     836          42 :       CALL get_atomic_kind_set(atomic_kind_set, kind_of=kind_of, atom_of_kind=atom_of_kind)
     837             : 
     838         126 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
     839          42 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
     840             : 
     841         126 :       ALLOCATE (idx_to_at_RI(SIZE(ri_data%bsizes_RI_split)))
     842          42 :       CALL get_idx_to_atom(idx_to_at_RI, ri_data%bsizes_RI_split, ri_data%bsizes_RI)
     843             : 
     844          42 :       nimg = ri_data%nimg
     845       11190 :       ALLOCATE (t_3c_der_RI(nimg, 3), t_3c_der_AO(nimg, 3), mat_der_pot(nimg, 3), t_2c_der_metric(natom, 3))
     846             : 
     847             :       !We assume that the integrals are available from the SCF
     848             :       !pre-calculate the derivs. 3c tensors as (P^0| sigma^a mu^0), with t_3c_der_AO holding deriv wrt mu^0
     849          42 :       CALL precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
     850             : 
     851             :       !Calculate the density matrix at each image
     852        2690 :       ALLOCATE (rho_ao_t(nspins, nimg))
     853             :       CALL create_2c_tensor(rho_ao_t(1, 1), dist1, dist2, ri_data%pgrid_2d, &
     854             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
     855          42 :                             name="(AO | AO)")
     856          42 :       DEALLOCATE (dist1, dist2)
     857          42 :       IF (nspins == 2) CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(2, 1))
     858        1010 :       DO i_img = 2, nimg
     859        2130 :          DO i_spin = 1, nspins
     860        2088 :             CALL dbt_create(rho_ao_t(1, 1), rho_ao_t(i_spin, i_img))
     861             :          END DO
     862             :       END DO
     863          42 :       CALL get_pmat_images(rho_ao_t, rho_ao, 0.0_dp, ri_data, qs_env)
     864             : 
     865             :       !Contract integrals with the density matrix
     866        2690 :       ALLOCATE (t_3c_apc(nspins, nimg))
     867        1052 :       DO i_img = 1, nimg
     868        2228 :          DO i_spin = 1, nspins
     869        2186 :             CALL dbt_create(ri_data%t_3c_int_ctr_2(1, 1), t_3c_apc(i_spin, i_img))
     870             :          END DO
     871             :       END DO
     872          42 :       CALL contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
     873             : 
     874             :       !Setup the subgroups
     875          42 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
     876          42 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
     877          42 :       group_size = para_env%num_pe/ngroups
     878          42 :       igroup = para_env%mepos/group_size
     879             : 
     880          42 :       ALLOCATE (para_env_sub)
     881          42 :       CALL para_env_sub%from_split(para_env, igroup)
     882          42 :       CALL cp_blacs_env_create(blacs_env_sub, para_env_sub)
     883             : 
     884             :       !Get the ususal sparsity pattern
     885         210 :       ALLOCATE (sparsity_pattern(natom, natom, nimg))
     886          42 :       CALL get_sparsity_pattern(sparsity_pattern, ri_data, qs_env)
     887          42 :       CALL get_sub_dist(sparsity_pattern, ngroups, ri_data)
     888             : 
     889             :       !Get the 2-center quantities in the subgroups (note: main group derivs are deleted wihtin)
     890           0 :       ALLOCATE (t_2c_inv(natom), mat_2c_pot(nimg), rho_ao_t_sub(nspins, nimg), t_2c_work(5), &
     891           0 :                 t_2c_der_metric_sub(natom, 3), mat_der_pot_sub(nimg, 3), t_2c_bint(natom), &
     892       10300 :                 t_2c_metric(natom), t_2c_binv(natom))
     893             :       CALL get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
     894             :                                   rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
     895          42 :                                   mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
     896          42 :       CALL dbt_create(t_2c_work(1), t_2c_R) !nRI x nRI
     897          42 :       CALL dbt_create(t_2c_work(5), t_2c_R_split) !nRI x nRI with split blocks
     898             : 
     899         504 :       ALLOCATE (t_2c_der_pot(3))
     900         168 :       DO i_xyz = 1, 3
     901         168 :          CALL dbt_create(t_2c_R, t_2c_der_pot(i_xyz))
     902             :       END DO
     903             : 
     904             :       !Get the 3-center quantities in the subgroups. The integrals and t_3c_apc already there
     905           0 :       ALLOCATE (t_3c_work_2(3), t_3c_work_3(4), t_3c_der_stack(6), t_3c_der_AO_sub(nimg, 3), &
     906       11312 :                 t_3c_der_RI_sub(nimg, 3), t_3c_apc_sub(nspins, nimg))
     907             :       CALL get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
     908             :                                   t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, t_3c_der_stack, &
     909          42 :                                   group_size, ngroups, para_env, para_env_sub, ri_data)
     910             : 
     911             :       !Set up batched contraction (go atom by atom)
     912         126 :       ALLOCATE (batch_ranges_at(natom + 1))
     913          42 :       batch_ranges_at(natom + 1) = SIZE(ri_data%bsizes_AO_split) + 1
     914          42 :       iatom = 0
     915         212 :       DO iblk = 1, SIZE(ri_data%bsizes_AO_split)
     916         212 :          IF (idx_to_at_AO(iblk) == iatom + 1) THEN
     917          84 :             iatom = iatom + 1
     918          84 :             batch_ranges_at(iatom) = iblk
     919             :          END IF
     920             :       END DO
     921             : 
     922          42 :       CALL dbt_batched_contract_init(t_3c_work_3(1), batch_range_2=batch_ranges_at)
     923          42 :       CALL dbt_batched_contract_init(t_3c_work_3(2), batch_range_2=batch_ranges_at)
     924          42 :       CALL dbt_batched_contract_init(t_3c_work_3(3), batch_range_2=batch_ranges_at)
     925          42 :       CALL dbt_batched_contract_init(t_3c_work_2(1), batch_range_1=batch_ranges_at)
     926          42 :       CALL dbt_batched_contract_init(t_3c_work_2(2), batch_range_1=batch_ranges_at)
     927             : 
     928             :       !Preparing for the stacking of 3c tensors
     929          42 :       nimg_nze = ri_data%nimg_nze
     930          42 :       batch_size = ri_data%kp_stack_size
     931          42 :       n_batch = nimg_nze/batch_size
     932          42 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch = n_batch + 1
     933         126 :       ALLOCATE (batch_ranges_nze(n_batch + 1))
     934          94 :       DO i_batch = 1, n_batch
     935          94 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
     936             :       END DO
     937          42 :       batch_ranges_nze(n_batch + 1) = nimg_nze + 1
     938             : 
     939             :       !Applying the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 from left and right
     940             :       !And keep the bump on LHS only version as well, with B*M^-1 = (M^-1*B)^T
     941         126 :       DO iatom = 1, natom
     942          84 :          CALL dbt_create(t_2c_inv(iatom), t_2c_binv(iatom))
     943          84 :          CALL dbt_copy(t_2c_inv(iatom), t_2c_binv(iatom))
     944          84 :          CALL apply_bump(t_2c_binv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.FALSE.)
     945         126 :          CALL apply_bump(t_2c_inv(iatom), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
     946             :       END DO
     947             : 
     948          42 :       t1 = m_walltime()
     949          42 :       work_virial = 0.0_dp
     950         210 :       ALLOCATE (iapc_pairs(nimg, 2), i_images(nimg))
     951         210 :       ALLOCATE (force_pattern(natom, natom, nimg))
     952        7112 :       force_pattern(:, :, :) = -1
     953             :       !We proceed with 2 loops: one over the sparsity pattern from the SCF, one over the rest
     954             :       !We use the SCF cost model for the first loop, while we calculate the cost of the upcoming loop
     955         126 :       DO i_loop = 1, 2
     956        2104 :          DO b_img = 1, nimg
     957        6144 :             DO jatom = 1, natom
     958       14140 :                DO iatom = 1, natom
     959             : 
     960        8080 :                   pref = -0.5_dp*fac
     961        8080 :                   IF (i_loop == 1 .AND. (.NOT. sparsity_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     962        4619 :                   IF (i_loop == 2 .AND. (.NOT. force_pattern(iatom, jatom, b_img) == igroup)) CYCLE
     963             : 
     964             :                   !Get the proper HFX potential 2c integrals (R_i^0|S_j^b), times (S_j^b|Q_j^b)^-1
     965        1098 :                   CALL timeset(routineN//"_2c_1", handle2)
     966             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
     967             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
     968        1098 :                                       dbcsr_template=dbcsr_template)
     969             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_inv(jatom), &
     970             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
     971             :                                     contract_1=[2], notcontract_1=[1], &
     972             :                                     contract_2=[1], notcontract_2=[2], &
     973        1098 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
     974        1098 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(5), move_data=.TRUE.) !move to split blocks
     975        1098 :                   CALL dbt_filter(t_2c_work(5), ri_data%filter_eps)
     976        1098 :                   CALL timestop(handle2)
     977             : 
     978        1098 :                   CALL timeset(routineN//"_3c", handle2)
     979        5488 :                   bounds_iat(:, 1) = [SUM(ri_data%bsizes_AO(1:iatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:iatom))]
     980        5452 :                   bounds_jat(:, 1) = [SUM(ri_data%bsizes_AO(1:jatom - 1)) + 1, SUM(ri_data%bsizes_AO(1:jatom))]
     981        1098 :                   CALL dbt_clear(t_2c_R_split)
     982             : 
     983        2435 :                   DO i_spin = 1, nspins
     984        2435 :                      CALL dbt_batched_contract_init(rho_ao_t_sub(i_spin, b_img))
     985             :                   END DO
     986             : 
     987        1098 :                   CALL get_iapc_pairs(iapc_pairs, b_img, ri_data, qs_env, i_images) !i = a+c-b
     988        2667 :                   DO i_batch = 1, n_batch
     989             : 
     990             :                      !Stack the 3c derivatives to take the trace later on
     991        6276 :                      DO i_xyz = 1, 3
     992        4707 :                         CALL dbt_clear(t_3c_der_stack(i_xyz))
     993             :                         CALL fill_3c_stack(t_3c_der_stack(i_xyz), t_3c_der_RI_sub(:, i_xyz), &
     994             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
     995             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
     996       14121 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
     997             : 
     998        4707 :                         CALL dbt_clear(t_3c_der_stack(3 + i_xyz))
     999             :                         CALL fill_3c_stack(t_3c_der_stack(3 + i_xyz), t_3c_der_AO_sub(:, i_xyz), &
    1000             :                                            iapc_pairs(:, 1), 3, ri_data, filter_at=jatom, &
    1001             :                                            filter_dim=2, idx_to_at=idx_to_at_AO, &
    1002       15690 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1003             :                      END DO
    1004             : 
    1005        4475 :                      DO i_spin = 1, nspins
    1006             :                         !stack the t_3c_apc tensors
    1007        1808 :                         CALL dbt_clear(t_3c_work_2(3))
    1008             :                         CALL fill_3c_stack(t_3c_work_2(3), t_3c_apc_sub(i_spin, :), iapc_pairs(:, 2), 3, &
    1009             :                                            ri_data, filter_at=iatom, filter_dim=1, idx_to_at=idx_to_at_AO, &
    1010        5424 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1011        1808 :                         CALL get_tensor_occupancy(t_3c_work_2(3), nze, occ)
    1012        1808 :                         IF (nze == 0) CYCLE
    1013        1808 :                         CALL dbt_copy(t_3c_work_2(3), t_3c_work_2(1), move_data=.TRUE.)
    1014             : 
    1015             :                         !Contract with the second density matrix: P_mu^0,nu^b * t_3c_apc,
    1016             :                         !where t_3c_apc = P_sigma^a,lambda^a+c (mu^0 P^0 sigma^a) *(P^0|R^0)^-1 (stacked along a+c)
    1017             :                         CALL dbt_contract(1.0_dp, rho_ao_t_sub(i_spin, b_img), t_3c_work_2(1), &
    1018             :                                           0.0_dp, t_3c_work_2(2), map_1=[1], map_2=[2, 3], &
    1019             :                                           contract_1=[1], notcontract_1=[2], &
    1020             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1021             :                                           bounds_1=bounds_iat, bounds_2=bounds_jat, &
    1022        1808 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1023        1808 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1024             : 
    1025        1808 :                         CALL get_tensor_occupancy(t_3c_work_2(2), nze, occ)
    1026        1808 :                         IF (nze == 0) CYCLE
    1027             : 
    1028             :                         !Contract with V_PQ so that we can take the trace with (Q^b|nu^b lmabda^a+c)^(x)
    1029        1666 :                         CALL dbt_copy(t_3c_work_2(2), t_3c_work_3(1), order=[2, 1, 3], move_data=.TRUE.)
    1030        1666 :                         CALL dbt_batched_contract_init(t_2c_work(5))
    1031             :                         CALL dbt_contract(1.0_dp, t_2c_work(5), t_3c_work_3(1), &
    1032             :                                           0.0_dp, t_3c_work_3(2), map_1=[1], map_2=[2, 3], &
    1033             :                                           contract_1=[1], notcontract_1=[2], &
    1034             :                                           contract_2=[1], notcontract_2=[2, 3], &
    1035        1666 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1036        1666 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1037        1666 :                         CALL dbt_batched_contract_finalize(t_2c_work(5))
    1038             : 
    1039             :                         !Contract with the 3c derivatives to get the force/virial
    1040        1666 :                         CALL dbt_copy(t_3c_work_3(2), t_3c_work_3(4), move_data=.TRUE.)
    1041        1666 :                         IF (use_virial_prv) THEN
    1042             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1043             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1044             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1045             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1046         257 :                                                         ri_data, qs_env, work_virial, cell, particle_set)
    1047             :                         ELSE
    1048             :                            CALL get_force_from_3c_trace(force, t_3c_work_3(4), t_3c_der_stack(1:3), &
    1049             :                                                         t_3c_der_stack(4:6), atom_of_kind, kind_of, &
    1050             :                                                         idx_to_at_RI, idx_to_at_AO, i_images, &
    1051             :                                                         batch_ranges_nze(i_batch), 2.0_dp*pref, &
    1052        1409 :                                                         ri_data, qs_env)
    1053             :                         END IF
    1054        1666 :                         CALL dbt_clear(t_3c_work_3(4))
    1055             : 
    1056             :                         !Contract with the 3-center integrals in order to have a matrix R_PQ such that
    1057             :                         !we can take the trace sum_PQ R_PQ (P^0|Q^b)^(x)
    1058        1666 :                         IF (i_loop == 2) CYCLE
    1059             : 
    1060             :                         !Stack the 3c integrals
    1061             :                         CALL fill_3c_stack(t_3c_work_3(4), ri_data%kp_t_3c_int, iapc_pairs(:, 1), 3, ri_data, &
    1062             :                                            filter_at=jatom, filter_dim=2, idx_to_at=idx_to_at_AO, &
    1063        2634 :                                            img_bounds=[batch_ranges_nze(i_batch), batch_ranges_nze(i_batch + 1)])
    1064         878 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(3), move_data=.TRUE.)
    1065             : 
    1066         878 :                         CALL dbt_batched_contract_init(t_2c_R_split)
    1067             :                         CALL dbt_contract(1.0_dp, t_3c_work_3(1), t_3c_work_3(3), &
    1068             :                                           1.0_dp, t_2c_R_split, map_1=[1], map_2=[2], &
    1069             :                                           contract_1=[2, 3], notcontract_1=[1], &
    1070             :                                           contract_2=[2, 3], notcontract_2=[1], &
    1071         878 :                                           filter_eps=ri_data%filter_eps, flop=nflop)
    1072         878 :                         ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1073         878 :                         CALL dbt_batched_contract_finalize(t_2c_R_split)
    1074        6063 :                         CALL dbt_copy(t_3c_work_3(4), t_3c_work_3(1))
    1075             :                      END DO
    1076             :                   END DO
    1077        2435 :                   DO i_spin = 1, nspins
    1078        2435 :                      CALL dbt_batched_contract_finalize(rho_ao_t_sub(i_spin, b_img))
    1079             :                   END DO
    1080        1098 :                   CALL timestop(handle2)
    1081             : 
    1082        1098 :                   IF (i_loop == 2) CYCLE
    1083         579 :                   pref = 2.0_dp*pref
    1084         579 :                   IF (iatom == jatom .AND. b_img == 1) pref = 0.5_dp*pref
    1085             : 
    1086         579 :                   CALL timeset(routineN//"_2c_2", handle2)
    1087             :                   !Note that the derivatives are in atomic block format (not split)
    1088         579 :                   CALL dbt_copy(t_2c_R_split, t_2c_R, move_data=.TRUE.)
    1089             : 
    1090             :                   CALL get_ext_2c_int(t_2c_work(1), mat_2c_pot, iatom, jatom, b_img, ri_data, qs_env, &
    1091             :                                       blacs_env_ext=blacs_env_sub, para_env_ext=para_env_sub, &
    1092         579 :                                       dbcsr_template=dbcsr_template)
    1093             : 
    1094             :                   !We have to calculate: S^-1(iat) * R_PQ * S^-1(jat)    to trace with HFX pot der
    1095             :                   !                      + R_PQ * S^-1(jat) * pot^T      to trace with S^(x) (iat)
    1096             :                   !                      + pot^T * S^-1(iat) *R_PQ       to trace with S^(x) (jat)
    1097             : 
    1098             :                   !Because 3c tensors are all precontracted with the inverse RI metric,
    1099             :                   !t_2c_R is currently implicitely multiplied by S^-1(iat) from the left
    1100             :                   !and S^-1(jat) from the right, directly in the proper format for the trace
    1101             :                   !with the HFX potential derivative
    1102             : 
    1103             :                   !Trace with HFX pot deriv, that we need to build first
    1104        2316 :                   DO i_xyz = 1, 3
    1105             :                      CALL get_ext_2c_int(t_2c_der_pot(i_xyz), mat_der_pot_sub(:, i_xyz), iatom, jatom, &
    1106             :                                          b_img, ri_data, qs_env, blacs_env_ext=blacs_env_sub, &
    1107        2316 :                                          para_env_ext=para_env_sub, dbcsr_template=dbcsr_template)
    1108             :                   END DO
    1109             : 
    1110         579 :                   IF (use_virial_prv) THEN
    1111             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1112         113 :                                            b_img, pref, ri_data, qs_env, work_virial, cell, particle_set)
    1113             :                   ELSE
    1114             :                      CALL get_2c_der_force(force, t_2c_R, t_2c_der_pot, atom_of_kind, kind_of, &
    1115         466 :                                            b_img, pref, ri_data, qs_env)
    1116             :                   END IF
    1117             : 
    1118        2316 :                   DO i_xyz = 1, 3
    1119        2316 :                      CALL dbt_clear(t_2c_der_pot(i_xyz))
    1120             :                   END DO
    1121             : 
    1122             :                   !R_PQ * S^-1(jat) * pot^T  (=A)
    1123             :                   CALL dbt_contract(1.0_dp, t_2c_metric(iatom), t_2c_R, & !get rid of implicit S^-1(iat)
    1124             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1125             :                                     contract_1=[2], notcontract_1=[1], &
    1126             :                                     contract_2=[1], notcontract_2=[2], &
    1127         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1128         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1129             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_work(1), &
    1130             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1131             :                                     contract_1=[2], notcontract_1=[1], &
    1132             :                                     contract_2=[2], notcontract_2=[1], &
    1133         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1134         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1135             : 
    1136             :                   !With the RI bump function, things get more complex. M = (S|P)_D + B*(S|P)_OD*B
    1137             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1138             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(iatom), &
    1139             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1140             :                                     contract_1=[2], notcontract_1=[1], &
    1141             :                                     contract_2=[1], notcontract_2=[2], &
    1142         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1143         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1144             : 
    1145             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1146             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1147             :                                     contract_1=[1], notcontract_1=[2], &
    1148             :                                     contract_2=[1], notcontract_2=[2], &
    1149         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1150         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1151             : 
    1152         579 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1153             :                   CALL get_2c_bump_forces(force, t_2c_work(4), iatom, atom_of_kind, kind_of, pref, &
    1154         579 :                                           ri_data, qs_env, work_virial)
    1155             : 
    1156             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1157             :                   CALL dbt_contract(1.0_dp, t_2c_binv(iatom), t_2c_work(2), &
    1158             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1159             :                                     contract_1=[1], notcontract_1=[2], &
    1160             :                                     contract_2=[1], notcontract_2=[2], &
    1161         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1162         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1163             : 
    1164         579 :                   IF (use_virial_prv) THEN
    1165             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1166             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1167         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1168             :                   ELSE
    1169             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1170         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1171             :                   END IF
    1172             : 
    1173             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1174         579 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1175         579 :                   CALL apply_bump(t_2c_work(2), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1176             : 
    1177         579 :                   IF (use_virial_prv) THEN
    1178             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1179             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1180         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1181             :                   ELSE
    1182             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(iatom, :), atom_of_kind, &
    1183         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1184             :                   END IF
    1185             : 
    1186             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1187             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1188             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(iatom), &
    1189             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1190             :                                     contract_1=[2], notcontract_1=[1], &
    1191             :                                     contract_2=[1], notcontract_2=[2], &
    1192         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1193         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1194             : 
    1195             :                   CALL dbt_contract(1.0_dp, t_2c_bint(iatom), t_2c_work(4), &
    1196             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1197             :                                     contract_1=[1], notcontract_1=[2], &
    1198             :                                     contract_2=[1], notcontract_2=[2], &
    1199         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1200         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1201             : 
    1202             :                   CALL get_2c_bump_forces(force, t_2c_work(2), iatom, atom_of_kind, kind_of, -pref, &
    1203         579 :                                           ri_data, qs_env, work_virial)
    1204             : 
    1205             :                   ! pot^T * S^-1(iat) * R_PQ (=A)
    1206             :                   CALL dbt_contract(1.0_dp, t_2c_work(1), t_2c_R, &
    1207             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1208             :                                     contract_1=[1], notcontract_1=[2], &
    1209             :                                     contract_2=[1], notcontract_2=[2], &
    1210         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1211         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1212             : 
    1213             :                   CALL dbt_contract(1.0_dp, t_2c_work(2), t_2c_metric(jatom), & !get rid of implicit S^-1(jat)
    1214             :                                     0.0_dp, t_2c_work(3), map_1=[1], map_2=[2], &
    1215             :                                     contract_1=[2], notcontract_1=[1], &
    1216             :                                     contract_2=[1], notcontract_2=[2], &
    1217         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1218         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1219             : 
    1220             :                   !Do the same shenanigans with the S^(x) (jatom)
    1221             :                   !Calculate M^-1*B*A + A*B*M^-1 to contract with B^x. A is in t_2c_work(3)
    1222             :                   CALL dbt_contract(1.0_dp, t_2c_work(3), t_2c_binv(jatom), &
    1223             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1224             :                                     contract_1=[2], notcontract_1=[1], &
    1225             :                                     contract_2=[1], notcontract_2=[2], &
    1226         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1227         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1228             : 
    1229             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(3), & !use transpose of B*M^-1 = M^-1*B
    1230             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1231             :                                     contract_1=[1], notcontract_1=[2], &
    1232             :                                     contract_2=[1], notcontract_2=[2], &
    1233         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1234         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1235             : 
    1236         579 :                   CALL dbt_copy(t_2c_work(2), t_2c_work(4), summation=.TRUE.)
    1237             :                   CALL get_2c_bump_forces(force, t_2c_work(4), jatom, atom_of_kind, kind_of, pref, &
    1238         579 :                                           ri_data, qs_env, work_virial)
    1239             : 
    1240             :                   !Calculate -M^-1*B*A*B*M^-1 to contracte with diagonal RI metric deriv. t_2c_work(2) holds A*B*M^-1
    1241             :                   CALL dbt_contract(1.0_dp, t_2c_binv(jatom), t_2c_work(2), &
    1242             :                                     0.0_dp, t_2c_work(4), map_1=[1], map_2=[2], &
    1243             :                                     contract_1=[1], notcontract_1=[2], &
    1244             :                                     contract_2=[1], notcontract_2=[2], &
    1245         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1246         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1247             : 
    1248         579 :                   IF (use_virial_prv) THEN
    1249             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1250             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1251         113 :                                            diag=.TRUE., offdiag=.FALSE.)
    1252             :                   ELSE
    1253             :                      CALL get_2c_der_force(force, t_2c_work(4), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1254         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.TRUE., offdiag=.FALSE.)
    1255             :                   END IF
    1256             : 
    1257             :                   !Calculate -B*M^-1*B*A*B*M^-1*B to contract with off-diagonal RI metric derivs
    1258         579 :                   CALL dbt_copy(t_2c_work(4), t_2c_work(2))
    1259         579 :                   CALL apply_bump(t_2c_work(2), jatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    1260             : 
    1261         579 :                   IF (use_virial_prv) THEN
    1262             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1263             :                                            kind_of, 1, -pref, ri_data, qs_env, work_virial, cell, particle_set, &
    1264         113 :                                            diag=.FALSE., offdiag=.TRUE.)
    1265             :                   ELSE
    1266             :                      CALL get_2c_der_force(force, t_2c_work(2), t_2c_der_metric_sub(jatom, :), atom_of_kind, &
    1267         466 :                                            kind_of, 1, -pref, ri_data, qs_env, diag=.FALSE., offdiag=.TRUE.)
    1268             :                   END IF
    1269             : 
    1270             :                   !Calculate -O*B*M^-1*B*A*B*M^-1 - M^-1*B*A*B*M^-1*B*O, where O is off-diagonal integrals
    1271             :                   !t_2c_work(4) holds M^-1*B*A*B*M^-1, and exploit transpose of B*O (stored in t_2c_bint)
    1272             :                   CALL dbt_contract(1.0_dp, t_2c_work(4), t_2c_bint(jatom), &
    1273             :                                     0.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1274             :                                     contract_1=[2], notcontract_1=[1], &
    1275             :                                     contract_2=[1], notcontract_2=[2], &
    1276         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1277         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1278             : 
    1279             :                   CALL dbt_contract(1.0_dp, t_2c_bint(jatom), t_2c_work(4), &
    1280             :                                     1.0_dp, t_2c_work(2), map_1=[1], map_2=[2], &
    1281             :                                     contract_1=[1], notcontract_1=[2], &
    1282             :                                     contract_2=[1], notcontract_2=[2], &
    1283         579 :                                     filter_eps=ri_data%filter_eps, flop=nflop)
    1284         579 :                   ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    1285             : 
    1286             :                   CALL get_2c_bump_forces(force, t_2c_work(2), jatom, atom_of_kind, kind_of, -pref, &
    1287         579 :                                           ri_data, qs_env, work_virial)
    1288             : 
    1289       14376 :                   CALL timestop(handle2)
    1290             :                END DO !iatom
    1291             :             END DO !jatom
    1292             :          END DO !b_img
    1293             : 
    1294         126 :          IF (i_loop == 1) THEN
    1295          42 :             CALL update_pattern_to_forces(force_pattern, sparsity_pattern, ngroups, ri_data, qs_env)
    1296             :          END IF
    1297             :       END DO !i_loop
    1298             : 
    1299          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(1))
    1300          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(2))
    1301          42 :       CALL dbt_batched_contract_finalize(t_3c_work_3(3))
    1302          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(1))
    1303          42 :       CALL dbt_batched_contract_finalize(t_3c_work_2(2))
    1304             : 
    1305          42 :       IF (use_virial_prv) THEN
    1306          32 :          DO k_xyz = 1, 3
    1307         104 :             DO j_xyz = 1, 3
    1308         312 :                DO i_xyz = 1, 3
    1309             :                   virial%pv_fock_4c(i_xyz, j_xyz) = virial%pv_fock_4c(i_xyz, j_xyz) &
    1310         288 :                                                     + work_virial(i_xyz, k_xyz)*cell%hmat(j_xyz, k_xyz)
    1311             :                END DO
    1312             :             END DO
    1313             :          END DO
    1314             :       END IF
    1315             : 
    1316             :       !End of subgroup parallelization
    1317          42 :       CALL cp_blacs_env_release(blacs_env_sub)
    1318          42 :       CALL para_env_sub%free()
    1319          42 :       DEALLOCATE (para_env_sub)
    1320             : 
    1321          42 :       CALL para_env%sync()
    1322          42 :       t2 = m_walltime()
    1323          42 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    1324             : 
    1325             :       !clean-up
    1326          42 :       IF (ASSOCIATED(dbcsr_template)) THEN
    1327          42 :          CALL dbcsr_release(dbcsr_template)
    1328          42 :          DEALLOCATE (dbcsr_template)
    1329             :       END IF
    1330          42 :       CALL dbt_destroy(t_2c_R)
    1331          42 :       CALL dbt_destroy(t_2c_R_split)
    1332          42 :       CALL dbt_destroy(t_2c_work(1))
    1333          42 :       CALL dbt_destroy(t_2c_work(2))
    1334          42 :       CALL dbt_destroy(t_2c_work(3))
    1335          42 :       CALL dbt_destroy(t_2c_work(4))
    1336          42 :       CALL dbt_destroy(t_2c_work(5))
    1337          42 :       CALL dbt_destroy(t_3c_work_2(1))
    1338          42 :       CALL dbt_destroy(t_3c_work_2(2))
    1339          42 :       CALL dbt_destroy(t_3c_work_2(3))
    1340          42 :       CALL dbt_destroy(t_3c_work_3(1))
    1341          42 :       CALL dbt_destroy(t_3c_work_3(2))
    1342          42 :       CALL dbt_destroy(t_3c_work_3(3))
    1343          42 :       CALL dbt_destroy(t_3c_work_3(4))
    1344          42 :       CALL dbt_destroy(t_3c_der_stack(1))
    1345          42 :       CALL dbt_destroy(t_3c_der_stack(2))
    1346          42 :       CALL dbt_destroy(t_3c_der_stack(3))
    1347          42 :       CALL dbt_destroy(t_3c_der_stack(4))
    1348          42 :       CALL dbt_destroy(t_3c_der_stack(5))
    1349          42 :       CALL dbt_destroy(t_3c_der_stack(6))
    1350         168 :       DO i_xyz = 1, 3
    1351         168 :          CALL dbt_destroy(t_2c_der_pot(i_xyz))
    1352             :       END DO
    1353         126 :       DO iatom = 1, natom
    1354          84 :          CALL dbt_destroy(t_2c_inv(iatom))
    1355          84 :          CALL dbt_destroy(t_2c_binv(iatom))
    1356          84 :          CALL dbt_destroy(t_2c_bint(iatom))
    1357          84 :          CALL dbt_destroy(t_2c_metric(iatom))
    1358         378 :          DO i_xyz = 1, 3
    1359         336 :             CALL dbt_destroy(t_2c_der_metric_sub(iatom, i_xyz))
    1360             :          END DO
    1361             :       END DO
    1362        1052 :       DO i_img = 1, nimg
    1363        1010 :          CALL dbcsr_release(mat_2c_pot(i_img))
    1364        2228 :          DO i_spin = 1, nspins
    1365        1176 :             CALL dbt_destroy(rho_ao_t_sub(i_spin, i_img))
    1366        2186 :             CALL dbt_destroy(t_3c_apc_sub(i_spin, i_img))
    1367             :          END DO
    1368             :       END DO
    1369         168 :       DO i_xyz = 1, 3
    1370        3198 :          DO i_img = 1, nimg
    1371        3030 :             CALL dbt_destroy(t_3c_der_RI_sub(i_img, i_xyz))
    1372        3030 :             CALL dbt_destroy(t_3c_der_AO_sub(i_img, i_xyz))
    1373        3156 :             CALL dbcsr_release(mat_der_pot_sub(i_img, i_xyz))
    1374             :          END DO
    1375             :       END DO
    1376             : 
    1377          42 :       CALL timestop(handle)
    1378             : 
    1379       18756 :    END SUBROUTINE hfx_ri_update_forces_kp
    1380             : 
    1381             : ! **************************************************************************************************
    1382             : !> \brief A routine the applies the RI bump matrix from the left and/or the right, given an input
    1383             : !>        matrix and the central RI atom. We assume atomic block sizes
    1384             : !> \param t_2c_inout ...
    1385             : !> \param atom_i ...
    1386             : !> \param ri_data ...
    1387             : !> \param qs_env ...
    1388             : !> \param from_left ...
    1389             : !> \param from_right ...
    1390             : !> \param debump ...
    1391             : ! **************************************************************************************************
    1392        1746 :    SUBROUTINE apply_bump(t_2c_inout, atom_i, ri_data, qs_env, from_left, from_right, debump)
    1393             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_inout
    1394             :       INTEGER, INTENT(IN)                                :: atom_i
    1395             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1396             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1397             :       LOGICAL, INTENT(IN), OPTIONAL                      :: from_left, from_right, debump
    1398             : 
    1399             :       INTEGER                                            :: i_img, i_RI, iatom, ind(2), j_img, j_RI, &
    1400             :                                                             jatom, natom, nblks(2), nimg, nkind
    1401        1746 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1402        1746 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1403             :       LOGICAL                                            :: found, my_debump, my_left, my_right
    1404             :       REAL(dp)                                           :: bval, r0, r1, ri(3), rj(3), rref(3), &
    1405             :                                                             scoord(3)
    1406        1746 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1407             :       TYPE(cell_type), POINTER                           :: cell
    1408             :       TYPE(dbt_iterator_type)                            :: iter
    1409             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1410        1746 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1411        1746 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1412             : 
    1413        1746 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1414             : 
    1415             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1416        1746 :                       kpoints=kpoints, particle_set=particle_set)
    1417        1746 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1418             : 
    1419        1746 :       my_debump = .FALSE.
    1420        1746 :       IF (PRESENT(debump)) my_debump = debump
    1421             : 
    1422        1746 :       my_left = .FALSE.
    1423        1746 :       IF (PRESENT(from_left)) my_left = from_left
    1424             : 
    1425        1746 :       my_right = .FALSE.
    1426        1746 :       IF (PRESENT(from_right)) my_right = from_right
    1427        1746 :       CPASSERT(my_left .OR. my_right)
    1428             : 
    1429        1746 :       CALL dbt_get_info(t_2c_inout, nblks_total=nblks)
    1430        1746 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1431        1746 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1432             : 
    1433        1746 :       nimg = ri_data%nimg
    1434             : 
    1435             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1436        1746 :       r1 = ri_data%kp_RI_range
    1437        1746 :       r0 = ri_data%kp_bump_rad
    1438        1746 :       rref = pbc(particle_set(atom_i)%r, cell)
    1439             : 
    1440             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_inout,natom,ri_data,cell,particle_set,index_to_cell,my_left, &
    1441             : !$OMP                               my_right,r0,r1,rref,my_debump) &
    1442        1746 : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,bval)
    1443             :       CALL dbt_iterator_start(iter, t_2c_inout)
    1444             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1445             :          CALL dbt_iterator_next_block(iter, ind)
    1446             :          CALL dbt_get_block(t_2c_inout, ind, blk, found)
    1447             :          IF (.NOT. found) CYCLE
    1448             : 
    1449             :          i_RI = (ind(1) - 1)/natom + 1
    1450             :          i_img = ri_data%RI_cell_to_img(i_RI)
    1451             :          iatom = ind(1) - (i_RI - 1)*natom
    1452             : 
    1453             :          CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    1454             :          CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    1455             : 
    1456             :          j_RI = (ind(2) - 1)/natom + 1
    1457             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1458             :          jatom = ind(2) - (j_RI - 1)*natom
    1459             : 
    1460             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1461             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1462             : 
    1463             :          IF (.NOT. my_debump) THEN
    1464             :             IF (my_left) blk(:, :) = blk(:, :)*bump(NORM2(ri - rref), r0, r1)
    1465             :             IF (my_right) blk(:, :) = blk(:, :)*bump(NORM2(rj - rref), r0, r1)
    1466             :          ELSE
    1467             :             !Note: by construction, the bump function is never quite zero, as its range is the same
    1468             :             !      as that of the extended RI basis (but we are safe)
    1469             :             bval = bump(NORM2(ri - rref), r0, r1)
    1470             :             IF (my_left .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1471             :             bval = bump(NORM2(rj - rref), r0, r1)
    1472             :             IF (my_right .AND. bval > EPSILON(1.0_dp)) blk(:, :) = blk(:, :)/bval
    1473             :          END IF
    1474             : 
    1475             :          CALL dbt_put_block(t_2c_inout, ind, SHAPE(blk), blk)
    1476             : 
    1477             :          DEALLOCATE (blk)
    1478             :       END DO
    1479             :       CALL dbt_iterator_stop(iter)
    1480             : !$OMP END PARALLEL
    1481        1746 :       CALL dbt_filter(t_2c_inout, ri_data%filter_eps)
    1482             : 
    1483        3492 :    END SUBROUTINE apply_bump
    1484             : 
    1485             : ! **************************************************************************************************
    1486             : !> \brief A routine that calculates the forces due to the derivative of the bump function
    1487             : !> \param force ...
    1488             : !> \param t_2c_in ...
    1489             : !> \param atom_i ...
    1490             : !> \param atom_of_kind ...
    1491             : !> \param kind_of ...
    1492             : !> \param pref ...
    1493             : !> \param ri_data ...
    1494             : !> \param qs_env ...
    1495             : !> \param work_virial ...
    1496             : ! **************************************************************************************************
    1497        2316 :    SUBROUTINE get_2c_bump_forces(force, t_2c_in, atom_i, atom_of_kind, kind_of, pref, ri_data, &
    1498             :                                  qs_env, work_virial)
    1499             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    1500             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_in
    1501             :       INTEGER, INTENT(IN)                                :: atom_i
    1502             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    1503             :       REAL(dp), INTENT(IN)                               :: pref
    1504             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1505             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1506             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT)           :: work_virial
    1507             : 
    1508             :       INTEGER :: i, i_img, i_RI, i_xyz, iat_of_kind, iatom, ikind, ind(2), j_img, j_RI, j_xyz, &
    1509             :          jat_of_kind, jatom, jkind, natom, nblks(2), nimg, nkind
    1510        2316 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1511        2316 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1512             :       LOGICAL                                            :: found
    1513             :       REAL(dp)                                           :: new_force, r0, r1, ri(3), rj(3), &
    1514             :                                                             rref(3), scoord(3), x
    1515        2316 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    1516             :       TYPE(cell_type), POINTER                           :: cell
    1517             :       TYPE(dbt_iterator_type)                            :: iter
    1518             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1519        2316 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1520        2316 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    1521             : 
    1522        2316 :       NULLIFY (qs_kind_set, particle_set, kpoints, index_to_cell, cell_to_index, cell)
    1523             : 
    1524             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    1525        2316 :                       kpoints=kpoints, particle_set=particle_set)
    1526        2316 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1527             : 
    1528        2316 :       CALL dbt_get_info(t_2c_in, nblks_total=nblks)
    1529        2316 :       CPASSERT(nblks(1) == ri_data%ncell_RI*natom)
    1530        2316 :       CPASSERT(nblks(2) == ri_data%ncell_RI*natom)
    1531             : 
    1532        2316 :       nimg = ri_data%nimg
    1533             : 
    1534             :       !Loop over the RI cells and atoms, and apply bump accordingly
    1535        2316 :       r1 = ri_data%kp_RI_range
    1536        2316 :       r0 = ri_data%kp_bump_rad
    1537        2316 :       rref = pbc(particle_set(atom_i)%r, cell)
    1538             : 
    1539        2316 :       iat_of_kind = atom_of_kind(atom_i)
    1540        2316 :       ikind = kind_of(atom_i)
    1541             : 
    1542             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_2c_in,natom,ri_data,cell,particle_set,index_to_cell,pref, &
    1543             : !$OMP force,r0,r1,rref,atom_of_kind,kind_of,iat_of_kind,ikind,work_virial) &
    1544             : !$OMP PRIVATE(iter,ind,blk,found,i_RI,i_img,iatom,j_RI,j_img,jatom,scoord,ri,rj,jkind,jat_of_kind, &
    1545        2316 : !$OMP         new_force,i_xyz,i,x,j_xyz)
    1546             :       CALL dbt_iterator_start(iter, t_2c_in)
    1547             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1548             :          CALL dbt_iterator_next_block(iter, ind)
    1549             :          IF (ind(1) .NE. ind(2)) CYCLE !bump matrix is diagonal
    1550             : 
    1551             :          CALL dbt_get_block(t_2c_in, ind, blk, found)
    1552             :          IF (.NOT. found) CYCLE
    1553             : 
    1554             :          !bump is a function of x = SQRT((R - Rref)^2). We refer to R as jatom, and Rref as atom_i
    1555             :          j_RI = (ind(2) - 1)/natom + 1
    1556             :          j_img = ri_data%RI_cell_to_img(j_RI)
    1557             :          jatom = ind(2) - (j_RI - 1)*natom
    1558             :          jat_of_kind = atom_of_kind(jatom)
    1559             :          jkind = kind_of(jatom)
    1560             : 
    1561             :          CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    1562             :          CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    1563             :          x = NORM2(rj - rref)
    1564             :          IF (x < r0 .OR. x > r1) CYCLE
    1565             : 
    1566             :          new_force = 0.0_dp
    1567             :          DO i = 1, SIZE(blk, 1)
    1568             :             new_force = new_force + blk(i, i)
    1569             :          END DO
    1570             :          new_force = pref*new_force*dbump(x, r0, r1)
    1571             : 
    1572             :          !x = SQRT((R - Rref)^2), so we multiply by dx/dR and dx/dRref
    1573             :          DO i_xyz = 1, 3
    1574             :             !Force acting on second atom
    1575             : !$OMP ATOMIC
    1576             :             force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) + &
    1577             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1578             : 
    1579             :             !virial acting on second atom
    1580             :             CALL real_to_scaled(scoord, rj, cell)
    1581             :             DO j_xyz = 1, 3
    1582             : !$OMP ATOMIC
    1583             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1584             :                                            + new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1585             :             END DO
    1586             : 
    1587             :             !Force acting on reference atom, defining the RI basis
    1588             : !$OMP ATOMIC
    1589             :             force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) - &
    1590             :                                                        new_force*(rj(i_xyz) - rref(i_xyz))/x
    1591             : 
    1592             :             !virial of ref atom
    1593             :             CALL real_to_scaled(scoord, rref, cell)
    1594             :             DO j_xyz = 1, 3
    1595             : !$OMP ATOMIC
    1596             :                work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) &
    1597             :                                            - new_force*scoord(j_xyz)*(rj(i_xyz) - rref(i_xyz))/x
    1598             :             END DO
    1599             :          END DO !i_xyz
    1600             : 
    1601             :          DEALLOCATE (blk)
    1602             :       END DO
    1603             :       CALL dbt_iterator_stop(iter)
    1604             : !$OMP END PARALLEL
    1605             : 
    1606        4632 :    END SUBROUTINE get_2c_bump_forces
    1607             : 
    1608             : ! **************************************************************************************************
    1609             : !> \brief The bumb function as defined by Juerg
    1610             : !> \param x ...
    1611             : !> \param r0 ...
    1612             : !> \param r1 ...
    1613             : !> \return ...
    1614             : ! **************************************************************************************************
    1615       25181 :    FUNCTION bump(x, r0, r1) RESULT(b)
    1616             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1617             :       REAL(dp)                                           :: b
    1618             : 
    1619             :       REAL(dp)                                           :: r
    1620             : 
    1621             :       !Head-Gordon
    1622             :       !b = 1.0_dp/(1.0_dp+EXP((r1-r0)/(r1-x)-(r1-r0)/(x-r0)))
    1623             :       !Juerg
    1624       25181 :       r = (x - r0)/(r1 - r0)
    1625       25181 :       b = -6.0_dp*r**5 + 15.0_dp*r**4 - 10.0_dp*r**3 + 1.0_dp
    1626       25181 :       IF (x .GE. r1) b = 0.0_dp
    1627       25181 :       IF (x .LE. r0) b = 1.0_dp
    1628             : 
    1629       25181 :    END FUNCTION bump
    1630             : 
    1631             : ! **************************************************************************************************
    1632             : !> \brief The derivative of the bump function
    1633             : !> \param x ...
    1634             : !> \param r0 ...
    1635             : !> \param r1 ...
    1636             : !> \return ...
    1637             : ! **************************************************************************************************
    1638         525 :    FUNCTION dbump(x, r0, r1) RESULT(b)
    1639             :       REAL(dp), INTENT(IN)                               :: x, r0, r1
    1640             :       REAL(dp)                                           :: b
    1641             : 
    1642             :       REAL(dp)                                           :: r
    1643             : 
    1644         525 :       r = (x - r0)/(r1 - r0)
    1645         525 :       b = (-30.0_dp*r**4 + 60.0_dp*r**3 - 30.0_dp*r**2)/(r1 - r0)
    1646         525 :       IF (x .GE. r1) b = 0.0_dp
    1647         525 :       IF (x .LE. r0) b = 0.0_dp
    1648             : 
    1649         525 :    END FUNCTION dbump
    1650             : 
    1651             : ! **************************************************************************************************
    1652             : !> \brief return the cell index a+c corresponding to given cell index i and b, with i = a+c-b
    1653             : !> \param i_index ...
    1654             : !> \param b_index ...
    1655             : !> \param qs_env ...
    1656             : !> \return ...
    1657             : ! **************************************************************************************************
    1658      151439 :    FUNCTION get_apc_index_from_ib(i_index, b_index, qs_env) RESULT(apc_index)
    1659             :       INTEGER, INTENT(IN)                                :: i_index, b_index
    1660             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1661             :       INTEGER                                            :: apc_index
    1662             : 
    1663             :       INTEGER, DIMENSION(3)                              :: cell_apc
    1664      151439 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1665      151439 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1666             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1667             : 
    1668      151439 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1669      151439 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1670             : 
    1671             :       !i = a+c-b => a+c = i+b
    1672      605756 :       cell_apc(:) = index_to_cell(:, i_index) + index_to_cell(:, b_index)
    1673             : 
    1674     1037869 :       IF (ANY([cell_apc(1), cell_apc(2), cell_apc(3)] < LBOUND(cell_to_index)) .OR. &
    1675             :           ANY([cell_apc(1), cell_apc(2), cell_apc(3)] > UBOUND(cell_to_index))) THEN
    1676             : 
    1677             :          apc_index = 0
    1678             :       ELSE
    1679      132550 :          apc_index = cell_to_index(cell_apc(1), cell_apc(2), cell_apc(3))
    1680             :       END IF
    1681             : 
    1682      151439 :    END FUNCTION get_apc_index_from_ib
    1683             : 
    1684             : ! **************************************************************************************************
    1685             : !> \brief return the cell index i corresponding to the summ of cell_a and cell_c
    1686             : !> \param a_index ...
    1687             : !> \param c_index ...
    1688             : !> \param qs_env ...
    1689             : !> \return ...
    1690             : ! **************************************************************************************************
    1691           0 :    FUNCTION get_apc_index(a_index, c_index, qs_env) RESULT(i_index)
    1692             :       INTEGER, INTENT(IN)                                :: a_index, c_index
    1693             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1694             :       INTEGER                                            :: i_index
    1695             : 
    1696             :       INTEGER, DIMENSION(3)                              :: cell_i
    1697           0 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1698           0 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1699             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1700             : 
    1701           0 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1702           0 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1703             : 
    1704           0 :       cell_i(:) = index_to_cell(:, a_index) + index_to_cell(:, c_index)
    1705             : 
    1706           0 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1707             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1708             : 
    1709             :          i_index = 0
    1710             :       ELSE
    1711           0 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1712             :       END IF
    1713             : 
    1714           0 :    END FUNCTION get_apc_index
    1715             : 
    1716             : ! **************************************************************************************************
    1717             : !> \brief return the cell index i corresponding to the summ of cell_a + cell_c - cell_b
    1718             : !> \param apc_index ...
    1719             : !> \param b_index ...
    1720             : !> \param qs_env ...
    1721             : !> \return ...
    1722             : ! **************************************************************************************************
    1723      510764 :    FUNCTION get_i_index(apc_index, b_index, qs_env) RESULT(i_index)
    1724             :       INTEGER, INTENT(IN)                                :: apc_index, b_index
    1725             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1726             :       INTEGER                                            :: i_index
    1727             : 
    1728             :       INTEGER, DIMENSION(3)                              :: cell_i
    1729      510764 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1730      510764 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1731             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1732             : 
    1733      510764 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1734      510764 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1735             : 
    1736     2043056 :       cell_i(:) = index_to_cell(:, apc_index) - index_to_cell(:, b_index)
    1737             : 
    1738     3491940 :       IF (ANY([cell_i(1), cell_i(2), cell_i(3)] < LBOUND(cell_to_index)) .OR. &
    1739             :           ANY([cell_i(1), cell_i(2), cell_i(3)] > UBOUND(cell_to_index))) THEN
    1740             : 
    1741             :          i_index = 0
    1742             :       ELSE
    1743      438348 :          i_index = cell_to_index(cell_i(1), cell_i(2), cell_i(3))
    1744             :       END IF
    1745             : 
    1746      510764 :    END FUNCTION get_i_index
    1747             : 
    1748             : ! **************************************************************************************************
    1749             : !> \brief A routine that returns all allowed a,c pairs such that a+c images corresponds to the value
    1750             : !>        of the apc_index input. Takes into account that image a corresponds to 3c integrals, which
    1751             : !>        are ordered in their own way
    1752             : !> \param ac_pairs ...
    1753             : !> \param apc_index ...
    1754             : !> \param ri_data ...
    1755             : !> \param qs_env ...
    1756             : ! **************************************************************************************************
    1757       15228 :    SUBROUTINE get_ac_pairs(ac_pairs, apc_index, ri_data, qs_env)
    1758             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: ac_pairs
    1759             :       INTEGER, INTENT(IN)                                :: apc_index
    1760             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1761             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1762             : 
    1763             :       INTEGER                                            :: a_index, actual_img, c_index, nimg
    1764             : 
    1765       15228 :       nimg = SIZE(ac_pairs, 1)
    1766             : 
    1767     1067212 :       ac_pairs(:, :) = 0
    1768             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(ac_pairs,nimg,ri_data,qs_env,apc_index) &
    1769       15228 : !$OMP PRIVATE(a_index,actual_img,c_index)
    1770             :       DO a_index = 1, nimg
    1771             :          actual_img = ri_data%idx_to_img(a_index)
    1772             :          !c = a+c - a
    1773             :          c_index = get_i_index(apc_index, actual_img, qs_env)
    1774             :          ac_pairs(a_index, 1) = a_index
    1775             :          ac_pairs(a_index, 2) = c_index
    1776             :       END DO
    1777             : !$OMP END PARALLEL DO
    1778             : 
    1779       15228 :    END SUBROUTINE get_ac_pairs
    1780             : 
    1781             : ! **************************************************************************************************
    1782             : !> \brief A routine that returns all allowed i,a+c pairs such that, for the given value of b, we have
    1783             : !>        i = a+c-b. Takes into account that image i corrsponds to the 3c ints, which are ordered in
    1784             : !>        their own way
    1785             : !> \param iapc_pairs ...
    1786             : !> \param b_index ...
    1787             : !> \param ri_data ...
    1788             : !> \param qs_env ...
    1789             : !> \param actual_i_img ...
    1790             : ! **************************************************************************************************
    1791        4330 :    SUBROUTINE get_iapc_pairs(iapc_pairs, b_index, ri_data, qs_env, actual_i_img)
    1792             :       INTEGER, DIMENSION(:, :), INTENT(INOUT)            :: iapc_pairs
    1793             :       INTEGER, INTENT(IN)                                :: b_index
    1794             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1795             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1796             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: actual_i_img
    1797             : 
    1798             :       INTEGER                                            :: actual_img, apc_index, i_index, nimg
    1799             : 
    1800        4330 :       nimg = SIZE(iapc_pairs, 1)
    1801       43396 :       IF (PRESENT(actual_i_img)) actual_i_img(:) = 0
    1802             : 
    1803      315868 :       iapc_pairs(:, :) = 0
    1804             : !$OMP PARALLEL DO DEFAULT(NONE) SHARED(iapc_pairs,nimg,ri_data,qs_env,b_index,actual_i_img) &
    1805        4330 : !$OMP PRIVATE(i_index,actual_img,apc_index)
    1806             :       DO i_index = 1, nimg
    1807             :          actual_img = ri_data%idx_to_img(i_index)
    1808             :          apc_index = get_apc_index_from_ib(actual_img, b_index, qs_env)
    1809             :          IF (apc_index == 0) CYCLE
    1810             :          iapc_pairs(i_index, 1) = i_index
    1811             :          iapc_pairs(i_index, 2) = apc_index
    1812             :          IF (PRESENT(actual_i_img)) actual_i_img(i_index) = actual_img
    1813             :       END DO
    1814             : 
    1815        4330 :    END SUBROUTINE get_iapc_pairs
    1816             : 
    1817             : ! **************************************************************************************************
    1818             : !> \brief A function that, given a cell index a, returun the index corresponding to -a, and zero if
    1819             : !>        if out of bounds
    1820             : !> \param a_index ...
    1821             : !> \param qs_env ...
    1822             : !> \return ...
    1823             : ! **************************************************************************************************
    1824       61137 :    FUNCTION get_opp_index(a_index, qs_env) RESULT(opp_index)
    1825             :       INTEGER, INTENT(IN)                                :: a_index
    1826             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1827             :       INTEGER                                            :: opp_index
    1828             : 
    1829             :       INTEGER, DIMENSION(3)                              :: opp_cell
    1830       61137 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    1831       61137 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1832             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1833             : 
    1834       61137 :       NULLIFY (kpoints, cell_to_index, index_to_cell)
    1835             : 
    1836       61137 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    1837       61137 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    1838             : 
    1839      244548 :       opp_cell(:) = -index_to_cell(:, a_index)
    1840             : 
    1841      427959 :       IF (ANY([opp_cell(1), opp_cell(2), opp_cell(3)] < LBOUND(cell_to_index)) .OR. &
    1842             :           ANY([opp_cell(1), opp_cell(2), opp_cell(3)] > UBOUND(cell_to_index))) THEN
    1843             : 
    1844             :          opp_index = 0
    1845             :       ELSE
    1846       61137 :          opp_index = cell_to_index(opp_cell(1), opp_cell(2), opp_cell(3))
    1847             :       END IF
    1848             : 
    1849       61137 :    END FUNCTION get_opp_index
    1850             : 
    1851             : ! **************************************************************************************************
    1852             : !> \brief A routine that returns the actual non-symemtric density matrix for each image, by Fourier
    1853             : !>        transforming the kpoint density matrix
    1854             : !> \param rho_ao_t ...
    1855             : !> \param rho_ao ...
    1856             : !> \param scale_prev_p ...
    1857             : !> \param ri_data ...
    1858             : !> \param qs_env ...
    1859             : ! **************************************************************************************************
    1860         422 :    SUBROUTINE get_pmat_images(rho_ao_t, rho_ao, scale_prev_p, ri_data, qs_env)
    1861             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t
    1862             :       TYPE(dbcsr_p_type), DIMENSION(:, :), INTENT(INOUT) :: rho_ao
    1863             :       REAL(dp), INTENT(IN)                               :: scale_prev_p
    1864             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1865             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1866             : 
    1867             :       INTEGER                                            :: cell_j(3), i_img, i_spin, iatom, icol, &
    1868             :                                                             irow, j_img, jatom, mi_img, mj_img, &
    1869             :                                                             nimg, nspins
    1870         422 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    1871             :       LOGICAL                                            :: found
    1872             :       REAL(dp)                                           :: fac
    1873         422 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock, pblock_desymm
    1874         422 :       TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks, rho_desymm
    1875        3798 :       TYPE(dbt_type)                                     :: tmp
    1876             :       TYPE(dft_control_type), POINTER                    :: dft_control
    1877             :       TYPE(kpoint_type), POINTER                         :: kpoints
    1878             :       TYPE(neighbor_list_iterator_p_type), &
    1879         422 :          DIMENSION(:), POINTER                           :: nl_iterator
    1880             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    1881         422 :          POINTER                                         :: sab_nl, sab_nl_nosym
    1882             :       TYPE(qs_scf_env_type), POINTER                     :: scf_env
    1883             : 
    1884         422 :       NULLIFY (rho_desymm, kpoints, sab_nl_nosym, scf_env, matrix_ks, dft_control, &
    1885         422 :                sab_nl, nl_iterator, cell_to_index, pblock, pblock_desymm)
    1886             : 
    1887         422 :       CALL get_qs_env(qs_env, kpoints=kpoints, scf_env=scf_env, matrix_ks_kp=matrix_ks, dft_control=dft_control)
    1888         422 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=sab_nl_nosym, cell_to_index=cell_to_index, sab_nl=sab_nl)
    1889             : 
    1890         422 :       IF (dft_control%do_admm) THEN
    1891         204 :          CALL get_admm_env(qs_env%admm_env, matrix_ks_aux_fit_kp=matrix_ks)
    1892             :       END IF
    1893             : 
    1894         422 :       nspins = SIZE(matrix_ks, 1)
    1895         422 :       nimg = ri_data%nimg
    1896             : 
    1897       26138 :       ALLOCATE (rho_desymm(nspins, nimg))
    1898       11040 :       DO i_img = 1, nimg
    1899       24872 :          DO i_spin = 1, nspins
    1900       13832 :             ALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1901             :             CALL dbcsr_create(rho_desymm(i_spin, i_img)%matrix, template=matrix_ks(i_spin, i_img)%matrix, &
    1902       13832 :                               matrix_type=dbcsr_type_no_symmetry)
    1903       24450 :             CALL cp_dbcsr_alloc_block_from_nbl(rho_desymm(i_spin, i_img)%matrix, sab_nl_nosym)
    1904             :          END DO
    1905             :       END DO
    1906         422 :       CALL dbt_create(rho_desymm(1, 1)%matrix, tmp)
    1907             : 
    1908             :       !We transfor the symmtric typed (but not actually symmetric: P_ab^i = P_ba^-i) real-spaced density
    1909             :       !matrix into proper non-symemtric ones (using the same nl for consistency)
    1910         422 :       CALL neighbor_list_iterator_create(nl_iterator, sab_nl)
    1911       18275 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    1912       17853 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    1913       17853 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    1914       17853 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    1915             : 
    1916       12737 :          fac = 1.0_dp
    1917       12737 :          IF (iatom == jatom) fac = 0.5_dp
    1918       12737 :          mj_img = get_opp_index(j_img, qs_env)
    1919             :          !if no opposite image, then no sum of P^j + P^-j => need full diag
    1920       12737 :          IF (mj_img == 0) fac = 1.0_dp
    1921             : 
    1922       12737 :          irow = iatom
    1923       12737 :          icol = jatom
    1924       12737 :          IF (iatom > jatom) THEN
    1925             :             !because symmetric nl. Value for atom pair i,j is actually stored in j,i if i > j
    1926        3997 :             irow = jatom
    1927        3997 :             icol = iatom
    1928             :          END IF
    1929             : 
    1930       30064 :          DO i_spin = 1, nspins
    1931       16905 :             CALL dbcsr_get_block_p(rho_ao(i_spin, j_img)%matrix, irow, icol, pblock, found)
    1932       16905 :             IF (.NOT. found) CYCLE
    1933             : 
    1934             :             !distribution of symm and non-symm matrix match in that way
    1935       16905 :             CALL dbcsr_get_block_p(rho_desymm(i_spin, j_img)%matrix, iatom, jatom, pblock_desymm, found)
    1936       16905 :             IF (.NOT. found) CYCLE
    1937             : 
    1938       68568 :             IF (iatom > jatom) THEN
    1939      656678 :                pblock_desymm(:, :) = fac*TRANSPOSE(pblock(:, :))
    1940             :             ELSE
    1941     1568460 :                pblock_desymm(:, :) = fac*pblock(:, :)
    1942             :             END IF
    1943             :          END DO
    1944             :       END DO
    1945         422 :       CALL neighbor_list_iterator_release(nl_iterator)
    1946             : 
    1947       11040 :       DO i_img = 1, nimg
    1948       24872 :          DO i_spin = 1, nspins
    1949       13832 :             CALL dbt_scale(rho_ao_t(i_spin, i_img), scale_prev_p)
    1950             : 
    1951       13832 :             CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, i_img)%matrix, tmp)
    1952       13832 :             CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), summation=.TRUE., move_data=.TRUE.)
    1953             : 
    1954             :             !symmetrize by addin transpose of opp img
    1955       13832 :             mi_img = get_opp_index(i_img, qs_env)
    1956       13832 :             IF (mi_img > 0 .AND. mi_img .LE. nimg) THEN
    1957       12464 :                CALL dbt_copy_matrix_to_tensor(rho_desymm(i_spin, mi_img)%matrix, tmp)
    1958       12464 :                CALL dbt_copy(tmp, rho_ao_t(i_spin, i_img), order=[2, 1], summation=.TRUE., move_data=.TRUE.)
    1959             :             END IF
    1960       24450 :             CALL dbt_filter(rho_ao_t(i_spin, i_img), ri_data%filter_eps)
    1961             :          END DO
    1962             :       END DO
    1963             : 
    1964       11040 :       DO i_img = 1, nimg
    1965       24872 :          DO i_spin = 1, nspins
    1966       13832 :             CALL dbcsr_release(rho_desymm(i_spin, i_img)%matrix)
    1967       24450 :             DEALLOCATE (rho_desymm(i_spin, i_img)%matrix)
    1968             :          END DO
    1969             :       END DO
    1970             : 
    1971         422 :       CALL dbt_destroy(tmp)
    1972         422 :       DEALLOCATE (rho_desymm)
    1973             : 
    1974         844 :    END SUBROUTINE get_pmat_images
    1975             : 
    1976             : ! **************************************************************************************************
    1977             : !> \brief A routine that, given a cell index b and atom indices ij, returns a 2c tensor with the HFX
    1978             : !>        potential (P_i^0|Q_j^b), within the extended RI basis
    1979             : !> \param t_2c_pot ...
    1980             : !> \param mat_orig ...
    1981             : !> \param atom_i ...
    1982             : !> \param atom_j ...
    1983             : !> \param img_b ...
    1984             : !> \param ri_data ...
    1985             : !> \param qs_env ...
    1986             : !> \param do_inverse ...
    1987             : !> \param para_env_ext ...
    1988             : !> \param blacs_env_ext ...
    1989             : !> \param dbcsr_template ...
    1990             : !> \param off_diagonal ...
    1991             : !> \param skip_inverse ...
    1992             : ! **************************************************************************************************
    1993        7318 :    SUBROUTINE get_ext_2c_int(t_2c_pot, mat_orig, atom_i, atom_j, img_b, ri_data, qs_env, do_inverse, &
    1994             :                              para_env_ext, blacs_env_ext, dbcsr_template, off_diagonal, skip_inverse)
    1995             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_pot
    1996             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_orig
    1997             :       INTEGER, INTENT(IN)                                :: atom_i, atom_j, img_b
    1998             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    1999             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2000             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_inverse
    2001             :       TYPE(mp_para_env_type), OPTIONAL, POINTER          :: para_env_ext
    2002             :       TYPE(cp_blacs_env_type), OPTIONAL, POINTER         :: blacs_env_ext
    2003             :       TYPE(dbcsr_type), OPTIONAL, POINTER                :: dbcsr_template
    2004             :       LOGICAL, INTENT(IN), OPTIONAL                      :: off_diagonal, skip_inverse
    2005             : 
    2006             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_ext_2c_int'
    2007             : 
    2008             :       INTEGER :: group, handle, handle2, i_img, i_RI, iatom, iblk, ikind, img_tot, j_img, j_RI, &
    2009             :          jatom, jblk, jkind, n_dependent, natom, nblks_RI, nimg, nkind
    2010        7318 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist1, dist2
    2011        7318 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: present_atoms_i, present_atoms_j
    2012             :       INTEGER, DIMENSION(3)                              :: cell_b, cell_i, cell_j, cell_tot
    2013        7318 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, col_dist_ext, ri_blk_size_ext, &
    2014        7318 :                                                             row_dist, row_dist_ext
    2015        7318 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell, pgrid
    2016        7318 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    2017             :       LOGICAL                                            :: do_inverse_prv, found, my_offd, &
    2018             :                                                             skip_inverse_prv, use_template
    2019             :       REAL(dp)                                           :: bfac, dij, r0, r1, threshold
    2020             :       REAL(dp), DIMENSION(3)                             :: ri, rij, rj, rref, scoord
    2021        7318 :       REAL(dp), DIMENSION(:, :), POINTER                 :: pblock
    2022             :       TYPE(cell_type), POINTER                           :: cell
    2023             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
    2024             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist, dbcsr_dist_ext
    2025             :       TYPE(dbcsr_iterator_type)                          :: dbcsr_iter
    2026             :       TYPE(dbcsr_type)                                   :: work, work_tight, work_tight_inv
    2027       51226 :       TYPE(dbt_type)                                     :: t_2c_tmp
    2028             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    2029             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    2030        7318 :          DIMENSION(:), TARGET                            :: basis_set_RI
    2031             :       TYPE(kpoint_type), POINTER                         :: kpoints
    2032             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2033             :       TYPE(neighbor_list_iterator_p_type), &
    2034        7318 :          DIMENSION(:), POINTER                           :: nl_iterator
    2035             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    2036        7318 :          POINTER                                         :: nl_2c
    2037        7318 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    2038        7318 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    2039             : 
    2040        7318 :       NULLIFY (qs_kind_set, nl_2c, nl_iterator, cell, kpoints, cell_to_index, index_to_cell, dist_2d, &
    2041        7318 :                para_env, pblock, blacs_env, particle_set, col_dist, row_dist, pgrid, &
    2042        7318 :                col_dist_ext, row_dist_ext)
    2043             : 
    2044        7318 :       CALL timeset(routineN, handle)
    2045             : 
    2046             :       !Idea: run over the neighbor list once for i and once for j, and record in which cell the MIC
    2047             :       !      atoms are. Then loop over the atoms and only take the pairs the we need
    2048             : 
    2049             :       CALL get_qs_env(qs_env, natom=natom, nkind=nkind, qs_kind_set=qs_kind_set, cell=cell, &
    2050        7318 :                       kpoints=kpoints, para_env=para_env, blacs_env=blacs_env, particle_set=particle_set)
    2051        7318 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell)
    2052             : 
    2053        7318 :       do_inverse_prv = .FALSE.
    2054        7318 :       IF (PRESENT(do_inverse)) do_inverse_prv = do_inverse
    2055         280 :       IF (do_inverse_prv) THEN
    2056         280 :          CPASSERT(atom_i == atom_j)
    2057             :       END IF
    2058             : 
    2059        7318 :       skip_inverse_prv = .FALSE.
    2060        7318 :       IF (PRESENT(skip_inverse)) skip_inverse_prv = skip_inverse
    2061             : 
    2062        7318 :       my_offd = .FALSE.
    2063        7318 :       IF (PRESENT(off_diagonal)) my_offd = off_diagonal
    2064             : 
    2065        7318 :       IF (PRESENT(para_env_ext)) para_env => para_env_ext
    2066        7318 :       IF (PRESENT(blacs_env_ext)) blacs_env => blacs_env_ext
    2067             : 
    2068        7318 :       nimg = SIZE(mat_orig)
    2069             : 
    2070        7318 :       CALL timeset(routineN//"_nl_iter", handle2)
    2071             : 
    2072             :       !create our own dist_2d in the subgroup
    2073       29272 :       ALLOCATE (dist1(natom), dist2(natom))
    2074       21954 :       DO iatom = 1, natom
    2075       14636 :          dist1(iatom) = MOD(iatom, blacs_env%num_pe(1))
    2076       21954 :          dist2(iatom) = MOD(iatom, blacs_env%num_pe(2))
    2077             :       END DO
    2078        7318 :       CALL distribution_2d_create(dist_2d, dist1, dist2, nkind, particle_set, blacs_env_ext=blacs_env)
    2079             : 
    2080       33055 :       ALLOCATE (basis_set_RI(nkind))
    2081        7318 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    2082             : 
    2083             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    2084        7318 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    2085             : 
    2086       43908 :       ALLOCATE (present_atoms_i(natom, nimg), present_atoms_j(natom, nimg))
    2087      746851 :       present_atoms_i = 0
    2088      746851 :       present_atoms_j = 0
    2089             : 
    2090        7318 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    2091      291602 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    2092             :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, r=rij, cell=cell_j, &
    2093      284284 :                                 ikind=ikind, jkind=jkind)
    2094             : 
    2095     1137136 :          dij = NORM2(rij)
    2096             : 
    2097      284284 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    2098      284284 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    2099             : 
    2100      282070 :          IF (iatom == atom_i .AND. dij .LE. ri_data%kp_RI_range) present_atoms_i(jatom, j_img) = 1
    2101      289388 :          IF (iatom == atom_j .AND. dij .LE. ri_data%kp_RI_range) present_atoms_j(jatom, j_img) = 1
    2102             :       END DO
    2103        7318 :       CALL neighbor_list_iterator_release(nl_iterator)
    2104        7318 :       CALL release_neighbor_list_sets(nl_2c)
    2105        7318 :       CALL distribution_2d_release(dist_2d)
    2106        7318 :       CALL timestop(handle2)
    2107             : 
    2108        7318 :       CALL para_env%sum(present_atoms_i)
    2109        7318 :       CALL para_env%sum(present_atoms_j)
    2110             : 
    2111             :       !Need to build a work matrix with matching distribution to mat_orig
    2112             :       !If template is provided, use it. If not, we create it.
    2113        7318 :       use_template = .FALSE.
    2114        7318 :       IF (PRESENT(dbcsr_template)) THEN
    2115        6646 :          IF (ASSOCIATED(dbcsr_template)) use_template = .TRUE.
    2116             :       END IF
    2117             : 
    2118             :       IF (use_template) THEN
    2119        6414 :          CALL dbcsr_create(work, template=dbcsr_template)
    2120             :       ELSE
    2121         904 :          CALL dbcsr_get_info(mat_orig(1), distribution=dbcsr_dist)
    2122         904 :          CALL dbcsr_distribution_get(dbcsr_dist, row_dist=row_dist, col_dist=col_dist, group=group, pgrid=pgrid)
    2123        3616 :          ALLOCATE (row_dist_ext(ri_data%ncell_RI*natom), col_dist_ext(ri_data%ncell_RI*natom))
    2124        1808 :          ALLOCATE (ri_blk_size_ext(ri_data%ncell_RI*natom))
    2125        6428 :          DO i_RI = 1, ri_data%ncell_RI
    2126       27620 :             row_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = row_dist(:)
    2127       27620 :             col_dist_ext((i_RI - 1)*natom + 1:i_RI*natom) = col_dist(:)
    2128       17476 :             RI_blk_size_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2129             :          END DO
    2130             : 
    2131             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2132         904 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2133             :          CALL dbcsr_create(work, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2134         904 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2135         904 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2136         904 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2137             : 
    2138        2712 :          IF (PRESENT(dbcsr_template)) THEN
    2139         232 :             ALLOCATE (dbcsr_template)
    2140         232 :             CALL dbcsr_create(dbcsr_template, template=work)
    2141             :          END IF
    2142             :       END IF !use_template
    2143             : 
    2144       29272 :       cell_b(:) = index_to_cell(:, img_b)
    2145      253829 :       DO i_img = 1, nimg
    2146      246511 :          i_RI = ri_data%img_to_RI_cell(i_img)
    2147      246511 :          IF (i_RI == 0) CYCLE
    2148      195364 :          cell_i(:) = index_to_cell(:, i_img)
    2149     2062970 :          DO j_img = 1, nimg
    2150     2006811 :             j_RI = ri_data%img_to_RI_cell(j_img)
    2151     2006811 :             IF (j_RI == 0) CYCLE
    2152     1693732 :             cell_j(:) = index_to_cell(:, j_img)
    2153     1693732 :             cell_tot = cell_j - cell_i + cell_b
    2154             : 
    2155     2924091 :             IF (ANY([cell_tot(1), cell_tot(2), cell_tot(3)] < LBOUND(cell_to_index)) .OR. &
    2156             :                 ANY([cell_tot(1), cell_tot(2), cell_tot(3)] > UBOUND(cell_to_index))) CYCLE
    2157      387929 :             img_tot = cell_to_index(cell_tot(1), cell_tot(2), cell_tot(3))
    2158      387929 :             IF (img_tot > nimg .OR. img_tot < 1) CYCLE
    2159             : 
    2160      263923 :             CALL dbcsr_iterator_start(dbcsr_iter, mat_orig(img_tot))
    2161      754095 :             DO WHILE (dbcsr_iterator_blocks_left(dbcsr_iter))
    2162      490172 :                CALL dbcsr_iterator_next_block(dbcsr_iter, row=iatom, column=jatom)
    2163      490172 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2164      184998 :                IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2165       80543 :                IF (my_offd .AND. (i_RI - 1)*natom + iatom == (j_RI - 1)*natom + jatom) CYCLE
    2166             : 
    2167       80254 :                CALL dbcsr_get_block_p(mat_orig(img_tot), iatom, jatom, pblock, found)
    2168       80254 :                IF (.NOT. found) CYCLE
    2169             : 
    2170      754095 :                CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2171             : 
    2172             :             END DO
    2173     2481741 :             CALL dbcsr_iterator_stop(dbcsr_iter)
    2174             : 
    2175             :          END DO !j_img
    2176             :       END DO !i_img
    2177        7318 :       CALL dbcsr_finalize(work)
    2178             : 
    2179        7318 :       IF (do_inverse_prv) THEN
    2180             : 
    2181         280 :          r1 = ri_data%kp_RI_range
    2182         280 :          r0 = ri_data%kp_bump_rad
    2183             : 
    2184             :          !Because there are a lot of empty rows/cols in work, we need to get rid of them for inversion
    2185       20872 :          nblks_RI = SUM(present_atoms_i)
    2186        1400 :          ALLOCATE (col_dist_ext(nblks_RI), row_dist_ext(nblks_RI), RI_blk_size_ext(nblks_RI))
    2187         280 :          iblk = 0
    2188        7144 :          DO i_img = 1, nimg
    2189        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2190        6864 :             IF (i_RI == 0) CYCLE
    2191        5512 :             DO iatom = 1, natom
    2192        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2193        1156 :                iblk = iblk + 1
    2194        1156 :                col_dist_ext(iblk) = col_dist(iatom)
    2195        1156 :                row_dist_ext(iblk) = row_dist(iatom)
    2196       10352 :                RI_blk_size_ext(iblk) = ri_data%bsizes_RI(iatom)
    2197             :             END DO
    2198             :          END DO
    2199             : 
    2200             :          CALL dbcsr_distribution_new(dbcsr_dist_ext, group=group, pgrid=pgrid, &
    2201         280 :                                      row_dist=row_dist_ext, col_dist=col_dist_ext)
    2202             :          CALL dbcsr_create(work_tight, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2203         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2204             :          CALL dbcsr_create(work_tight_inv, dist=dbcsr_dist_ext, name="RI_ext", matrix_type=dbcsr_type_no_symmetry, &
    2205         280 :                            row_blk_size=RI_blk_size_ext, col_blk_size=RI_blk_size_ext)
    2206         280 :          CALL dbcsr_distribution_release(dbcsr_dist_ext)
    2207         280 :          DEALLOCATE (col_dist_ext, row_dist_ext, RI_blk_size_ext)
    2208             : 
    2209             :          !We apply a bump function to the RI metric inverse for smooth RI basis extension:
    2210             :          ! S^-1 = B * ((P|Q)_D + B*(P|Q)_OD*B)^-1 * B, with D block-diagonal blocks and OD off-diagonal
    2211         280 :          rref = pbc(particle_set(atom_i)%r, cell)
    2212             : 
    2213         280 :          iblk = 0
    2214        7144 :          DO i_img = 1, nimg
    2215        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2216        6864 :             IF (i_RI == 0) CYCLE
    2217        5512 :             DO iatom = 1, natom
    2218        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2219        1156 :                iblk = iblk + 1
    2220             : 
    2221        1156 :                CALL real_to_scaled(scoord, pbc(particle_set(iatom)%r, cell), cell)
    2222        4624 :                CALL scaled_to_real(ri, scoord(:) + index_to_cell(:, i_img), cell)
    2223             : 
    2224        1156 :                jblk = 0
    2225       42764 :                DO j_img = 1, nimg
    2226       34744 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2227       34744 :                   IF (j_RI == 0) CYCLE
    2228       29504 :                   DO jatom = 1, natom
    2229       17344 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2230        5580 :                      jblk = jblk + 1
    2231             : 
    2232        5580 :                      CALL real_to_scaled(scoord, pbc(particle_set(jatom)%r, cell), cell)
    2233       22320 :                      CALL scaled_to_real(rj, scoord(:) + index_to_cell(:, j_img), cell)
    2234             : 
    2235        5580 :                      CALL dbcsr_get_block_p(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock, found)
    2236        5580 :                      IF (.NOT. found) CYCLE
    2237             : 
    2238        2508 :                      bfac = 1.0_dp
    2239       14088 :                      IF (iblk .NE. jblk) bfac = bump(NORM2(ri - rref), r0, r1)*bump(NORM2(rj - rref), r0, r1)
    2240     5075024 :                      CALL dbcsr_put_block(work_tight, iblk, jblk, bfac*pblock(:, :))
    2241             :                   END DO
    2242             :                END DO
    2243             :             END DO
    2244             :          END DO
    2245         280 :          CALL dbcsr_finalize(work_tight)
    2246         280 :          CALL dbcsr_clear(work)
    2247             : 
    2248         280 :          IF (.NOT. skip_inverse_prv) THEN
    2249         140 :             SELECT CASE (ri_data%t2c_method)
    2250             :             CASE (hfx_ri_do_2c_iter)
    2251           0 :                threshold = MAX(ri_data%filter_eps, 1.0e-12_dp)
    2252           0 :                CALL invert_hotelling(work_tight_inv, work_tight, threshold=threshold, silent=.FALSE.)
    2253             :             CASE (hfx_ri_do_2c_cholesky)
    2254         140 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2255         140 :                CALL cp_dbcsr_cholesky_decompose(work_tight_inv, para_env=para_env, blacs_env=blacs_env)
    2256             :                CALL cp_dbcsr_cholesky_invert(work_tight_inv, para_env=para_env, blacs_env=blacs_env, &
    2257         140 :                                              uplo_to_full=.TRUE.)
    2258             :             CASE (hfx_ri_do_2c_diag)
    2259           0 :                CALL dbcsr_copy(work_tight_inv, work_tight)
    2260             :                CALL cp_dbcsr_power(work_tight_inv, -1.0_dp, ri_data%eps_eigval, n_dependent, &
    2261         140 :                                    para_env, blacs_env, verbose=ri_data%unit_nr_dbcsr > 0)
    2262             :             END SELECT
    2263             :          ELSE
    2264         140 :             CALL dbcsr_copy(work_tight_inv, work_tight)
    2265             :          END IF
    2266             : 
    2267             :          !move back data to standard extended RI pattern
    2268             :          !Note: we apply the external bump to ((P|Q)_D + B*(P|Q)_OD*B)^-1 later, because this matrix
    2269             :          !      is required for forces
    2270         280 :          iblk = 0
    2271        7144 :          DO i_img = 1, nimg
    2272        6864 :             i_RI = ri_data%img_to_RI_cell(i_img)
    2273        6864 :             IF (i_RI == 0) CYCLE
    2274        5512 :             DO iatom = 1, natom
    2275        3488 :                IF (present_atoms_i(iatom, i_img) == 0) CYCLE
    2276        1156 :                iblk = iblk + 1
    2277             : 
    2278        1156 :                jblk = 0
    2279       42764 :                DO j_img = 1, nimg
    2280       34744 :                   j_RI = ri_data%img_to_RI_cell(j_img)
    2281       34744 :                   IF (j_RI == 0) CYCLE
    2282       29504 :                   DO jatom = 1, natom
    2283       17344 :                      IF (present_atoms_j(jatom, j_img) == 0) CYCLE
    2284        5580 :                      jblk = jblk + 1
    2285             : 
    2286        5580 :                      CALL dbcsr_get_block_p(work_tight_inv, iblk, jblk, pblock, found)
    2287        5580 :                      IF (.NOT. found) CYCLE
    2288             : 
    2289       54737 :                      CALL dbcsr_put_block(work, (i_RI - 1)*natom + iatom, (j_RI - 1)*natom + jatom, pblock)
    2290             :                   END DO
    2291             :                END DO
    2292             :             END DO
    2293             :          END DO
    2294         280 :          CALL dbcsr_finalize(work)
    2295             : 
    2296         280 :          CALL dbcsr_release(work_tight)
    2297         560 :          CALL dbcsr_release(work_tight_inv)
    2298             :       END IF
    2299             : 
    2300        7318 :       CALL dbt_create(work, t_2c_tmp)
    2301        7318 :       CALL dbt_copy_matrix_to_tensor(work, t_2c_tmp)
    2302        7318 :       CALL dbt_copy(t_2c_tmp, t_2c_pot, move_data=.TRUE.)
    2303        7318 :       CALL dbt_filter(t_2c_pot, ri_data%filter_eps)
    2304             : 
    2305        7318 :       CALL dbt_destroy(t_2c_tmp)
    2306        7318 :       CALL dbcsr_release(work)
    2307             : 
    2308        7318 :       CALL timestop(handle)
    2309             : 
    2310       29272 :    END SUBROUTINE get_ext_2c_int
    2311             : 
    2312             : ! **************************************************************************************************
    2313             : !> \brief Pre-contract the density matrices with the 3-center integrals:
    2314             : !>        P_sigma^a,lambda^a+c (mu^0 sigma^a| P^0)
    2315             : !> \param t_3c_apc ...
    2316             : !> \param rho_ao_t ...
    2317             : !> \param ri_data ...
    2318             : !> \param qs_env ...
    2319             : ! **************************************************************************************************
    2320         232 :    SUBROUTINE contract_pmat_3c(t_3c_apc, rho_ao_t, ri_data, qs_env)
    2321             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, rho_ao_t
    2322             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2323             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2324             : 
    2325             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'contract_pmat_3c'
    2326             : 
    2327             :       INTEGER                                            :: apc_img, batch_size, handle, i_batch, &
    2328             :                                                             i_img, i_spin, j_batch, n_batch_img, &
    2329             :                                                             n_batch_nze, nimg, nimg_nze, nspins
    2330             :       INTEGER(int_8)                                     :: nflop, nze
    2331         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: batch_ranges_img, batch_ranges_nze, &
    2332         232 :                                                             int_indices
    2333         232 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ac_pairs
    2334             :       REAL(dp)                                           :: occ, t1, t2
    2335        2088 :       TYPE(dbt_type)                                     :: t_3c_tmp
    2336         232 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ints_stack, res_stack, rho_stack
    2337             :       TYPE(dft_control_type), POINTER                    :: dft_control
    2338             : 
    2339         232 :       CALL timeset(routineN, handle)
    2340             : 
    2341         232 :       CALL get_qs_env(qs_env, dft_control=dft_control)
    2342             : 
    2343         232 :       nimg = ri_data%nimg
    2344         232 :       nimg_nze = ri_data%nimg_nze
    2345         232 :       nspins = dft_control%nspins
    2346             : 
    2347         232 :       CALL dbt_create(t_3c_apc(1, 1), t_3c_tmp)
    2348             : 
    2349         232 :       batch_size = nimg/ri_data%n_mem
    2350             : 
    2351             :       !batching over all images
    2352         232 :       n_batch_img = nimg/batch_size
    2353         232 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch_img = n_batch_img + 1
    2354         696 :       ALLOCATE (batch_ranges_img(n_batch_img + 1))
    2355         810 :       DO i_batch = 1, n_batch_img
    2356         810 :          batch_ranges_img(i_batch) = (i_batch - 1)*batch_size + 1
    2357             :       END DO
    2358         232 :       batch_ranges_img(n_batch_img + 1) = nimg + 1
    2359             : 
    2360             :       !batching over images with non-zero 3c integrals
    2361         232 :       n_batch_nze = nimg_nze/batch_size
    2362         232 :       IF (MODULO(nimg_nze, batch_size) .NE. 0) n_batch_nze = n_batch_nze + 1
    2363         696 :       ALLOCATE (batch_ranges_nze(n_batch_nze + 1))
    2364         752 :       DO i_batch = 1, n_batch_nze
    2365         752 :          batch_ranges_nze(i_batch) = (i_batch - 1)*batch_size + 1
    2366             :       END DO
    2367         232 :       batch_ranges_nze(n_batch_nze + 1) = nimg_nze + 1
    2368             : 
    2369             :       !Create the stack tensors in the approriate distribution
    2370        7192 :       ALLOCATE (rho_stack(2), ints_stack(2), res_stack(2))
    2371             :       CALL get_stack_tensors(res_stack, rho_stack, ints_stack, rho_ao_t(1, 1), &
    2372         232 :                              ri_data%t_3c_int_ctr_1(1, 1), batch_size, ri_data, qs_env)
    2373             : 
    2374        1160 :       ALLOCATE (ac_pairs(nimg, 2), int_indices(nimg_nze))
    2375        4554 :       DO i_img = 1, nimg_nze
    2376        4554 :          int_indices(i_img) = i_img
    2377             :       END DO
    2378             : 
    2379         232 :       t1 = m_walltime()
    2380         752 :       DO j_batch = 1, n_batch_nze
    2381             :          !First batch is over the integrals. They are always in the same order, consistent with get_ac_pairs
    2382             :          CALL fill_3c_stack(ints_stack(1), ri_data%t_3c_int_ctr_1(1, :), int_indices, 3, ri_data, &
    2383        1560 :                             img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)])
    2384         520 :          CALL dbt_copy(ints_stack(1), ints_stack(2), move_data=.TRUE.)
    2385             : 
    2386        1492 :          DO i_spin = 1, nspins
    2387        3152 :             DO i_batch = 1, n_batch_img
    2388             :                !Second batch is over the P matrix. Here we fill the stacked rho tensors col by col
    2389       17120 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2390       15228 :                   CALL get_ac_pairs(ac_pairs, apc_img, ri_data, qs_env)
    2391             :                   CALL fill_2c_stack(rho_stack(1), rho_ao_t(i_spin, :), ac_pairs(:, 2), 1, ri_data, &
    2392             :                                      img_bounds=[batch_ranges_nze(j_batch), batch_ranges_nze(j_batch + 1)], &
    2393       47576 :                                      shift=apc_img - batch_ranges_img(i_batch) + 1)
    2394             : 
    2395             :                END DO !apc_img
    2396        1892 :                CALL get_tensor_occupancy(rho_stack(1), nze, occ)
    2397        1892 :                IF (nze == 0) CYCLE
    2398        1650 :                CALL dbt_copy(rho_stack(1), rho_stack(2), move_data=.TRUE.)
    2399             : 
    2400             :                !The actual contraction
    2401        1650 :                CALL dbt_batched_contract_init(rho_stack(2))
    2402             :                CALL dbt_contract(1.0_dp, ints_stack(2), rho_stack(2), &
    2403             :                                  0.0_dp, res_stack(2), map_1=[1, 2], map_2=[3], &
    2404             :                                  contract_1=[3], notcontract_1=[1, 2], &
    2405             :                                  contract_2=[1], notcontract_2=[2], &
    2406        1650 :                                  filter_eps=ri_data%filter_eps, flop=nflop)
    2407        1650 :                ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2408        1650 :                CALL dbt_batched_contract_finalize(rho_stack(2))
    2409        1650 :                CALL dbt_copy(res_stack(2), res_stack(1), move_data=.TRUE.)
    2410             : 
    2411       18680 :                DO apc_img = batch_ranges_img(i_batch), batch_ranges_img(i_batch + 1) - 1
    2412             :                   !Destack the resulting tensor and put it in t_3c_apc with correct apc_img
    2413       14398 :                   CALL unstack_t_3c_apc(t_3c_tmp, res_stack(1), apc_img - batch_ranges_img(i_batch) + 1)
    2414       16290 :                   CALL dbt_copy(t_3c_tmp, t_3c_apc(i_spin, apc_img), summation=.TRUE., move_data=.TRUE.)
    2415             :                END DO
    2416             : 
    2417             :             END DO !i_batch
    2418             :          END DO !i_spin
    2419             :       END DO !j_batch
    2420         232 :       DEALLOCATE (batch_ranges_img)
    2421         232 :       DEALLOCATE (batch_ranges_nze)
    2422         232 :       t2 = m_walltime()
    2423         232 :       ri_data%dbcsr_time = ri_data%dbcsr_time + t2 - t1
    2424             : 
    2425         232 :       CALL dbt_destroy(rho_stack(1))
    2426         232 :       CALL dbt_destroy(rho_stack(2))
    2427         232 :       CALL dbt_destroy(ints_stack(1))
    2428         232 :       CALL dbt_destroy(ints_stack(2))
    2429         232 :       CALL dbt_destroy(res_stack(1))
    2430         232 :       CALL dbt_destroy(res_stack(2))
    2431         232 :       CALL dbt_destroy(t_3c_tmp)
    2432             : 
    2433         232 :       CALL timestop(handle)
    2434             : 
    2435        2320 :    END SUBROUTINE contract_pmat_3c
    2436             : 
    2437             : ! **************************************************************************************************
    2438             : !> \brief Pre-contract 3-center integrals with the bumped invrse RI metric, for each atom
    2439             : !> \param t_3c_int ...
    2440             : !> \param ri_data ...
    2441             : !> \param qs_env ...
    2442             : ! **************************************************************************************************
    2443          70 :    SUBROUTINE precontract_3c_ints(t_3c_int, ri_data, qs_env)
    2444             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_int
    2445             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2446             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    2447             : 
    2448             :       CHARACTER(len=*), PARAMETER :: routineN = 'precontract_3c_ints'
    2449             : 
    2450             :       INTEGER                                            :: batch_size, handle, i_batch, i_img, &
    2451             :                                                             i_RI, iatom, is, n_batch, natom, &
    2452             :                                                             nblks, nblks_3c(3), nimg
    2453             :       INTEGER(int_8)                                     :: nflop
    2454          70 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: batch_ranges, bsizes_RI_ext, bsizes_RI_ext_split, &
    2455          70 :          bsizes_stack, dist1, dist2, dist3, dist_stack3, idx_to_at_AO, int_indices
    2456         630 :       TYPE(dbt_distribution_type)                        :: t_dist
    2457       14630 :       TYPE(dbt_type)                                     :: t_2c_RI_tmp(2), t_3c_tmp(3)
    2458             : 
    2459          70 :       CALL timeset(routineN, handle)
    2460             : 
    2461          70 :       CALL get_qs_env(qs_env, natom=natom)
    2462             : 
    2463          70 :       nimg = ri_data%nimg
    2464         210 :       ALLOCATE (int_indices(nimg))
    2465        1786 :       DO i_img = 1, nimg
    2466        1786 :          int_indices(i_img) = i_img
    2467             :       END DO
    2468             : 
    2469         210 :       ALLOCATE (idx_to_at_AO(SIZE(ri_data%bsizes_AO_split)))
    2470          70 :       CALL get_idx_to_atom(idx_to_at_AO, ri_data%bsizes_AO_split, ri_data%bsizes_AO)
    2471             : 
    2472          70 :       nblks = SIZE(ri_data%bsizes_RI_split)
    2473         210 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    2474         210 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    2475         506 :       DO i_RI = 1, ri_data%ncell_RI
    2476        1308 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    2477        2344 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    2478             :       END DO
    2479             :       CALL create_2c_tensor(t_2c_RI_tmp(1), dist1, dist2, ri_data%pgrid_2d, &
    2480             :                             bsizes_RI_ext, bsizes_RI_ext, &
    2481             :                             name="(RI | RI)")
    2482          70 :       DEALLOCATE (dist1, dist2)
    2483             :       CALL create_2c_tensor(t_2c_RI_tmp(2), dist1, dist2, ri_data%pgrid_2d, &
    2484             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    2485             :                             name="(RI | RI)")
    2486          70 :       DEALLOCATE (dist1, dist2)
    2487             : 
    2488             :       !For more efficiency, we stack multiple images of the 3-center integrals into a single tensor
    2489          70 :       batch_size = nimg/ri_data%n_mem
    2490          70 :       n_batch = nimg/batch_size
    2491          70 :       IF (MODULO(nimg, batch_size) .NE. 0) n_batch = n_batch + 1
    2492         210 :       ALLOCATE (batch_ranges(n_batch + 1))
    2493         246 :       DO i_batch = 1, n_batch
    2494         246 :          batch_ranges(i_batch) = (i_batch - 1)*batch_size + 1
    2495             :       END DO
    2496          70 :       batch_ranges(n_batch + 1) = nimg + 1
    2497             : 
    2498          70 :       nblks = SIZE(ri_data%bsizes_AO_split)
    2499         210 :       ALLOCATE (bsizes_stack(batch_size*nblks))
    2500         910 :       DO is = 1, batch_size
    2501        3502 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    2502             :       END DO
    2503             : 
    2504          70 :       CALL dbt_get_info(t_3c_int(1, 1), nblks_total=nblks_3c)
    2505         630 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)), dist_stack3(batch_size*nblks_3c(3)))
    2506          70 :       CALL dbt_get_info(t_3c_int(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, proc_dist_3=dist3)
    2507         910 :       DO is = 1, batch_size
    2508        3502 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    2509             :       END DO
    2510             : 
    2511          70 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid, dist1, dist2, dist_stack3)
    2512             :       CALL dbt_create(t_3c_tmp(1), "ints_stack", t_dist, [1], [2, 3], bsizes_RI_ext_split, &
    2513          70 :                       ri_data%bsizes_AO_split, bsizes_stack)
    2514          70 :       CALL dbt_distribution_destroy(t_dist)
    2515          70 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    2516             : 
    2517          70 :       CALL dbt_create(t_3c_tmp(1), t_3c_tmp(2))
    2518          70 :       CALL dbt_create(t_3c_int(1, 1), t_3c_tmp(3))
    2519             : 
    2520         210 :       DO iatom = 1, natom
    2521         140 :          CALL dbt_copy(ri_data%t_2c_inv(1, iatom), t_2c_RI_tmp(1))
    2522         140 :          CALL apply_bump(t_2c_RI_tmp(1), iatom, ri_data, qs_env, from_left=.TRUE., from_right=.TRUE.)
    2523         140 :          CALL dbt_copy(t_2c_RI_tmp(1), t_2c_RI_tmp(2), move_data=.TRUE.)
    2524             : 
    2525         140 :          CALL dbt_batched_contract_init(t_2c_RI_tmp(2))
    2526         492 :          DO i_batch = 1, n_batch
    2527             : 
    2528             :             CALL fill_3c_stack(t_3c_tmp(1), t_3c_int(1, :), int_indices, 3, ri_data, &
    2529             :                                img_bounds=[batch_ranges(i_batch), batch_ranges(i_batch + 1)], &
    2530        1056 :                                filter_at=iatom, filter_dim=2, idx_to_at=idx_to_at_AO)
    2531             : 
    2532             :             CALL dbt_contract(1.0_dp, t_2c_RI_tmp(2), t_3c_tmp(1), &
    2533             :                               0.0_dp, t_3c_tmp(2), map_1=[1], map_2=[2, 3], &
    2534             :                               contract_1=[2], notcontract_1=[1], &
    2535             :                               contract_2=[1], notcontract_2=[2, 3], &
    2536         352 :                               filter_eps=ri_data%filter_eps, flop=nflop)
    2537         352 :             ri_data%dbcsr_nflop = ri_data%dbcsr_nflop + nflop
    2538             : 
    2539        3784 :             DO i_img = batch_ranges(i_batch), batch_ranges(i_batch + 1) - 1
    2540        3432 :                CALL unstack_t_3c_apc(t_3c_tmp(3), t_3c_tmp(2), i_img - batch_ranges(i_batch) + 1)
    2541             :                CALL dbt_copy(t_3c_tmp(3), ri_data%t_3c_int_ctr_1(1, i_img), summation=.TRUE., &
    2542        3784 :                              order=[2, 1, 3], move_data=.TRUE.)
    2543             :             END DO
    2544         492 :             CALL dbt_clear(t_3c_tmp(1))
    2545             :          END DO
    2546         210 :          CALL dbt_batched_contract_finalize(t_2c_RI_tmp(2))
    2547             : 
    2548             :       END DO
    2549          70 :       CALL dbt_destroy(t_2c_RI_tmp(1))
    2550          70 :       CALL dbt_destroy(t_2c_RI_tmp(2))
    2551          70 :       CALL dbt_destroy(t_3c_tmp(1))
    2552          70 :       CALL dbt_destroy(t_3c_tmp(2))
    2553          70 :       CALL dbt_destroy(t_3c_tmp(3))
    2554             : 
    2555        1786 :       DO i_img = 1, nimg
    2556        1786 :          CALL dbt_destroy(t_3c_int(1, i_img))
    2557             :       END DO
    2558             : 
    2559          70 :       CALL timestop(handle)
    2560             : 
    2561         350 :    END SUBROUTINE precontract_3c_ints
    2562             : 
    2563             : ! **************************************************************************************************
    2564             : !> \brief Copy the data of a 2D tensor living in the main MPI group to a sub-group, given the proc
    2565             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2566             : !> \param t2c_sub ...
    2567             : !> \param t2c_main ...
    2568             : !> \param group_size ...
    2569             : !> \param ngroups ...
    2570             : !> \param para_env ...
    2571             : ! **************************************************************************************************
    2572        8000 :    SUBROUTINE copy_2c_to_subgroup(t2c_sub, t2c_main, group_size, ngroups, para_env)
    2573             :       TYPE(dbt_type), INTENT(INOUT)                      :: t2c_sub, t2c_main
    2574             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2575             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2576             : 
    2577             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iblk, &
    2578             :                                                             igroup, iproc, ir, is, jblk, n_batch, &
    2579             :                                                             nocc, tag
    2580        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2
    2581        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: block_dest, block_source
    2582        8000 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: current_dest
    2583             :       INTEGER, DIMENSION(2)                              :: ind, nblks
    2584             :       LOGICAL                                            :: found
    2585        8000 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2586        8000 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2587             :       TYPE(dbt_iterator_type)                            :: iter
    2588        8000 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2589             : 
    2590             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2591             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2592             :       !         to have blocks pre-reserved though
    2593             : 
    2594        8000 :       CALL dbt_get_info(t2c_main, nblks_total=nblks)
    2595             : 
    2596             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2597       32000 :       ALLOCATE (block_source(nblks(1), nblks(2)))
    2598      166468 :       block_source = -1
    2599        8000 :       nocc = 0
    2600        8000 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t2c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2601             :       CALL dbt_iterator_start(iter, t2c_main)
    2602             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2603             :          CALL dbt_iterator_next_block(iter, ind)
    2604             :          CALL dbt_get_block(t2c_main, ind, blk, found)
    2605             :          IF (.NOT. found) CYCLE
    2606             : 
    2607             :          block_source(ind(1), ind(2)) = para_env%mepos
    2608             : !$OMP ATOMIC
    2609             :          nocc = nocc + 1
    2610             :          DEALLOCATE (blk)
    2611             :       END DO
    2612             :       CALL dbt_iterator_stop(iter)
    2613             : !$OMP END PARALLEL
    2614             : 
    2615        8000 :       CALL para_env%sum(nocc)
    2616        8000 :       CALL para_env%sum(block_source)
    2617      166468 :       block_source = block_source + para_env%num_pe - 1
    2618        8000 :       IF (nocc == 0) RETURN
    2619             : 
    2620             :       !Loop over the sub tensor, get the block destination
    2621        7820 :       igroup = para_env%mepos/group_size
    2622       23460 :       ALLOCATE (block_dest(nblks(1), nblks(2)))
    2623      165208 :       block_dest = -1
    2624       29840 :       DO jblk = 1, nblks(2)
    2625      165208 :          DO iblk = 1, nblks(1)
    2626      135368 :             IF (block_source(iblk, jblk) == -1) CYCLE
    2627             : 
    2628       98868 :             CALL dbt_get_stored_coordinates(t2c_sub, [iblk, jblk], iproc)
    2629      157388 :             block_dest(iblk, jblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2630             :          END DO
    2631             :       END DO
    2632             : 
    2633       39100 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)))
    2634        7820 :       CALL dbt_get_info(t2c_main, blk_size_1=bsizes1, blk_size_2=bsizes2)
    2635             : 
    2636       39100 :       ALLOCATE (current_dest(nblks(1), nblks(2), 0:ngroups - 1))
    2637       23460 :       DO igroup = 0, ngroups - 1
    2638             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2639      330416 :          current_dest(:, :, igroup) = block_dest(:, :)
    2640       23460 :          CALL para_env%bcast(current_dest(:, :, igroup), source=igroup*group_size) !bcast from first proc in sub-group
    2641             :       END DO
    2642             : 
    2643             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2644        7820 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2645        7820 :       n_batch = (nocc*ngroups)/batch_size
    2646        7820 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2647             : 
    2648       15640 :       DO i_batch = 1, n_batch
    2649             :          !Loop over groups, blocks and send/receive
    2650      163104 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2651      163104 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2652             :          ir = 0
    2653             :          is = 0
    2654             :          i_msg = 0
    2655       29840 :          DO jblk = 1, nblks(2)
    2656      165208 :             DO iblk = 1, nblks(1)
    2657      428124 :                DO igroup = 0, ngroups - 1
    2658      270736 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2659             : 
    2660       65912 :                   i_msg = i_msg + 1
    2661       65912 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2662             : 
    2663             :                   !a unique tag per block, within this batch
    2664       65912 :                   tag = i_msg - (i_batch - 1)*batch_size
    2665             : 
    2666       65912 :                   found = .FALSE.
    2667       65912 :                   IF (para_env%mepos == block_source(iblk, jblk)) THEN
    2668       98868 :                      CALL dbt_get_block(t2c_main, [iblk, jblk], blk, found)
    2669             :                   END IF
    2670             : 
    2671             :                   !If blocks live on same proc, simply copy. Else MPI send/recv
    2672       65912 :                   IF (block_source(iblk, jblk) == current_dest(iblk, jblk, igroup)) THEN
    2673       98868 :                      IF (found) CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2674             :                   ELSE
    2675       32956 :                      IF (para_env%mepos == block_source(iblk, jblk) .AND. found) THEN
    2676       65912 :                         ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2677    21062150 :                         send_buff(tag)%array(:, :) = blk(:, :)
    2678       16478 :                         is = is + 1
    2679             :                         CALL para_env%isend(msgin=send_buff(tag)%array, dest=current_dest(iblk, jblk, igroup), &
    2680       16478 :                                             request=send_req(is), tag=tag)
    2681             :                      END IF
    2682             : 
    2683       32956 :                      IF (para_env%mepos == current_dest(iblk, jblk, igroup)) THEN
    2684       65912 :                         ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk)))
    2685       16478 :                         ir = ir + 1
    2686             :                         CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk), &
    2687       16478 :                                             request=recv_req(ir), tag=tag)
    2688             :                      END IF
    2689             :                   END IF
    2690             : 
    2691      201280 :                   IF (found) DEALLOCATE (blk)
    2692             :                END DO
    2693             :             END DO
    2694             :          END DO
    2695             : 
    2696        7820 :          CALL mp_waitall(send_req(1:is))
    2697        7820 :          CALL mp_waitall(recv_req(1:ir))
    2698             :          !clean-up
    2699       73732 :          DO i = 1, batch_size
    2700       73732 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2701             :          END DO
    2702             : 
    2703             :          !Finally copy the data from the buffer to the sub-tensor
    2704             :          i_msg = 0
    2705       29840 :          DO jblk = 1, nblks(2)
    2706      165208 :             DO iblk = 1, nblks(1)
    2707      428124 :                DO igroup = 0, ngroups - 1
    2708      270736 :                   IF (block_source(iblk, jblk) == -1) CYCLE
    2709             : 
    2710       65912 :                   i_msg = i_msg + 1
    2711       65912 :                   IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2712             : 
    2713             :                   !a unique tag per block, within this batch
    2714       65912 :                   tag = i_msg - (i_batch - 1)*batch_size
    2715             : 
    2716       65912 :                   IF (para_env%mepos == current_dest(iblk, jblk, igroup) .AND. &
    2717      135368 :                       block_source(iblk, jblk) .NE. current_dest(iblk, jblk, igroup)) THEN
    2718             : 
    2719       65912 :                      ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk)))
    2720    21062150 :                      blk(:, :) = recv_buff(tag)%array(:, :)
    2721       82390 :                      CALL dbt_put_block(t2c_sub, [iblk, jblk], SHAPE(blk), blk)
    2722       16478 :                      DEALLOCATE (blk)
    2723             :                   END IF
    2724             :                END DO
    2725             :             END DO
    2726             :          END DO
    2727             : 
    2728             :          !clean-up
    2729       73732 :          DO i = 1, batch_size
    2730       73732 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2731             :          END DO
    2732       15640 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2733             :       END DO !i_batch
    2734        7820 :       CALL dbt_finalize(t2c_sub)
    2735             : 
    2736       16000 :    END SUBROUTINE copy_2c_to_subgroup
    2737             : 
    2738             : ! **************************************************************************************************
    2739             : !> \brief Copy the data of a 3D tensor living in the main MPI group to a sub-group, given the proc
    2740             : !>        mapping from one to the other (e.g. for a proc idx in the subgroup, we get the idx in the main)
    2741             : !> \param t3c_sub ...
    2742             : !> \param t3c_main ...
    2743             : !> \param group_size ...
    2744             : !> \param ngroups ...
    2745             : !> \param para_env ...
    2746             : !> \param iatom_to_subgroup ...
    2747             : !> \param dim_at ...
    2748             : !> \param idx_to_at ...
    2749             : ! **************************************************************************************************
    2750       11690 :    SUBROUTINE copy_3c_to_subgroup(t3c_sub, t3c_main, group_size, ngroups, para_env, iatom_to_subgroup, &
    2751       11690 :                                   dim_at, idx_to_at)
    2752             :       TYPE(dbt_type), INTENT(INOUT)                      :: t3c_sub, t3c_main
    2753             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    2754             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2755             :       TYPE(cp_1d_logical_p_type), DIMENSION(:), &
    2756             :          INTENT(INOUT), OPTIONAL                         :: iatom_to_subgroup
    2757             :       INTEGER, INTENT(IN), OPTIONAL                      :: dim_at
    2758             :       INTEGER, DIMENSION(:), OPTIONAL                    :: idx_to_at
    2759             : 
    2760             :       INTEGER                                            :: batch_size, i, i_batch, i_msg, iatom, &
    2761             :                                                             iblk, igroup, iproc, ir, is, jblk, &
    2762             :                                                             kblk, n_batch, nocc, tag
    2763       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes1, bsizes2, bsizes3
    2764       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: block_dest, block_source
    2765       11690 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :)        :: current_dest
    2766             :       INTEGER, DIMENSION(3)                              :: ind, nblks
    2767             :       LOGICAL                                            :: filter_at, found
    2768       11690 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    2769       11690 :       TYPE(cp_3d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2770             :       TYPE(dbt_iterator_type)                            :: iter
    2771       11690 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2772             : 
    2773             :       !Stategy: we loop over the main tensor, and send all the data. Then we loop over the sub tensor
    2774             :       !         and receive it. We do all of it with async MPI communication. The sub tensor needs
    2775             :       !         to have blocks pre-reserved though
    2776             : 
    2777       11690 :       CALL dbt_get_info(t3c_main, nblks_total=nblks)
    2778             : 
    2779             :       !in some cases, only copy a fraction of the 3c tensor to a given subgroup (corresponding to some atoms)
    2780       11690 :       filter_at = .FALSE.
    2781       11690 :       IF (PRESENT(iatom_to_subgroup) .AND. PRESENT(dim_at) .AND. PRESENT(idx_to_at)) THEN
    2782        6614 :          filter_at = .TRUE.
    2783        6614 :          CPASSERT(nblks(dim_at) == SIZE(idx_to_at))
    2784             :       END IF
    2785             : 
    2786             :       !Loop over the main tensor, count how many blocks are there, which ones, and on which proc
    2787       58450 :       ALLOCATE (block_source(nblks(1), nblks(2), nblks(3)))
    2788      640562 :       block_source = -1
    2789       11690 :       nocc = 0
    2790       11690 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t3c_main,para_env,nocc,block_source) PRIVATE(iter,ind,blk,found)
    2791             :       CALL dbt_iterator_start(iter, t3c_main)
    2792             :       DO WHILE (dbt_iterator_blocks_left(iter))
    2793             :          CALL dbt_iterator_next_block(iter, ind)
    2794             :          CALL dbt_get_block(t3c_main, ind, blk, found)
    2795             :          IF (.NOT. found) CYCLE
    2796             : 
    2797             :          block_source(ind(1), ind(2), ind(3)) = para_env%mepos
    2798             : !$OMP ATOMIC
    2799             :          nocc = nocc + 1
    2800             :          DEALLOCATE (blk)
    2801             :       END DO
    2802             :       CALL dbt_iterator_stop(iter)
    2803             : !$OMP END PARALLEL
    2804             : 
    2805       11690 :       CALL para_env%sum(nocc)
    2806       11690 :       CALL para_env%sum(block_source)
    2807      640562 :       block_source = block_source + para_env%num_pe - 1
    2808       11690 :       IF (nocc == 0) RETURN
    2809             : 
    2810             :       !Loop over the sub tensor, get the block destination
    2811       11690 :       igroup = para_env%mepos/group_size
    2812       46760 :       ALLOCATE (block_dest(nblks(1), nblks(2), nblks(3)))
    2813      640562 :       block_dest = -1
    2814       35070 :       DO kblk = 1, nblks(3)
    2815      163314 :          DO jblk = 1, nblks(2)
    2816      628872 :             DO iblk = 1, nblks(1)
    2817      477248 :                IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2818             : 
    2819      491240 :                CALL dbt_get_stored_coordinates(t3c_sub, [iblk, jblk, kblk], iproc)
    2820      605492 :                block_dest(iblk, jblk, kblk) = igroup*group_size + iproc !mapping of iproc in subgroup to main group idx
    2821             :             END DO
    2822             :          END DO
    2823             :       END DO
    2824             : 
    2825       81830 :       ALLOCATE (bsizes1(nblks(1)), bsizes2(nblks(2)), bsizes3(nblks(3)))
    2826       11690 :       CALL dbt_get_info(t3c_main, blk_size_1=bsizes1, blk_size_2=bsizes2, blk_size_3=bsizes3)
    2827             : 
    2828       70140 :       ALLOCATE (current_dest(nblks(1), nblks(2), nblks(3), 0:ngroups - 1))
    2829       35070 :       DO igroup = 0, ngroups - 1
    2830             :          !for a given subgroup, need to make the destination available to everyone in the main group
    2831     1281124 :          current_dest(:, :, :, igroup) = block_dest(:, :, :)
    2832       35070 :          CALL para_env%bcast(current_dest(:, :, :, igroup), source=igroup*group_size) !bcast from first proc in subgroup
    2833             :       END DO
    2834             : 
    2835             :       !We go by batches, which cannot be larger than the maximum MPI tag value
    2836       11690 :       batch_size = MIN(para_env%get_tag_ub(), 128000, nocc*ngroups)
    2837       11690 :       n_batch = (nocc*ngroups)/batch_size
    2838       11690 :       IF (MODULO(nocc*ngroups, batch_size) .NE. 0) n_batch = n_batch + 1
    2839             : 
    2840       23380 :       DO i_batch = 1, n_batch
    2841             :          !Loop over groups, blocks and send/receive
    2842      538000 :          ALLOCATE (send_buff(batch_size), recv_buff(batch_size))
    2843      538000 :          ALLOCATE (send_req(batch_size), recv_req(batch_size))
    2844             :          ir = 0
    2845             :          is = 0
    2846             :          i_msg = 0
    2847       35070 :          DO kblk = 1, nblks(3)
    2848      163314 :             DO jblk = 1, nblks(2)
    2849      628872 :                DO iblk = 1, nblks(1)
    2850     1559988 :                   DO igroup = 0, ngroups - 1
    2851      954496 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2852             : 
    2853      245620 :                      i_msg = i_msg + 1
    2854      245620 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2855             : 
    2856             :                      !a unique tag per block, within this batch
    2857      245620 :                      tag = i_msg - (i_batch - 1)*batch_size
    2858             : 
    2859      245620 :                      IF (filter_at) THEN
    2860      693136 :                         ind(:) = [iblk, jblk, kblk]
    2861      173284 :                         iatom = idx_to_at(ind(dim_at))
    2862      173284 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2863             :                      END IF
    2864             : 
    2865      158978 :                      found = .FALSE.
    2866      158978 :                      IF (para_env%mepos == block_source(iblk, jblk, kblk)) THEN
    2867      317956 :                         CALL dbt_get_block(t3c_main, [iblk, jblk, kblk], blk, found)
    2868             :                      END IF
    2869             : 
    2870             :                      !If blocks live on same proc, simply copy. Else MPI send/recv
    2871      158978 :                      IF (block_source(iblk, jblk, kblk) == current_dest(iblk, jblk, kblk, igroup)) THEN
    2872      329744 :                         IF (found) CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2873             :                      ELSE
    2874       76542 :                         IF (para_env%mepos == block_source(iblk, jblk, kblk) .AND. found) THEN
    2875      191355 :                            ALLOCATE (send_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2876   187303822 :                            send_buff(tag)%array(:, :, :) = blk(:, :, :)
    2877       38271 :                            is = is + 1
    2878             :                            CALL para_env%isend(msgin=send_buff(tag)%array, &
    2879             :                                                dest=current_dest(iblk, jblk, kblk, igroup), &
    2880       38271 :                                                request=send_req(is), tag=tag)
    2881             :                         END IF
    2882             : 
    2883       76542 :                         IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup)) THEN
    2884      191355 :                            ALLOCATE (recv_buff(tag)%array(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2885       38271 :                            ir = ir + 1
    2886             :                            CALL para_env%irecv(msgout=recv_buff(tag)%array, source=block_source(iblk, jblk, kblk), &
    2887       38271 :                                                request=recv_req(ir), tag=tag)
    2888             :                         END IF
    2889             :                      END IF
    2890             : 
    2891      636226 :                      IF (found) DEALLOCATE (blk)
    2892             :                   END DO
    2893             :                END DO
    2894             :             END DO
    2895             :          END DO
    2896             : 
    2897       11690 :          CALL mp_waitall(send_req(1:is))
    2898       11690 :          CALL mp_waitall(recv_req(1:ir))
    2899             :          !clean-up
    2900      257310 :          DO i = 1, batch_size
    2901      257310 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    2902             :          END DO
    2903             : 
    2904             :          !Finally copy the data from the buffer to the sub-tensor
    2905             :          i_msg = 0
    2906       35070 :          DO kblk = 1, nblks(3)
    2907      163314 :             DO jblk = 1, nblks(2)
    2908      628872 :                DO iblk = 1, nblks(1)
    2909     1559988 :                   DO igroup = 0, ngroups - 1
    2910      954496 :                      IF (block_source(iblk, jblk, kblk) == -1) CYCLE
    2911             : 
    2912      245620 :                      i_msg = i_msg + 1
    2913      245620 :                      IF (i_msg < (i_batch - 1)*batch_size + 1 .OR. i_msg > i_batch*batch_size) CYCLE
    2914             : 
    2915             :                      !a unique tag per block, within this batch
    2916      245620 :                      tag = i_msg - (i_batch - 1)*batch_size
    2917             : 
    2918      245620 :                      IF (filter_at) THEN
    2919      693136 :                         ind(:) = [iblk, jblk, kblk]
    2920      173284 :                         iatom = idx_to_at(ind(dim_at))
    2921      173284 :                         IF (.NOT. iatom_to_subgroup(iatom)%array(igroup + 1)) CYCLE
    2922             :                      END IF
    2923             : 
    2924      158978 :                      IF (para_env%mepos == current_dest(iblk, jblk, kblk, igroup) .AND. &
    2925      477248 :                          block_source(iblk, jblk, kblk) .NE. current_dest(iblk, jblk, kblk, igroup)) THEN
    2926             : 
    2927      191355 :                         ALLOCATE (blk(bsizes1(iblk), bsizes2(jblk), bsizes3(kblk)))
    2928   187303822 :                         blk(:, :, :) = recv_buff(tag)%array(:, :, :)
    2929      267897 :                         CALL dbt_put_block(t3c_sub, [iblk, jblk, kblk], SHAPE(blk), blk)
    2930       38271 :                         DEALLOCATE (blk)
    2931             :                      END IF
    2932             :                   END DO
    2933             :                END DO
    2934             :             END DO
    2935             :          END DO
    2936             : 
    2937             :          !clean-up
    2938      257310 :          DO i = 1, batch_size
    2939      257310 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    2940             :          END DO
    2941       23380 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    2942             :       END DO !i_batch
    2943       11690 :       CALL dbt_finalize(t3c_sub)
    2944             : 
    2945       23380 :    END SUBROUTINE copy_3c_to_subgroup
    2946             : 
    2947             : ! **************************************************************************************************
    2948             : !> \brief A routine that gather the pieces of the KS matrix accross the subgroup and puts it in the
    2949             : !>        main group. Each b_img, iatom, jatom tuple is one a single CPU
    2950             : !> \param ks_t ...
    2951             : !> \param ks_t_sub ...
    2952             : !> \param group_size ...
    2953             : !> \param sparsity_pattern ...
    2954             : !> \param para_env ...
    2955             : !> \param ri_data ...
    2956             : ! **************************************************************************************************
    2957         190 :    SUBROUTINE gather_ks_matrix(ks_t, ks_t_sub, group_size, sparsity_pattern, para_env, ri_data)
    2958             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t, ks_t_sub
    2959             :       INTEGER, INTENT(IN)                                :: group_size
    2960             :       INTEGER, DIMENSION(:, :, :), INTENT(IN)            :: sparsity_pattern
    2961             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    2962             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    2963             : 
    2964             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'gather_ks_matrix'
    2965             : 
    2966             :       INTEGER                                            :: b_img, dest, handle, i, i_spin, iatom, &
    2967             :                                                             igroup, ir, is, jatom, n_mess, natom, &
    2968             :                                                             nimg, nspins, source, tag
    2969             :       LOGICAL                                            :: found
    2970         190 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    2971         190 :       TYPE(cp_2d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: recv_buff, send_buff
    2972         190 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: recv_req, send_req
    2973             : 
    2974         190 :       CALL timeset(routineN, handle)
    2975             : 
    2976         190 :       nimg = SIZE(sparsity_pattern, 3)
    2977         190 :       natom = SIZE(sparsity_pattern, 2)
    2978         190 :       nspins = SIZE(ks_t, 1)
    2979             : 
    2980        4994 :       DO b_img = 1, nimg
    2981             :          n_mess = 0
    2982       11132 :          DO i_spin = 1, nspins
    2983       23788 :             DO jatom = 1, natom
    2984       44296 :                DO iatom = 1, natom
    2985       37968 :                   IF (sparsity_pattern(iatom, jatom, b_img) > -1) n_mess = n_mess + 1
    2986             :                END DO
    2987             :             END DO
    2988             :          END DO
    2989             : 
    2990       37136 :          ALLOCATE (send_buff(n_mess), recv_buff(n_mess))
    2991       41940 :          ALLOCATE (send_req(n_mess), recv_req(n_mess))
    2992        4804 :          ir = 0
    2993        4804 :          is = 0
    2994        4804 :          n_mess = 0
    2995        4804 :          tag = 0
    2996             : 
    2997       11132 :          DO i_spin = 1, nspins
    2998       23788 :             DO jatom = 1, natom
    2999       44296 :                DO iatom = 1, natom
    3000       25312 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3001        8776 :                   n_mess = n_mess + 1
    3002        8776 :                   tag = tag + 1
    3003             : 
    3004             :                   !sending the message
    3005       26328 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3006       26328 :                   CALL dbt_get_stored_coordinates(ks_t_sub(i_spin, b_img), [iatom, jatom], source) !source within sub
    3007        8776 :                   igroup = sparsity_pattern(iatom, jatom, b_img)
    3008        8776 :                   source = source + igroup*group_size
    3009        8776 :                   IF (para_env%mepos == source) THEN
    3010       13164 :                      CALL dbt_get_block(ks_t_sub(i_spin, b_img), [iatom, jatom], blk, found)
    3011        4388 :                      IF (source == dest) THEN
    3012        3547 :                         IF (found) CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3013             :                      ELSE
    3014       13764 :                         ALLOCATE (send_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3015      282773 :                         send_buff(n_mess)%array(:, :) = 0.0_dp
    3016        3441 :                         IF (found) THEN
    3017      194070 :                            send_buff(n_mess)%array(:, :) = blk(:, :)
    3018             :                         END IF
    3019        3441 :                         is = is + 1
    3020             :                         CALL para_env%isend(msgin=send_buff(n_mess)%array, dest=dest, &
    3021        3441 :                                             request=send_req(is), tag=tag)
    3022             :                      END IF
    3023        4388 :                      DEALLOCATE (blk)
    3024             :                   END IF
    3025             : 
    3026             :                   !receiving the message
    3027       21432 :                   IF (para_env%mepos == dest .AND. source .NE. dest) THEN
    3028       13764 :                      ALLOCATE (recv_buff(n_mess)%array(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3029        3441 :                      ir = ir + 1
    3030             :                      CALL para_env%irecv(msgout=recv_buff(n_mess)%array, source=source, &
    3031        3441 :                                          request=recv_req(ir), tag=tag)
    3032             :                   END IF
    3033             :                END DO !iatom
    3034             :             END DO !jatom
    3035             :          END DO !ispin
    3036             : 
    3037        4804 :          CALL mp_waitall(send_req(1:is))
    3038        4804 :          CALL mp_waitall(recv_req(1:ir))
    3039             : 
    3040             :          !Copy the messages received into the KS matrix
    3041        4804 :          n_mess = 0
    3042       11132 :          DO i_spin = 1, nspins
    3043       23788 :             DO jatom = 1, natom
    3044       44296 :                DO iatom = 1, natom
    3045       25312 :                   IF (sparsity_pattern(iatom, jatom, b_img) < 0) CYCLE
    3046        8776 :                   n_mess = n_mess + 1
    3047             : 
    3048       26328 :                   CALL dbt_get_stored_coordinates(ks_t(i_spin, b_img), [iatom, jatom], dest)
    3049       21432 :                   IF (para_env%mepos == dest) THEN
    3050        4388 :                      IF (.NOT. ASSOCIATED(recv_buff(n_mess)%array)) CYCLE
    3051       13764 :                      ALLOCATE (blk(ri_data%bsizes_AO(iatom), ri_data%bsizes_AO(jatom)))
    3052      282773 :                      blk(:, :) = recv_buff(n_mess)%array(:, :)
    3053       17205 :                      CALL dbt_put_block(ks_t(i_spin, b_img), [iatom, jatom], SHAPE(blk), blk)
    3054        3441 :                      DEALLOCATE (blk)
    3055             :                   END IF
    3056             :                END DO
    3057             :             END DO
    3058             :          END DO
    3059             : 
    3060             :          !clean-up
    3061       13580 :          DO i = 1, n_mess
    3062        8776 :             IF (ASSOCIATED(send_buff(i)%array)) DEALLOCATE (send_buff(i)%array)
    3063       13580 :             IF (ASSOCIATED(recv_buff(i)%array)) DEALLOCATE (recv_buff(i)%array)
    3064             :          END DO
    3065        4994 :          DEALLOCATE (send_buff, recv_buff, send_req, recv_req)
    3066             :       END DO !b_img
    3067             : 
    3068         190 :       CALL timestop(handle)
    3069             : 
    3070         190 :    END SUBROUTINE gather_ks_matrix
    3071             : 
    3072             : ! **************************************************************************************************
    3073             : !> \brief copy all required 2c tensors from the main MPI group to the subgroups
    3074             : !> \param mat_2c_pot ...
    3075             : !> \param t_2c_work ...
    3076             : !> \param t_2c_ao_tmp ...
    3077             : !> \param ks_t_split ...
    3078             : !> \param ks_t_sub ...
    3079             : !> \param group_size ...
    3080             : !> \param ngroups ...
    3081             : !> \param para_env ...
    3082             : !> \param para_env_sub ...
    3083             : !> \param ri_data ...
    3084             : ! **************************************************************************************************
    3085         190 :    SUBROUTINE get_subgroup_2c_tensors(mat_2c_pot, t_2c_work, t_2c_ao_tmp, ks_t_split, ks_t_sub, &
    3086             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3087             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3088             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work, t_2c_ao_tmp, ks_t_split
    3089             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: ks_t_sub
    3090             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3091             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3092             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3093             : 
    3094             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_tensors'
    3095             : 
    3096             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, iproc, &
    3097             :                                                             j, natom, nblks, nimg, nspins
    3098             :       INTEGER(int_8)                                     :: nze
    3099             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3100         190 :                                                             dist1, dist2
    3101             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3102         380 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3103         190 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3104             :       REAL(dp)                                           :: occ
    3105             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3106         570 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3107        2470 :       TYPE(dbt_type)                                     :: work, work_sub
    3108             : 
    3109         190 :       CALL timeset(routineN, handle)
    3110             : 
    3111             :       !Create the 2d pgrid
    3112         190 :       pdims_2d = 0
    3113         190 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3114             : 
    3115         190 :       natom = SIZE(ri_data%bsizes_RI)
    3116         190 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3117         570 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3118         570 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3119        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3120        3432 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3121        6064 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3122             :       END DO
    3123             : 
    3124             :       !nRI x nRI 2c tensors
    3125             :       CALL create_2c_tensor(t_2c_work(1), dist1, dist2, pgrid_2d, &
    3126             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3127             :                             name="(RI | RI)")
    3128         190 :       DEALLOCATE (dist1, dist2)
    3129             : 
    3130             :       CALL create_2c_tensor(t_2c_work(2), dist1, dist2, pgrid_2d, &
    3131             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3132         190 :                             name="(RI | RI)")
    3133         190 :       DEALLOCATE (dist1, dist2)
    3134             : 
    3135             :       !the AO based tensors
    3136             :       CALL create_2c_tensor(ks_t_split(1), dist1, dist2, pgrid_2d, &
    3137             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3138             :                             name="(AO | AO)")
    3139         190 :       DEALLOCATE (dist1, dist2)
    3140         190 :       CALL dbt_create(ks_t_split(1), ks_t_split(2))
    3141             : 
    3142             :       CALL create_2c_tensor(t_2c_ao_tmp(1), dist1, dist2, pgrid_2d, &
    3143             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, &
    3144             :                             name="(AO | AO)")
    3145         190 :       DEALLOCATE (dist1, dist2)
    3146             : 
    3147         190 :       nspins = SIZE(ks_t_sub, 1)
    3148         190 :       nimg = SIZE(ks_t_sub, 2)
    3149        4994 :       DO i_img = 1, nimg
    3150       11322 :          DO i_spin = 1, nspins
    3151       11132 :             CALL dbt_create(t_2c_ao_tmp(1), ks_t_sub(i_spin, i_img))
    3152             :          END DO
    3153             :       END DO
    3154             : 
    3155             :       !Finally the HFX potential matrices
    3156             :       !For now, we do a convoluted things where we go to tensors first, then back to matrices.
    3157             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3158             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3159             :                             name="(RI | RI)")
    3160         190 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3161             : 
    3162         760 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3163         190 :       iproc = 0
    3164         380 :       DO i = 0, pdims_2d(1) - 1
    3165         570 :          DO j = 0, pdims_2d(2) - 1
    3166         190 :             dbcsr_pgrid(i, j) = iproc
    3167         380 :             iproc = iproc + 1
    3168             :          END DO
    3169             :       END DO
    3170             : 
    3171             :       !We need to have the same exact 2d block dist as the tensors
    3172         760 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3173         570 :       row_dist(:) = dist1(:)
    3174         570 :       col_dist(:) = dist2(:)
    3175             : 
    3176         380 :       ALLOCATE (RI_blk_size(natom))
    3177         570 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3178             : 
    3179             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3180         190 :                                   row_dist=row_dist, col_dist=col_dist)
    3181             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3182         190 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3183             : 
    3184        4994 :       DO i_img = 1, nimg
    3185        4804 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3186        4804 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3187        4804 :          CALL get_tensor_occupancy(work, nze, occ)
    3188        4804 :          IF (nze == 0) CYCLE
    3189             : 
    3190        3708 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3191        3708 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3192        3708 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3193        8702 :          CALL dbt_clear(work_sub)
    3194             :       END DO
    3195             : 
    3196         190 :       CALL dbt_destroy(work)
    3197         190 :       CALL dbt_destroy(work_sub)
    3198         190 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3199         190 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3200         190 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3201         190 :       CALL timestop(handle)
    3202             : 
    3203        1710 :    END SUBROUTINE get_subgroup_2c_tensors
    3204             : 
    3205             : ! **************************************************************************************************
    3206             : !> \brief copy all required 3c tensors from the main MPI group to the subgroups
    3207             : !> \param t_3c_int ...
    3208             : !> \param t_3c_work_2 ...
    3209             : !> \param t_3c_work_3 ...
    3210             : !> \param t_3c_apc ...
    3211             : !> \param t_3c_apc_sub ...
    3212             : !> \param group_size ...
    3213             : !> \param ngroups ...
    3214             : !> \param para_env ...
    3215             : !> \param para_env_sub ...
    3216             : !> \param ri_data ...
    3217             : ! **************************************************************************************************
    3218         190 :    SUBROUTINE get_subgroup_3c_tensors(t_3c_int, t_3c_work_2, t_3c_work_3, t_3c_apc, t_3c_apc_sub, &
    3219             :                                       group_size, ngroups, para_env, para_env_sub, ri_data)
    3220             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_int, t_3c_work_2, t_3c_work_3
    3221             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_apc, t_3c_apc_sub
    3222             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3223             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3224             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3225             : 
    3226             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_tensors'
    3227             : 
    3228             :       INTEGER                                            :: batch_size, bfac, bo(2), handle, &
    3229             :                                                             handle2, i_blk, i_img, i_RI, i_spin, &
    3230             :                                                             ib, natom, nblks_AO, nblks_RI, nimg, &
    3231             :                                                             nspins
    3232             :       INTEGER(int_8)                                     :: nze
    3233         190 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3234         190 :                                                             bsizes_stack, bsizes_tmp, dist1, &
    3235         190 :                                                             dist2, dist3, dist_stack, idx_to_at
    3236             :       INTEGER, DIMENSION(3)                              :: pdims
    3237             :       REAL(dp)                                           :: occ
    3238        1710 :       TYPE(dbt_distribution_type)                        :: t_dist
    3239         570 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3240        4750 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3241             : 
    3242         190 :       CALL timeset(routineN, handle)
    3243             : 
    3244         190 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3245         570 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3246        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3247        6064 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3248             :       END DO
    3249             : 
    3250             :       !Preparing larger block sizes for efficient communication (less, bigger messages)
    3251             :       !we put 2 atoms per RI block
    3252         190 :       bfac = 2
    3253         190 :       natom = SIZE(ri_data%bsizes_RI)
    3254         190 :       nblks_RI = MAX(1, natom/bfac)
    3255         570 :       ALLOCATE (bsizes_tmp(nblks_RI))
    3256         380 :       DO i_blk = 1, nblks_RI
    3257         190 :          bo = get_limit(natom, nblks_RI, i_blk - 1)
    3258         760 :          bsizes_tmp(i_blk) = SUM(ri_data%bsizes_RI(bo(1):bo(2)))
    3259             :       END DO
    3260         570 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3261        1334 :       DO i_RI = 1, ri_data%ncell_RI
    3262        2478 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = bsizes_tmp(:)
    3263             :       END DO
    3264             : 
    3265         190 :       batch_size = ri_data%kp_stack_size
    3266         190 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3267         570 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3268        6270 :       DO ib = 1, batch_size
    3269       29886 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3270             :       END DO
    3271             : 
    3272             :       !Create the pgrid for the configuration correspoinding to ri_data%t_3c_int_ctr_3
    3273         190 :       natom = SIZE(ri_data%bsizes_RI)
    3274         190 :       pdims = 0
    3275             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3276         760 :                             tensor_dims=[SIZE(bsizes_RI_ext_split), 1, batch_size*SIZE(ri_data%bsizes_AO_split)])
    3277             : 
    3278             :       !Create all required 3c tensors in that configuration
    3279             :       CALL create_3c_tensor(t_3c_int(1), dist1, dist2, dist3, &
    3280             :                             pgrid, bsizes_RI_ext_split, ri_data%bsizes_AO_split, &
    3281         190 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(RI | AO AO)")
    3282         190 :       nimg = SIZE(t_3c_int)
    3283        4804 :       DO i_img = 2, nimg
    3284        4804 :          CALL dbt_create(t_3c_int(1), t_3c_int(i_img))
    3285             :       END DO
    3286             : 
    3287             :       !The stacked work tensors, in a distribution that matches that of t_3c_int
    3288         380 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3289        6270 :       DO ib = 1, batch_size
    3290       29886 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3291             :       END DO
    3292             : 
    3293         190 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3294             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3295         190 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3296         190 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3297         190 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3298         190 :       CALL dbt_distribution_destroy(t_dist)
    3299         190 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3300             : 
    3301             :       !For more efficient communication, we use intermediate tensors with larger block size
    3302             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3303             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3304         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3305         190 :       DEALLOCATE (dist1, dist2, dist3)
    3306             : 
    3307             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3308             :                             ri_data%pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3309         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3310         190 :       DEALLOCATE (dist1, dist2, dist3)
    3311             : 
    3312             :       !Finally copy the integrals into the subgroups (if not there already)
    3313         190 :       CALL timeset(routineN//"_ints", handle2)
    3314         190 :       IF (ALLOCATED(ri_data%kp_t_3c_int)) THEN
    3315        3208 :          DO i_img = 1, nimg
    3316        3208 :             CALL dbt_copy(ri_data%kp_t_3c_int(i_img), t_3c_int(i_img), move_data=.TRUE.)
    3317             :          END DO
    3318             :       ELSE
    3319        2486 :          ALLOCATE (ri_data%kp_t_3c_int(nimg))
    3320        1786 :          DO i_img = 1, nimg
    3321        1716 :             CALL dbt_create(t_3c_int(i_img), ri_data%kp_t_3c_int(i_img))
    3322        1716 :             CALL get_tensor_occupancy(ri_data%t_3c_int_ctr_1(1, i_img), nze, occ)
    3323        1716 :             IF (nze == 0) CYCLE
    3324        1236 :             CALL dbt_copy(ri_data%t_3c_int_ctr_1(1, i_img), work_atom_block, order=[2, 1, 3])
    3325        1236 :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, ngroups, para_env)
    3326        3022 :             CALL dbt_copy(work_atom_block_sub, t_3c_int(i_img), move_data=.TRUE.)
    3327             :          END DO
    3328             :       END IF
    3329         190 :       CALL timestop(handle2)
    3330         190 :       CALL dbt_pgrid_destroy(pgrid)
    3331         190 :       CALL dbt_destroy(work_atom_block)
    3332         190 :       CALL dbt_destroy(work_atom_block_sub)
    3333             : 
    3334             :       !Do the same for the t_3c_ctr_2 configuration
    3335         190 :       pdims = 0
    3336             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3337         760 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3338             : 
    3339             :       !For more efficient communication, we use intermediate tensors with larger block size
    3340             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3341             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3342         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3343         190 :       DEALLOCATE (dist1, dist2, dist3)
    3344             : 
    3345             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3346             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3347         190 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3348         190 :       DEALLOCATE (dist1, dist2, dist3)
    3349             : 
    3350             :       !template for t_3c_apc_sub
    3351             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3352             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3353         190 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3354             : 
    3355             :       !create t_3c_work_2 tensors in a distribution that matches the above
    3356         380 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3357        6270 :       DO ib = 1, batch_size
    3358       29886 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3359             :       END DO
    3360             : 
    3361         190 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3362             :       CALL dbt_create(t_3c_work_2(1), "work_2_stack", t_dist, [1], [2, 3], &
    3363         190 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3364         190 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3365         190 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3366         190 :       CALL dbt_distribution_destroy(t_dist)
    3367         190 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3368             : 
    3369             :       !Finally copy data from t_3c_apc to the subgroups
    3370         570 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3371         190 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3372         190 :       nspins = SIZE(t_3c_apc, 1)
    3373         190 :       CALL timeset(routineN//"_apc", handle2)
    3374        4994 :       DO i_img = 1, nimg
    3375       11132 :          DO i_spin = 1, nspins
    3376        6328 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3377        6328 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3378        6328 :             IF (nze == 0) CYCLE
    3379        5462 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3380             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3381        5462 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3382       16594 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3383             :          END DO
    3384       11322 :          DO i_spin = 1, nspins
    3385       11132 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3386             :          END DO
    3387             :       END DO
    3388         190 :       CALL timestop(handle2)
    3389         190 :       CALL dbt_pgrid_destroy(pgrid)
    3390         190 :       CALL dbt_destroy(tmp)
    3391         190 :       CALL dbt_destroy(work_atom_block)
    3392         190 :       CALL dbt_destroy(work_atom_block_sub)
    3393             : 
    3394         190 :       CALL timestop(handle)
    3395             : 
    3396         760 :    END SUBROUTINE get_subgroup_3c_tensors
    3397             : 
    3398             : ! **************************************************************************************************
    3399             : !> \brief copy all required 2c force tensors from the main MPI group to the subgroups
    3400             : !> \param t_2c_inv ...
    3401             : !> \param t_2c_bint ...
    3402             : !> \param t_2c_metric ...
    3403             : !> \param mat_2c_pot ...
    3404             : !> \param t_2c_work ...
    3405             : !> \param rho_ao_t ...
    3406             : !> \param rho_ao_t_sub ...
    3407             : !> \param t_2c_der_metric ...
    3408             : !> \param t_2c_der_metric_sub ...
    3409             : !> \param mat_der_pot ...
    3410             : !> \param mat_der_pot_sub ...
    3411             : !> \param group_size ...
    3412             : !> \param ngroups ...
    3413             : !> \param para_env ...
    3414             : !> \param para_env_sub ...
    3415             : !> \param ri_data ...
    3416             : !> \note Main MPI group tensors are deleted within this routine, for memory optimization
    3417             : ! **************************************************************************************************
    3418          84 :    SUBROUTINE get_subgroup_2c_derivs(t_2c_inv, t_2c_bint, t_2c_metric, mat_2c_pot, t_2c_work, rho_ao_t, &
    3419          42 :                                      rho_ao_t_sub, t_2c_der_metric, t_2c_der_metric_sub, mat_der_pot, &
    3420          42 :                                      mat_der_pot_sub, group_size, ngroups, para_env, para_env_sub, ri_data)
    3421             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_inv, t_2c_bint, t_2c_metric
    3422             :       TYPE(dbcsr_type), DIMENSION(:), INTENT(INOUT)      :: mat_2c_pot
    3423             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_work
    3424             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: rho_ao_t, rho_ao_t_sub, t_2c_der_metric, &
    3425             :                                                             t_2c_der_metric_sub
    3426             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot, mat_der_pot_sub
    3427             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3428             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3429             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3430             : 
    3431             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_2c_derivs'
    3432             : 
    3433             :       INTEGER                                            :: handle, i, i_img, i_RI, i_spin, i_xyz, &
    3434             :                                                             iatom, iproc, j, natom, nblks, nimg, &
    3435             :                                                             nspins
    3436             :       INTEGER(int_8)                                     :: nze
    3437             :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3438          42 :                                                             dist1, dist2
    3439             :       INTEGER, DIMENSION(2)                              :: pdims_2d
    3440          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_dist, RI_blk_size, row_dist
    3441          42 :       INTEGER, DIMENSION(:, :), POINTER                  :: dbcsr_pgrid
    3442             :       REAL(dp)                                           :: occ
    3443             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist_sub
    3444         126 :       TYPE(dbt_pgrid_type)                               :: pgrid_2d
    3445         546 :       TYPE(dbt_type)                                     :: work, work_sub
    3446             : 
    3447          42 :       CALL timeset(routineN, handle)
    3448             : 
    3449             :       !Note: a fair portion of this routine is copied from the energy version of it
    3450             :       !Create the 2d pgrid
    3451          42 :       pdims_2d = 0
    3452          42 :       CALL dbt_pgrid_create(para_env_sub, pdims_2d, pgrid_2d)
    3453             : 
    3454          42 :       natom = SIZE(ri_data%bsizes_RI)
    3455          42 :       nblks = SIZE(ri_data%bsizes_RI_split)
    3456         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*natom))
    3457         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks))
    3458         294 :       DO i_RI = 1, ri_data%ncell_RI
    3459         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    3460        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks + 1:i_RI*nblks) = ri_data%bsizes_RI_split(:)
    3461             :       END DO
    3462             : 
    3463             :       !nRI x nRI 2c tensors
    3464             :       CALL create_2c_tensor(t_2c_inv(1), dist1, dist2, pgrid_2d, &
    3465             :                             bsizes_RI_ext, bsizes_RI_ext, &
    3466             :                             name="(RI | RI)")
    3467          42 :       DEALLOCATE (dist1, dist2)
    3468             : 
    3469          42 :       CALL dbt_create(t_2c_inv(1), t_2c_bint(1))
    3470          42 :       CALL dbt_create(t_2c_inv(1), t_2c_metric(1))
    3471          84 :       DO iatom = 2, natom
    3472          42 :          CALL dbt_create(t_2c_inv(1), t_2c_inv(iatom))
    3473          42 :          CALL dbt_create(t_2c_inv(1), t_2c_bint(iatom))
    3474          84 :          CALL dbt_create(t_2c_inv(1), t_2c_metric(iatom))
    3475             :       END DO
    3476          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(1))
    3477          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(2))
    3478          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(3))
    3479          42 :       CALL dbt_create(t_2c_inv(1), t_2c_work(4))
    3480             : 
    3481             :       CALL create_2c_tensor(t_2c_work(5), dist1, dist2, pgrid_2d, &
    3482             :                             bsizes_RI_ext_split, bsizes_RI_ext_split, &
    3483          42 :                             name="(RI | RI)")
    3484          42 :       DEALLOCATE (dist1, dist2)
    3485             : 
    3486             :       !copy the data from the main group.
    3487         126 :       DO iatom = 1, natom
    3488          84 :          CALL copy_2c_to_subgroup(t_2c_inv(iatom), ri_data%t_2c_inv(1, iatom), group_size, ngroups, para_env)
    3489          84 :          CALL copy_2c_to_subgroup(t_2c_bint(iatom), ri_data%t_2c_int(1, iatom), group_size, ngroups, para_env)
    3490         126 :          CALL copy_2c_to_subgroup(t_2c_metric(iatom), ri_data%t_2c_pot(1, iatom), group_size, ngroups, para_env)
    3491             :       END DO
    3492             : 
    3493             :       !This includes the derivatives of the RI metric, for which there is one per atom
    3494         168 :       DO i_xyz = 1, 3
    3495         420 :          DO iatom = 1, natom
    3496         252 :             CALL dbt_create(t_2c_inv(1), t_2c_der_metric_sub(iatom, i_xyz))
    3497             :             CALL copy_2c_to_subgroup(t_2c_der_metric_sub(iatom, i_xyz), t_2c_der_metric(iatom, i_xyz), &
    3498         252 :                                      group_size, ngroups, para_env)
    3499         378 :             CALL dbt_destroy(t_2c_der_metric(iatom, i_xyz))
    3500             :          END DO
    3501             :       END DO
    3502             : 
    3503             :       !AO x AO 2c tensors
    3504             :       CALL create_2c_tensor(rho_ao_t_sub(1, 1), dist1, dist2, pgrid_2d, &
    3505             :                             ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    3506             :                             name="(AO | AO)")
    3507          42 :       DEALLOCATE (dist1, dist2)
    3508          42 :       nspins = SIZE(rho_ao_t, 1)
    3509          42 :       nimg = SIZE(rho_ao_t, 2)
    3510             : 
    3511        1052 :       DO i_img = 1, nimg
    3512        2228 :          DO i_spin = 1, nspins
    3513        1176 :             IF (.NOT. (i_img == 1 .AND. i_spin == 1)) &
    3514        1134 :                CALL dbt_create(rho_ao_t_sub(1, 1), rho_ao_t_sub(i_spin, i_img))
    3515             :             CALL copy_2c_to_subgroup(rho_ao_t_sub(i_spin, i_img), rho_ao_t(i_spin, i_img), &
    3516        1176 :                                      group_size, ngroups, para_env)
    3517        2186 :             CALL dbt_destroy(rho_ao_t(i_spin, i_img))
    3518             :          END DO
    3519             :       END DO
    3520             : 
    3521             :       !The RIxRI matrices, going through tensors
    3522             :       CALL create_2c_tensor(work_sub, dist1, dist2, pgrid_2d, &
    3523             :                             ri_data%bsizes_RI, ri_data%bsizes_RI, &
    3524             :                             name="(RI | RI)")
    3525          42 :       CALL dbt_create(ri_data%kp_mat_2c_pot(1, 1), work)
    3526             : 
    3527         168 :       ALLOCATE (dbcsr_pgrid(0:pdims_2d(1) - 1, 0:pdims_2d(2) - 1))
    3528          42 :       iproc = 0
    3529          84 :       DO i = 0, pdims_2d(1) - 1
    3530         126 :          DO j = 0, pdims_2d(2) - 1
    3531          42 :             dbcsr_pgrid(i, j) = iproc
    3532          84 :             iproc = iproc + 1
    3533             :          END DO
    3534             :       END DO
    3535             : 
    3536             :       !We need to have the same exact 2d block dist as the tensors
    3537         168 :       ALLOCATE (col_dist(natom), row_dist(natom))
    3538         126 :       row_dist(:) = dist1(:)
    3539         126 :       col_dist(:) = dist2(:)
    3540             : 
    3541          84 :       ALLOCATE (RI_blk_size(natom))
    3542         126 :       RI_blk_size(:) = ri_data%bsizes_RI(:)
    3543             : 
    3544             :       CALL dbcsr_distribution_new(dbcsr_dist_sub, group=para_env_sub%get_handle(), pgrid=dbcsr_pgrid, &
    3545          42 :                                   row_dist=row_dist, col_dist=col_dist)
    3546             :       CALL dbcsr_create(mat_2c_pot(1), dist=dbcsr_dist_sub, name="sub", matrix_type=dbcsr_type_no_symmetry, &
    3547          42 :                         row_blk_size=RI_blk_size, col_blk_size=RI_blk_size)
    3548             : 
    3549             :       !The HFX potential
    3550        1052 :       DO i_img = 1, nimg
    3551        1010 :          IF (i_img > 1) CALL dbcsr_create(mat_2c_pot(i_img), template=mat_2c_pot(1))
    3552        1010 :          CALL dbt_copy_matrix_to_tensor(ri_data%kp_mat_2c_pot(1, i_img), work)
    3553        1010 :          CALL get_tensor_occupancy(work, nze, occ)
    3554        1010 :          IF (nze == 0) CYCLE
    3555             : 
    3556         654 :          CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3557         654 :          CALL dbt_copy_tensor_to_matrix(work_sub, mat_2c_pot(i_img))
    3558         654 :          CALL dbcsr_filter(mat_2c_pot(i_img), ri_data%filter_eps)
    3559        1706 :          CALL dbt_clear(work_sub)
    3560             :       END DO
    3561             : 
    3562             :       !The derivatives of the HFX potential
    3563         168 :       DO i_xyz = 1, 3
    3564        3198 :          DO i_img = 1, nimg
    3565        3030 :             CALL dbcsr_create(mat_der_pot_sub(i_img, i_xyz), template=mat_2c_pot(1))
    3566        3030 :             CALL dbt_copy_matrix_to_tensor(mat_der_pot(i_img, i_xyz), work)
    3567        3030 :             CALL dbcsr_release(mat_der_pot(i_img, i_xyz))
    3568        3030 :             CALL get_tensor_occupancy(work, nze, occ)
    3569        3030 :             IF (nze == 0) CYCLE
    3570             : 
    3571        1958 :             CALL copy_2c_to_subgroup(work_sub, work, group_size, ngroups, para_env)
    3572        1958 :             CALL dbt_copy_tensor_to_matrix(work_sub, mat_der_pot_sub(i_img, i_xyz))
    3573        1958 :             CALL dbcsr_filter(mat_der_pot_sub(i_img, i_xyz), ri_data%filter_eps)
    3574        5114 :             CALL dbt_clear(work_sub)
    3575             :          END DO
    3576             :       END DO
    3577             : 
    3578          42 :       CALL dbt_destroy(work)
    3579          42 :       CALL dbt_destroy(work_sub)
    3580          42 :       CALL dbt_pgrid_destroy(pgrid_2d)
    3581          42 :       CALL dbcsr_distribution_release(dbcsr_dist_sub)
    3582          42 :       DEALLOCATE (col_dist, row_dist, RI_blk_size, dbcsr_pgrid)
    3583             : 
    3584          42 :       CALL timestop(handle)
    3585             : 
    3586         336 :    END SUBROUTINE get_subgroup_2c_derivs
    3587             : 
    3588             : ! **************************************************************************************************
    3589             : !> \brief copy all required 3c derivative tensors from the main MPI group to the subgroups
    3590             : !> \param t_3c_work_2 ...
    3591             : !> \param t_3c_work_3 ...
    3592             : !> \param t_3c_der_AO ...
    3593             : !> \param t_3c_der_AO_sub ...
    3594             : !> \param t_3c_der_RI ...
    3595             : !> \param t_3c_der_RI_sub ...
    3596             : !> \param t_3c_apc ...
    3597             : !> \param t_3c_apc_sub ...
    3598             : !> \param t_3c_der_stack ...
    3599             : !> \param group_size ...
    3600             : !> \param ngroups ...
    3601             : !> \param para_env ...
    3602             : !> \param para_env_sub ...
    3603             : !> \param ri_data ...
    3604             : !> \note the tensor containing the derivatives in the main MPI group are deleted for memory
    3605             : ! **************************************************************************************************
    3606          42 :    SUBROUTINE get_subgroup_3c_derivs(t_3c_work_2, t_3c_work_3, t_3c_der_AO, t_3c_der_AO_sub, &
    3607          42 :                                      t_3c_der_RI, t_3c_der_RI_sub, t_3c_apc, t_3c_apc_sub, &
    3608          42 :                                      t_3c_der_stack, group_size, ngroups, para_env, para_env_sub, &
    3609             :                                      ri_data)
    3610             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_work_2, t_3c_work_3
    3611             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_AO, t_3c_der_AO_sub, &
    3612             :                                                             t_3c_der_RI, t_3c_der_RI_sub, &
    3613             :                                                             t_3c_apc, t_3c_apc_sub
    3614             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_der_stack
    3615             :       INTEGER, INTENT(IN)                                :: group_size, ngroups
    3616             :       TYPE(mp_para_env_type), POINTER                    :: para_env, para_env_sub
    3617             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3618             : 
    3619             :       CHARACTER(len=*), PARAMETER :: routineN = 'get_subgroup_3c_derivs'
    3620             : 
    3621             :       INTEGER                                            :: batch_size, handle, i_img, i_RI, i_spin, &
    3622             :                                                             i_xyz, ib, nblks_AO, nblks_RI, nimg, &
    3623             :                                                             nspins, pdims(3)
    3624             :       INTEGER(int_8)                                     :: nze
    3625          42 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_RI_ext_split, &
    3626          42 :                                                             bsizes_stack, dist1, dist2, dist3, &
    3627          42 :                                                             dist_stack, idx_to_at
    3628             :       REAL(dp)                                           :: occ
    3629         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    3630         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    3631        1050 :       TYPE(dbt_type)                                     :: tmp, work_atom_block, work_atom_block_sub
    3632             : 
    3633          42 :       CALL timeset(routineN, handle)
    3634             : 
    3635             :       !We use intermediate tensors with larger block size for more optimized communication
    3636          42 :       nblks_RI = SIZE(ri_data%bsizes_RI)
    3637         126 :       ALLOCATE (bsizes_RI_ext(ri_data%ncell_RI*nblks_RI))
    3638         294 :       DO i_RI = 1, ri_data%ncell_RI
    3639         798 :          bsizes_RI_ext((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI(:)
    3640             :       END DO
    3641             : 
    3642          42 :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), pdims=pdims)
    3643          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3644             : 
    3645             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3646             :                             pgrid, bsizes_RI_ext, ri_data%bsizes_AO, &
    3647          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3648          42 :       DEALLOCATE (dist1, dist2, dist3)
    3649             : 
    3650             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3651             :                             ri_data%pgrid_2, bsizes_RI_ext, ri_data%bsizes_AO, &
    3652          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(RI | AO AO)")
    3653          42 :       DEALLOCATE (dist1, dist2, dist3)
    3654          42 :       CALL dbt_pgrid_destroy(pgrid)
    3655             : 
    3656             :       !We use the 3c integrals on the subgroup as template for the derivatives
    3657          42 :       nimg = ri_data%nimg
    3658         168 :       DO i_xyz = 1, 3
    3659        3156 :          DO i_img = 1, nimg
    3660        3030 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_AO_sub(i_img, i_xyz))
    3661        3030 :             CALL get_tensor_occupancy(t_3c_der_AO(i_img, i_xyz), nze, occ)
    3662        3030 :             IF (nze == 0) CYCLE
    3663             : 
    3664        1930 :             CALL dbt_copy(t_3c_der_AO(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3665             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3666        1930 :                                      group_size, ngroups, para_env)
    3667        5086 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_AO_sub(i_img, i_xyz), move_data=.TRUE.)
    3668             :          END DO
    3669             : 
    3670        3156 :          DO i_img = 1, nimg
    3671        3030 :             CALL dbt_create(ri_data%kp_t_3c_int(1), t_3c_der_RI_sub(i_img, i_xyz))
    3672        3030 :             CALL get_tensor_occupancy(t_3c_der_RI(i_img, i_xyz), nze, occ)
    3673        3030 :             IF (nze == 0) CYCLE
    3674             : 
    3675        1910 :             CALL dbt_copy(t_3c_der_RI(i_img, i_xyz), work_atom_block, move_data=.TRUE.)
    3676             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, &
    3677        1910 :                                      group_size, ngroups, para_env)
    3678        5066 :             CALL dbt_copy(work_atom_block_sub, t_3c_der_RI_sub(i_img, i_xyz), move_data=.TRUE.)
    3679             :          END DO
    3680             : 
    3681        3198 :          DO i_img = 1, nimg
    3682        3030 :             CALL dbt_destroy(t_3c_der_RI(i_img, i_xyz))
    3683        3156 :             CALL dbt_destroy(t_3c_der_AO(i_img, i_xyz))
    3684             :          END DO
    3685             :       END DO
    3686          42 :       CALL dbt_destroy(work_atom_block_sub)
    3687          42 :       CALL dbt_destroy(work_atom_block)
    3688             : 
    3689             :       !Deal with t_3c_apc
    3690          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    3691         126 :       ALLOCATE (bsizes_RI_ext_split(ri_data%ncell_RI*nblks_RI))
    3692         294 :       DO i_RI = 1, ri_data%ncell_RI
    3693        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    3694             :       END DO
    3695             : 
    3696          42 :       pdims = 0
    3697             :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid, &
    3698         168 :                             tensor_dims=[1, SIZE(bsizes_RI_ext_split), batch_size*SIZE(ri_data%bsizes_AO_split)])
    3699             : 
    3700             :       CALL create_3c_tensor(work_atom_block_sub, dist1, dist2, dist3, &
    3701             :                             pgrid, ri_data%bsizes_AO, bsizes_RI_ext, &
    3702          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3703          42 :       DEALLOCATE (dist1, dist2, dist3)
    3704             : 
    3705             :       CALL create_3c_tensor(work_atom_block, dist1, dist2, dist3, &
    3706             :                             ri_data%pgrid_1, ri_data%bsizes_AO, bsizes_RI_ext, &
    3707          42 :                             ri_data%bsizes_AO, [1], [2, 3], name="(AO RI | AO)")
    3708          42 :       DEALLOCATE (dist1, dist2, dist3)
    3709             : 
    3710             :       CALL create_3c_tensor(tmp, dist1, dist2, dist3, &
    3711             :                             pgrid, ri_data%bsizes_AO_split, bsizes_RI_ext_split, &
    3712          42 :                             ri_data%bsizes_AO_split, [1], [2, 3], name="(AO RI | AO)")
    3713          42 :       DEALLOCATE (dist1, dist2, dist3)
    3714             : 
    3715         126 :       ALLOCATE (idx_to_at(SIZE(ri_data%bsizes_AO)))
    3716          42 :       CALL get_idx_to_atom(idx_to_at, ri_data%bsizes_AO, ri_data%bsizes_AO)
    3717          42 :       nspins = SIZE(t_3c_apc, 1)
    3718        1052 :       DO i_img = 1, nimg
    3719        2186 :          DO i_spin = 1, nspins
    3720        1176 :             CALL dbt_create(tmp, t_3c_apc_sub(i_spin, i_img))
    3721        1176 :             CALL get_tensor_occupancy(t_3c_apc(i_spin, i_img), nze, occ)
    3722        1176 :             IF (nze == 0) CYCLE
    3723        1152 :             CALL dbt_copy(t_3c_apc(i_spin, i_img), work_atom_block, move_data=.TRUE.)
    3724             :             CALL copy_3c_to_subgroup(work_atom_block_sub, work_atom_block, group_size, &
    3725        1152 :                                      ngroups, para_env, ri_data%iatom_to_subgroup, 1, idx_to_at)
    3726        3338 :             CALL dbt_copy(work_atom_block_sub, t_3c_apc_sub(i_spin, i_img), move_data=.TRUE.)
    3727             :          END DO
    3728        2228 :          DO i_spin = 1, nspins
    3729        2186 :             CALL dbt_destroy(t_3c_apc(i_spin, i_img))
    3730             :          END DO
    3731             :       END DO
    3732          42 :       CALL dbt_destroy(tmp)
    3733          42 :       CALL dbt_destroy(work_atom_block)
    3734          42 :       CALL dbt_destroy(work_atom_block_sub)
    3735          42 :       CALL dbt_pgrid_destroy(pgrid)
    3736             : 
    3737             :       !t_3c_work_3 based on structure of 3c integrals/derivs
    3738          42 :       batch_size = ri_data%kp_stack_size
    3739          42 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    3740         126 :       ALLOCATE (bsizes_stack(batch_size*nblks_AO))
    3741        1386 :       DO ib = 1, batch_size
    3742        6826 :          bsizes_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = ri_data%bsizes_AO_split(:)
    3743             :       END DO
    3744             : 
    3745         294 :       ALLOCATE (dist1(ri_data%ncell_RI*nblks_RI), dist2(nblks_AO), dist3(nblks_AO))
    3746             :       CALL dbt_get_info(ri_data%kp_t_3c_int(1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3747          42 :                         proc_dist_3=dist3, pdims=pdims)
    3748             : 
    3749         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3750        1386 :       DO ib = 1, batch_size
    3751        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3752             :       END DO
    3753             : 
    3754          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3755          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3756             :       CALL dbt_create(t_3c_work_3(1), "work_3_stack", t_dist, [1], [2, 3], &
    3757          42 :                       bsizes_RI_ext_split, ri_data%bsizes_AO_split, bsizes_stack)
    3758          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(2))
    3759          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(3))
    3760          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_work_3(4))
    3761          42 :       CALL dbt_distribution_destroy(t_dist)
    3762          42 :       CALL dbt_pgrid_destroy(pgrid)
    3763          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3764             : 
    3765             :       !the derivatives are stacked in the same way
    3766          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(1))
    3767          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(2))
    3768          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(3))
    3769          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(4))
    3770          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(5))
    3771          42 :       CALL dbt_create(t_3c_work_3(1), t_3c_der_stack(6))
    3772             : 
    3773             :       !t_3c_work_2 based on structure of t_3c_apc
    3774         294 :       ALLOCATE (dist1(nblks_AO), dist2(ri_data%ncell_RI*nblks_RI), dist3(nblks_AO))
    3775             :       CALL dbt_get_info(t_3c_apc_sub(1, 1), proc_dist_1=dist1, proc_dist_2=dist2, &
    3776          42 :                         proc_dist_3=dist3, pdims=pdims)
    3777             : 
    3778         126 :       ALLOCATE (dist_stack(batch_size*nblks_AO))
    3779        1386 :       DO ib = 1, batch_size
    3780        6826 :          dist_stack((ib - 1)*nblks_AO + 1:ib*nblks_AO) = dist3(:)
    3781             :       END DO
    3782             : 
    3783          42 :       CALL dbt_pgrid_create(para_env_sub, pdims, pgrid)
    3784          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist1, dist2, dist_stack)
    3785             :       CALL dbt_create(t_3c_work_2(1), "work_3_stack", t_dist, [1], [2, 3], &
    3786          42 :                       ri_data%bsizes_AO_split, bsizes_RI_ext_split, bsizes_stack)
    3787          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(2))
    3788          42 :       CALL dbt_create(t_3c_work_2(1), t_3c_work_2(3))
    3789          42 :       CALL dbt_distribution_destroy(t_dist)
    3790          42 :       CALL dbt_pgrid_destroy(pgrid)
    3791          42 :       DEALLOCATE (dist1, dist2, dist3, dist_stack)
    3792             : 
    3793          42 :       CALL timestop(handle)
    3794             : 
    3795          84 :    END SUBROUTINE get_subgroup_3c_derivs
    3796             : 
    3797             : ! **************************************************************************************************
    3798             : !> \brief A routine that reorders the t_3c_int tensors such that all items which are fully empty
    3799             : !>        are bunched together. This way, we can get much more efficient screening based on NZE
    3800             : !> \param t_3c_ints ...
    3801             : !> \param ri_data ...
    3802             : ! **************************************************************************************************
    3803          70 :    SUBROUTINE reorder_3c_ints(t_3c_ints, ri_data)
    3804             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_ints
    3805             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3806             : 
    3807             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_ints'
    3808             : 
    3809             :       INTEGER                                            :: handle, i_img, idx, idx_empty, idx_full, &
    3810             :                                                             nimg
    3811             :       INTEGER(int_8)                                     :: nze
    3812             :       REAL(dp)                                           :: occ
    3813          70 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3814             : 
    3815          70 :       CALL timeset(routineN, handle)
    3816             : 
    3817          70 :       nimg = ri_data%nimg
    3818        2486 :       ALLOCATE (t_3c_tmp(nimg))
    3819        1786 :       DO i_img = 1, nimg
    3820        1716 :          CALL dbt_create(t_3c_ints(i_img), t_3c_tmp(i_img))
    3821        1786 :          CALL dbt_copy(t_3c_ints(i_img), t_3c_tmp(i_img), move_data=.TRUE.)
    3822             :       END DO
    3823             : 
    3824             :       !Loop over the images, check if ints have NZE == 0, and put them at the start or end of the
    3825             :       !initial tensor array. Keep the mapping in an array
    3826         210 :       ALLOCATE (ri_data%idx_to_img(nimg))
    3827          70 :       idx_full = 0
    3828          70 :       idx_empty = nimg + 1
    3829             : 
    3830        1786 :       DO i_img = 1, nimg
    3831        1716 :          CALL get_tensor_occupancy(t_3c_tmp(i_img), nze, occ)
    3832        1716 :          IF (nze == 0) THEN
    3833         480 :             idx_empty = idx_empty - 1
    3834         480 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_empty), move_data=.TRUE.)
    3835         480 :             ri_data%idx_to_img(idx_empty) = i_img
    3836             :          ELSE
    3837        1236 :             idx_full = idx_full + 1
    3838        1236 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_ints(idx_full), move_data=.TRUE.)
    3839        1236 :             ri_data%idx_to_img(idx_full) = i_img
    3840             :          END IF
    3841        3502 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3842             :       END DO
    3843             : 
    3844             :       !store the highest image index with non-zero integrals
    3845          70 :       ri_data%nimg_nze = idx_full
    3846             : 
    3847         140 :       ALLOCATE (ri_data%img_to_idx(nimg))
    3848        1786 :       DO idx = 1, nimg
    3849        1786 :          ri_data%img_to_idx(ri_data%idx_to_img(idx)) = idx
    3850             :       END DO
    3851             : 
    3852          70 :       CALL timestop(handle)
    3853             : 
    3854        1856 :    END SUBROUTINE reorder_3c_ints
    3855             : 
    3856             : ! **************************************************************************************************
    3857             : !> \brief A routine that reorders the 3c derivatives, the same way that the integrals are, also to
    3858             : !>        increase efficiency of screening
    3859             : !> \param t_3c_derivs ...
    3860             : !> \param ri_data ...
    3861             : ! **************************************************************************************************
    3862          84 :    SUBROUTINE reorder_3c_derivs(t_3c_derivs, ri_data)
    3863             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_derivs
    3864             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3865             : 
    3866             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'reorder_3c_derivs'
    3867             : 
    3868             :       INTEGER                                            :: handle, i_img, i_xyz, idx, nimg
    3869             :       INTEGER(int_8)                                     :: nze
    3870             :       REAL(dp)                                           :: occ
    3871          84 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: t_3c_tmp
    3872             : 
    3873          84 :       CALL timeset(routineN, handle)
    3874             : 
    3875          84 :       nimg = ri_data%nimg
    3876        2944 :       ALLOCATE (t_3c_tmp(nimg))
    3877        2104 :       DO i_img = 1, nimg
    3878        2104 :          CALL dbt_create(t_3c_derivs(1, 1), t_3c_tmp(i_img))
    3879             :       END DO
    3880             : 
    3881         336 :       DO i_xyz = 1, 3
    3882        6312 :          DO i_img = 1, nimg
    3883        6312 :             CALL dbt_copy(t_3c_derivs(i_img, i_xyz), t_3c_tmp(i_img), move_data=.TRUE.)
    3884             :          END DO
    3885        6396 :          DO i_img = 1, nimg
    3886        6060 :             idx = ri_data%img_to_idx(i_img)
    3887        6060 :             CALL dbt_copy(t_3c_tmp(i_img), t_3c_derivs(idx, i_xyz), move_data=.TRUE.)
    3888        6060 :             CALL get_tensor_occupancy(t_3c_derivs(idx, i_xyz), nze, occ)
    3889        6312 :             IF (nze > 0) ri_data%nimg_nze = MAX(idx, ri_data%nimg_nze)
    3890             :          END DO
    3891             :       END DO
    3892             : 
    3893        2104 :       DO i_img = 1, nimg
    3894        2104 :          CALL dbt_destroy(t_3c_tmp(i_img))
    3895             :       END DO
    3896             : 
    3897          84 :       CALL timestop(handle)
    3898             : 
    3899        2188 :    END SUBROUTINE reorder_3c_derivs
    3900             : 
    3901             : ! **************************************************************************************************
    3902             : !> \brief Get the sparsity pattern related to the non-symmetric AO basis overlap neighbor list
    3903             : !> \param pattern ...
    3904             : !> \param ri_data ...
    3905             : !> \param qs_env ...
    3906             : ! **************************************************************************************************
    3907         232 :    SUBROUTINE get_sparsity_pattern(pattern, ri_data, qs_env)
    3908             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: pattern
    3909             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    3910             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    3911             : 
    3912             :       INTEGER                                            :: iatom, j_img, jatom, mj_img, natom, nimg
    3913         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bins
    3914         232 :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: tmp_pattern
    3915             :       INTEGER, DIMENSION(3)                              :: cell_j
    3916         232 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    3917         232 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    3918             :       TYPE(dft_control_type), POINTER                    :: dft_control
    3919             :       TYPE(kpoint_type), POINTER                         :: kpoints
    3920             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    3921             :       TYPE(neighbor_list_iterator_p_type), &
    3922         232 :          DIMENSION(:), POINTER                           :: nl_iterator
    3923             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    3924         232 :          POINTER                                         :: nl_2c
    3925             : 
    3926         232 :       NULLIFY (nl_2c, nl_iterator, kpoints, cell_to_index, dft_control, index_to_cell, para_env)
    3927             : 
    3928         232 :       CALL get_qs_env(qs_env, kpoints=kpoints, dft_control=dft_control, para_env=para_env, natom=natom)
    3929         232 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index, index_to_cell=index_to_cell, sab_nl=nl_2c)
    3930             : 
    3931         232 :       nimg = ri_data%nimg
    3932       40930 :       pattern(:, :, :) = 0
    3933             : 
    3934             :       !We use the symmetric nl for all images that have an opposite cell
    3935         232 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3936       10017 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3937        9785 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3938             : 
    3939        9785 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3940        9785 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3941             : 
    3942        6948 :          mj_img = get_opp_index(j_img, qs_env)
    3943        6948 :          IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3944             : 
    3945        6659 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3946             : 
    3947        9785 :          pattern(iatom, jatom, j_img) = 1
    3948             :       END DO
    3949         232 :       CALL neighbor_list_iterator_release(nl_iterator)
    3950             : 
    3951             :       !If there is no opposite cell present, then we take into account the non-symmetric nl
    3952         232 :       CALL get_kpoint_info(kpoints, sab_nl_nosym=nl_2c)
    3953             : 
    3954         232 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    3955       13102 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    3956       12870 :          CALL get_iterator_info(nl_iterator, iatom=iatom, jatom=jatom, cell=cell_j)
    3957             : 
    3958       12870 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    3959       12870 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    3960             : 
    3961        8872 :          mj_img = get_opp_index(j_img, qs_env)
    3962        8872 :          IF (mj_img .LE. nimg .AND. mj_img > 0) CYCLE
    3963             : 
    3964         298 :          IF (ri_data%present_images(j_img) == 0) CYCLE
    3965             : 
    3966       12870 :          pattern(iatom, jatom, j_img) = 1
    3967             :       END DO
    3968         232 :       CALL neighbor_list_iterator_release(nl_iterator)
    3969             : 
    3970       81628 :       CALL para_env%sum(pattern)
    3971             : 
    3972             :       !If the opposite image is considered, then there is no need to compute diagonal twice
    3973        5814 :       DO j_img = 2, nimg
    3974       16978 :          DO iatom = 1, natom
    3975       16746 :             IF (pattern(iatom, iatom, j_img) .NE. 0) THEN
    3976        3788 :                mj_img = get_opp_index(j_img, qs_env)
    3977        3788 :                IF (mj_img > nimg .OR. mj_img < 1) CYCLE
    3978        3788 :                pattern(iatom, iatom, mj_img) = 0
    3979             :             END IF
    3980             :          END DO
    3981             :       END DO
    3982             : 
    3983             :       ! We want to equilibrate the sparsity pattern such that there are same amount of blocks
    3984             :       ! for each atom i of i,j pairs
    3985         696 :       ALLOCATE (bins(natom))
    3986         696 :       bins(:) = 0
    3987             : 
    3988        1160 :       ALLOCATE (tmp_pattern(natom, natom, nimg))
    3989       40930 :       tmp_pattern(:, :, :) = 0
    3990        6046 :       DO j_img = 1, nimg
    3991       17674 :          DO jatom = 1, natom
    3992       40698 :             DO iatom = 1, natom
    3993       23256 :                IF (pattern(iatom, jatom, j_img) == 0) CYCLE
    3994        7622 :                mj_img = get_opp_index(j_img, qs_env)
    3995             : 
    3996             :                !Should we take the i,j,b or th j,i,-b atomic block?
    3997       19250 :                IF (mj_img > nimg .OR. mj_img < 1) THEN
    3998             :                   !No opposite image, no choice
    3999         198 :                   bins(iatom) = bins(iatom) + 1
    4000         198 :                   tmp_pattern(iatom, jatom, j_img) = 1
    4001             :                ELSE
    4002             : 
    4003        7424 :                   IF (bins(iatom) > bins(jatom)) THEN
    4004        1498 :                      bins(jatom) = bins(jatom) + 1
    4005        1498 :                      tmp_pattern(jatom, iatom, mj_img) = 1
    4006             :                   ELSE
    4007        5926 :                      bins(iatom) = bins(iatom) + 1
    4008        5926 :                      tmp_pattern(iatom, jatom, j_img) = 1
    4009             :                   END IF
    4010             :                END IF
    4011             :             END DO
    4012             :          END DO
    4013             :       END DO
    4014             : 
    4015             :       ! -1 => unoccupied, 0 => occupied
    4016       40930 :       pattern(:, :, :) = tmp_pattern(:, :, :) - 1
    4017             : 
    4018         464 :    END SUBROUTINE get_sparsity_pattern
    4019             : 
    4020             : ! **************************************************************************************************
    4021             : !> \brief Distribute the iatom, jatom, b_img triplet over the subgroupd to spread the load
    4022             : !>        the group id for each triplet is passed as the value of sparsity_pattern(i, j, b),
    4023             : !>        with -1 being an unoccupied block
    4024             : !> \param sparsity_pattern ...
    4025             : !> \param ngroups ...
    4026             : !> \param ri_data ...
    4027             : ! **************************************************************************************************
    4028         232 :    SUBROUTINE get_sub_dist(sparsity_pattern, ngroups, ri_data)
    4029             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: sparsity_pattern
    4030             :       INTEGER, INTENT(IN)                                :: ngroups
    4031             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4032             : 
    4033             :       INTEGER                                            :: b_img, ctr, iat, iatom, igroup, jatom, &
    4034             :                                                             natom, nimg, ub
    4035         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: max_at_per_group
    4036             :       REAL(dp)                                           :: cost
    4037         232 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4038             : 
    4039         232 :       natom = SIZE(sparsity_pattern, 2)
    4040         232 :       nimg = SIZE(sparsity_pattern, 3)
    4041             : 
    4042             :       !To avoid unnecessary data replication accross the subgroups, we want to have a limited number
    4043             :       !of subgroup with the data of a given iatom. At the minimum, all groups have 1 atom
    4044             :       !We assume that the cost associated to each iatom is roughly the same
    4045         232 :       IF (.NOT. ALLOCATED(ri_data%iatom_to_subgroup)) THEN
    4046         350 :          ALLOCATE (ri_data%iatom_to_subgroup(natom), max_at_per_group(ngroups))
    4047         150 :          DO iatom = 1, natom
    4048         100 :             NULLIFY (ri_data%iatom_to_subgroup(iatom)%array)
    4049         200 :             ALLOCATE (ri_data%iatom_to_subgroup(iatom)%array(ngroups))
    4050         350 :             ri_data%iatom_to_subgroup(iatom)%array(:) = .FALSE.
    4051             :          END DO
    4052             : 
    4053          50 :          ub = natom/ngroups
    4054          50 :          IF (ub*ngroups < natom) ub = ub + 1
    4055         150 :          max_at_per_group(:) = MAX(1, ub)
    4056             : 
    4057             :          !We want each atom to be present the same amount of times. Some groups might have more atoms
    4058             :          !than other to achieve this.
    4059             :          ctr = 0
    4060         150 :          DO WHILE (MODULO(SUM(max_at_per_group), natom) .NE. 0)
    4061           0 :             igroup = MODULO(ctr, ngroups) + 1
    4062           0 :             max_at_per_group(igroup) = max_at_per_group(igroup) + 1
    4063          50 :             ctr = ctr + 1
    4064             :          END DO
    4065             : 
    4066             :          ctr = 0
    4067         150 :          DO igroup = 1, ngroups
    4068         250 :             DO iat = 1, max_at_per_group(igroup)
    4069         100 :                iatom = MODULO(ctr, natom) + 1
    4070         100 :                ri_data%iatom_to_subgroup(iatom)%array(igroup) = .TRUE.
    4071         200 :                ctr = ctr + 1
    4072             :             END DO
    4073             :          END DO
    4074             :       END IF
    4075             : 
    4076         696 :       ALLOCATE (bins(ngroups))
    4077         696 :       bins = 0.0_dp
    4078        6046 :       DO b_img = 1, nimg
    4079       17674 :          DO jatom = 1, natom
    4080       40698 :             DO iatom = 1, natom
    4081       23256 :                IF (sparsity_pattern(iatom, jatom, b_img) == -1) CYCLE
    4082       38110 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4083             : 
    4084             :                !Use cost information from previous SCF if available
    4085      554266 :                IF (ANY(ri_data%kp_cost > EPSILON(0.0_dp))) THEN
    4086        5404 :                   cost = ri_data%kp_cost(iatom, jatom, b_img)
    4087             :                ELSE
    4088        2218 :                   cost = REAL(ri_data%bsizes_AO(iatom)*ri_data%bsizes_AO(jatom), dp)
    4089             :                END IF
    4090        7622 :                bins(igroup + 1) = bins(igroup + 1) + cost
    4091       34884 :                sparsity_pattern(iatom, jatom, b_img) = igroup
    4092             :             END DO
    4093             :          END DO
    4094             :       END DO
    4095             : 
    4096         232 :    END SUBROUTINE get_sub_dist
    4097             : 
    4098             : ! **************************************************************************************************
    4099             : !> \brief A rouine that updates the sparsity pattern for force calculation, where all i,j,b combinations
    4100             : !>        are visited.
    4101             : !> \param force_pattern ...
    4102             : !> \param scf_pattern ...
    4103             : !> \param ngroups ...
    4104             : !> \param ri_data ...
    4105             : !> \param qs_env ...
    4106             : ! **************************************************************************************************
    4107          42 :    SUBROUTINE update_pattern_to_forces(force_pattern, scf_pattern, ngroups, ri_data, qs_env)
    4108             :       INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: force_pattern, scf_pattern
    4109             :       INTEGER, INTENT(IN)                                :: ngroups
    4110             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4111             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4112             : 
    4113             :       INTEGER                                            :: b_img, iatom, igroup, jatom, mb_img, &
    4114             :                                                             natom, nimg
    4115          42 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: bins
    4116             : 
    4117          42 :       natom = SIZE(scf_pattern, 2)
    4118          42 :       nimg = SIZE(scf_pattern, 3)
    4119             : 
    4120         126 :       ALLOCATE (bins(ngroups))
    4121         126 :       bins = 0.0_dp
    4122             : 
    4123        1052 :       DO b_img = 1, nimg
    4124        1010 :          mb_img = get_opp_index(b_img, qs_env)
    4125        3072 :          DO jatom = 1, natom
    4126        7070 :             DO iatom = 1, natom
    4127             :                !Important: same distribution as KS matrix, because reuse t_3c_apc
    4128       20200 :                igroup = MINLOC(bins, 1, MASK=ri_data%iatom_to_subgroup(iatom)%array) - 1
    4129             : 
    4130             :                !check that block not already treated
    4131        4040 :                IF (scf_pattern(iatom, jatom, b_img) > -1) CYCLE
    4132             : 
    4133             :                !If not, take the cost of block j, i, -b (same energy contribution)
    4134        4902 :                IF (mb_img > 0 .AND. mb_img .LE. nimg) THEN
    4135        2486 :                   IF (scf_pattern(jatom, iatom, mb_img) == -1) CYCLE
    4136        1038 :                   bins(igroup + 1) = bins(igroup + 1) + ri_data%kp_cost(jatom, iatom, mb_img)
    4137        1038 :                   force_pattern(iatom, jatom, b_img) = igroup
    4138             :                END IF
    4139             :             END DO
    4140             :          END DO
    4141             :       END DO
    4142             : 
    4143          42 :    END SUBROUTINE update_pattern_to_forces
    4144             : 
    4145             : ! **************************************************************************************************
    4146             : !> \brief A routine that determines the extend of the KP RI-HFX periodic images, including for the
    4147             : !>        extension of the RI basis
    4148             : !> \param ri_data ...
    4149             : !> \param qs_env ...
    4150             : ! **************************************************************************************************
    4151          70 :    SUBROUTINE get_kp_and_ri_images(ri_data, qs_env)
    4152             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4153             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4154             : 
    4155             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_kp_and_ri_images'
    4156             : 
    4157             :       INTEGER :: cell_j(3), cell_k(3), handle, i_img, iatom, ikind, j_img, jatom, jcell, katom, &
    4158             :          kcell, kp_index_lbounds(3), kp_index_ubounds(3), natom, ngroups, nimg, nkind, pcoord(3), &
    4159             :          pdims(3)
    4160          70 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dist_AO_1, dist_AO_2, dist_RI, &
    4161          70 :                                                             nRI_per_atom, present_img, RI_cells
    4162          70 :       INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
    4163             :       REAL(dp)                                           :: bump_fact, dij, dik, image_range, &
    4164             :                                                             RI_range, rij(3), rik(3)
    4165         490 :       TYPE(dbt_type)                                     :: t_dummy
    4166             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4167             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4168             :       TYPE(distribution_3d_type)                         :: dist_3d
    4169             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4170          70 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4171             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4172          70 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4173             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4174             :       TYPE(neighbor_list_3c_iterator_type)               :: nl_3c_iter
    4175             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4176             :       TYPE(neighbor_list_iterator_p_type), &
    4177          70 :          DIMENSION(:), POINTER                           :: nl_iterator
    4178             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4179          70 :          POINTER                                         :: nl_2c
    4180          70 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4181          70 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4182             :       TYPE(section_vals_type), POINTER                   :: hfx_section
    4183             : 
    4184          70 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, nl_iterator, dft_control, &
    4185          70 :                particle_set, kpoints, para_env, cell_to_index, hfx_section)
    4186             : 
    4187          70 :       CALL timeset(routineN, handle)
    4188             : 
    4189             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, &
    4190             :                       dft_control=dft_control, particle_set=particle_set, kpoints=kpoints, &
    4191          70 :                       para_env=para_env, natom=natom)
    4192          70 :       nimg = dft_control%nimages
    4193          70 :       CALL get_kpoint_info(kpoints, cell_to_index=cell_to_index)
    4194         280 :       kp_index_lbounds = LBOUND(cell_to_index)
    4195         280 :       kp_index_ubounds = UBOUND(cell_to_index)
    4196             : 
    4197          70 :       hfx_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC%HF%RI")
    4198          70 :       CALL section_vals_val_get(hfx_section, "KP_NGROUPS", i_val=ngroups)
    4199             : 
    4200         496 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4201          70 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4202          70 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4203             : 
    4204             :       !In case of shortrange HFX potential, it is imprtant to be consistent with the rest of the KP
    4205             :       !code, and use EPS_SCHWARZ to determine the range (rather than eps_filter_2c in normal RI-HFX)
    4206          70 :       IF (ri_data%hfx_pot%potential_type == do_potential_short) THEN
    4207           0 :          CALL erfc_cutoff(ri_data%eps_schwarz, ri_data%hfx_pot%omega, ri_data%hfx_pot%cutoff_radius)
    4208             :       END IF
    4209             : 
    4210             :       !Determine the range for contributing periodic images, and for the RI basis extension
    4211          70 :       ri_data%kp_RI_range = 0.0_dp
    4212          70 :       ri_data%kp_image_range = 0.0_dp
    4213         178 :       DO ikind = 1, nkind
    4214             : 
    4215         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4216         108 :          CALL get_gto_basis_set(basis_set_AO(ikind)%gto_basis_set, kind_radius=RI_range)
    4217         108 :          ri_data%kp_RI_range = MAX(RI_range, ri_data%kp_RI_range)
    4218             : 
    4219         108 :          CALL init_interaction_radii_orb_basis(basis_set_AO(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4220         108 :          CALL init_interaction_radii_orb_basis(basis_set_RI(ikind)%gto_basis_set, ri_data%eps_pgf_orb)
    4221         108 :          CALL get_gto_basis_set(basis_set_RI(ikind)%gto_basis_set, kind_radius=image_range)
    4222             : 
    4223         108 :          image_range = 2.0_dp*image_range + cutoff_screen_factor*ri_data%hfx_pot%cutoff_radius
    4224         286 :          ri_data%kp_image_range = MAX(image_range, ri_data%kp_image_range)
    4225             :       END DO
    4226             : 
    4227          70 :       CALL section_vals_val_get(hfx_section, "KP_RI_BUMP_FACTOR", r_val=bump_fact)
    4228          70 :       ri_data%kp_bump_rad = bump_fact*ri_data%kp_RI_range
    4229             : 
    4230             :       !For the extent of the KP RI-HFX images, we are limited by the RI-HFX potential in
    4231             :       !(mu^0 sigma^a|P^0) (P^0|Q^b) (Q^b|nu^b lambda^a+c), if there is no contact between
    4232             :       !any P^0 and Q^b, then image b does not contribute
    4233             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4234          70 :                                    "HFX_2c_nl_RI", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4235             : 
    4236         210 :       ALLOCATE (present_img(nimg))
    4237        3120 :       present_img = 0
    4238          70 :       ri_data%nimg = 0
    4239          70 :       CALL neighbor_list_iterator_create(nl_iterator, nl_2c)
    4240        1568 :       DO WHILE (neighbor_list_iterate(nl_iterator) == 0)
    4241        1498 :          CALL get_iterator_info(nl_iterator, r=rij, cell=cell_j)
    4242             : 
    4243        5992 :          dij = NORM2(rij)
    4244             : 
    4245        1498 :          j_img = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4246        1498 :          IF (j_img > nimg .OR. j_img < 1) CYCLE
    4247             : 
    4248        1466 :          IF (dij > ri_data%kp_image_range) CYCLE
    4249             : 
    4250        1466 :          ri_data%nimg = MAX(j_img, ri_data%nimg)
    4251        1498 :          present_img(j_img) = 1
    4252             : 
    4253             :       END DO
    4254          70 :       CALL neighbor_list_iterator_release(nl_iterator)
    4255          70 :       CALL release_neighbor_list_sets(nl_2c)
    4256          70 :       CALL para_env%max(ri_data%nimg)
    4257          70 :       IF (ri_data%nimg > nimg) &
    4258           0 :          CPABORT("Make sure the smallest exponent of the RI-HFX basis is larger than that of the ORB basis.")
    4259             : 
    4260             :       !Keep track of which images will not contribute, so that can be ignored before calculation
    4261          70 :       CALL para_env%sum(present_img)
    4262         210 :       ALLOCATE (ri_data%present_images(ri_data%nimg))
    4263        1786 :       ri_data%present_images = 0
    4264        1786 :       DO i_img = 1, ri_data%nimg
    4265        1786 :          IF (present_img(i_img) > 0) ri_data%present_images(i_img) = 1
    4266             :       END DO
    4267             : 
    4268             :       CALL create_3c_tensor(t_dummy, dist_AO_1, dist_AO_2, dist_RI, &
    4269             :                             ri_data%pgrid, ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4270          70 :                             map1=[1, 2], map2=[3], name="(AO AO | RI)")
    4271             : 
    4272          70 :       CALL dbt_mp_environ_pgrid(ri_data%pgrid, pdims, pcoord)
    4273          70 :       CALL mp_comm_t3c%create(ri_data%pgrid%mp_comm_2d, 3, pdims)
    4274             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4275          70 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4276          70 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4277          70 :       CALL dbt_destroy(t_dummy)
    4278             : 
    4279             :       !For the extension of the RI basis P in (mu^0 sigma^a |P^i), we consider an atom if the distance,
    4280             :       !between mu^0 and P^i if smaller or equal to the kind radius of mu^0
    4281             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, &
    4282             :                                    ri_data%ri_metric, "HFX_3c_nl", qs_env, op_pos=2, sym_ij=.FALSE., &
    4283          70 :                                    own_dist=.TRUE.)
    4284             : 
    4285         140 :       ALLOCATE (RI_cells(nimg))
    4286        3120 :       RI_cells = 0
    4287             : 
    4288         210 :       ALLOCATE (nRI_per_atom(natom))
    4289         210 :       nRI_per_atom = 0
    4290             : 
    4291          70 :       CALL neighbor_list_3c_iterator_create(nl_3c_iter, nl_3c)
    4292       58508 :       DO WHILE (neighbor_list_3c_iterate(nl_3c_iter) == 0)
    4293             :          CALL get_3c_iterator_info(nl_3c_iter, cell_k=cell_k, rik=rik, cell_j=cell_j, &
    4294       58438 :                                    iatom=iatom, jatom=jatom, katom=katom)
    4295      233752 :          dik = NORM2(rik)
    4296             : 
    4297      409066 :          IF (ANY([cell_j(1), cell_j(2), cell_j(3)] < kp_index_lbounds) .OR. &
    4298             :              ANY([cell_j(1), cell_j(2), cell_j(3)] > kp_index_ubounds)) CYCLE
    4299             : 
    4300       58438 :          jcell = cell_to_index(cell_j(1), cell_j(2), cell_j(3))
    4301       58438 :          IF (jcell > nimg .OR. jcell < 1) CYCLE
    4302             : 
    4303      385667 :          IF (ANY([cell_k(1), cell_k(2), cell_k(3)] < kp_index_lbounds) .OR. &
    4304             :              ANY([cell_k(1), cell_k(2), cell_k(3)] > kp_index_ubounds)) CYCLE
    4305             : 
    4306       51169 :          kcell = cell_to_index(cell_k(1), cell_k(2), cell_k(3))
    4307       51169 :          IF (kcell > nimg .OR. kcell < 1) CYCLE
    4308             : 
    4309       43523 :          IF (dik > ri_data%kp_RI_range) CYCLE
    4310        5791 :          RI_cells(kcell) = 1
    4311             : 
    4312        5861 :          IF (jcell == 1 .AND. iatom == jatom) nRI_per_atom(iatom) = nRI_per_atom(iatom) + ri_data%bsizes_RI(katom)
    4313             :       END DO
    4314          70 :       CALL neighbor_list_3c_iterator_destroy(nl_3c_iter)
    4315          70 :       CALL neighbor_list_3c_destroy(nl_3c)
    4316          70 :       CALL para_env%sum(RI_cells)
    4317          70 :       CALL para_env%sum(nRI_per_atom)
    4318             : 
    4319         140 :       ALLOCATE (ri_data%img_to_RI_cell(nimg))
    4320          70 :       ri_data%ncell_RI = 0
    4321        3120 :       ri_data%img_to_RI_cell = 0
    4322        3120 :       DO i_img = 1, nimg
    4323        3120 :          IF (RI_cells(i_img) > 0) THEN
    4324         436 :             ri_data%ncell_RI = ri_data%ncell_RI + 1
    4325         436 :             ri_data%img_to_RI_cell(i_img) = ri_data%ncell_RI
    4326             :          END IF
    4327             :       END DO
    4328             : 
    4329         210 :       ALLOCATE (ri_data%RI_cell_to_img(ri_data%ncell_RI))
    4330        3120 :       DO i_img = 1, nimg
    4331        3120 :          IF (ri_data%img_to_RI_cell(i_img) > 0) ri_data%RI_cell_to_img(ri_data%img_to_RI_cell(i_img)) = i_img
    4332             :       END DO
    4333             : 
    4334             :       !Print some info
    4335          70 :       IF (ri_data%unit_nr > 0) THEN
    4336             :          WRITE (ri_data%unit_nr, FMT="(/T3,A,I29)") &
    4337          35 :             "KP-HFX_RI_INFO| Number of RI-KP parallel groups:", ngroups
    4338             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F31.3,A)") &
    4339          35 :             "KP-HFX_RI_INFO| RI basis extension radius:", ri_data%kp_RI_range*angstrom, " Ang"
    4340             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F12.3,A, F6.3, A)") &
    4341          35 :             "KP-HFX_RI_INFO| RI basis bump factor and bump radius:", bump_fact, " /", &
    4342          70 :             ri_data%kp_bump_rad*angstrom, " Ang"
    4343             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I16,A)") &
    4344          35 :             "KP-HFX_RI_INFO| The extended RI bases cover up to ", ri_data%ncell_RI, " unit cells"
    4345             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I18)") &
    4346         105 :             "KP-HFX_RI_INFO| Average number of sgf in extended RI bases:", SUM(nRI_per_atom)/natom
    4347             :          WRITE (ri_data%unit_nr, FMT="(T3,A,F13.3,A)") &
    4348          35 :             "KP-HFX_RI_INFO| Consider all image cells within a radius of ", ri_data%kp_image_range*angstrom, " Ang"
    4349             :          WRITE (ri_data%unit_nr, FMT="(T3,A,I27/)") &
    4350          35 :             "KP-HFX_RI_INFO| Number of image cells considered: ", ri_data%nimg
    4351          35 :          CALL m_flush(ri_data%unit_nr)
    4352             :       END IF
    4353             : 
    4354          70 :       CALL timestop(handle)
    4355             : 
    4356         840 :    END SUBROUTINE get_kp_and_ri_images
    4357             : 
    4358             : ! **************************************************************************************************
    4359             : !> \brief A routine that creates tensors structure for rho_ao and 3c_ints in a stacked format for
    4360             : !>        the efficient contractions of rho_sigma^0,lambda^c * (mu^0 sigam^a | P) => TAS tensors
    4361             : !> \param res_stack ...
    4362             : !> \param rho_stack ...
    4363             : !> \param ints_stack ...
    4364             : !> \param rho_template ...
    4365             : !> \param ints_template ...
    4366             : !> \param stack_size ...
    4367             : !> \param ri_data ...
    4368             : !> \param qs_env ...
    4369             : !> \note The result tensor has the exact same shape and distribution as the integral tensor
    4370             : ! **************************************************************************************************
    4371         232 :    SUBROUTINE get_stack_tensors(res_stack, rho_stack, ints_stack, rho_template, ints_template, &
    4372             :                                 stack_size, ri_data, qs_env)
    4373             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: res_stack, rho_stack, ints_stack
    4374             :       TYPE(dbt_type), INTENT(INOUT)                      :: rho_template, ints_template
    4375             :       INTEGER, INTENT(IN)                                :: stack_size
    4376             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4377             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4378             : 
    4379             :       INTEGER                                            :: is, nblks, nblks_3c(3), pdims_3d(3)
    4380         232 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bsizes_RI_ext, bsizes_stack, dist1, &
    4381         232 :                                                             dist2, dist3, dist_stack1, &
    4382         232 :                                                             dist_stack2, dist_stack3
    4383        2088 :       TYPE(dbt_distribution_type)                        :: t_dist
    4384         696 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4385             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4386             : 
    4387         232 :       NULLIFY (para_env)
    4388             : 
    4389         232 :       CALL get_qs_env(qs_env, para_env=para_env)
    4390             : 
    4391         232 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4392         696 :       ALLOCATE (bsizes_stack(stack_size*nblks))
    4393        3082 :       DO is = 1, stack_size
    4394       11582 :          bsizes_stack((is - 1)*nblks + 1:is*nblks) = ri_data%bsizes_AO_split(:)
    4395             :       END DO
    4396             : 
    4397        2088 :       ALLOCATE (dist1(nblks), dist2(nblks), dist_stack1(stack_size*nblks), dist_stack2(stack_size*nblks))
    4398         232 :       CALL dbt_get_info(rho_template, proc_dist_1=dist1, proc_dist_2=dist2)
    4399        3082 :       DO is = 1, stack_size
    4400       11350 :          dist_stack1((is - 1)*nblks + 1:is*nblks) = dist1(:)
    4401       11582 :          dist_stack2((is - 1)*nblks + 1:is*nblks) = dist2(:)
    4402             :       END DO
    4403             : 
    4404             :       !First 2c tensor matches the distribution of template
    4405             :       !It is stacked in both directions
    4406         232 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_2d, dist_stack1, dist_stack2)
    4407         232 :       CALL dbt_create(rho_stack(1), "RHO_stack", t_dist, [1], [2], bsizes_stack, bsizes_stack)
    4408         232 :       CALL dbt_distribution_destroy(t_dist)
    4409         232 :       DEALLOCATE (dist1, dist2, dist_stack1, dist_stack2)
    4410             : 
    4411             :       !Second 2c tensor has optimal distribution on the 2d pgrid
    4412         232 :       CALL create_2c_tensor(rho_stack(2), dist1, dist2, ri_data%pgrid_2d, bsizes_stack, bsizes_stack, name="RHO_stack")
    4413         232 :       DEALLOCATE (dist1, dist2)
    4414             : 
    4415         232 :       CALL dbt_get_info(ints_template, nblks_total=nblks_3c)
    4416        1624 :       ALLOCATE (dist1(nblks_3c(1)), dist2(nblks_3c(2)), dist3(nblks_3c(3)))
    4417        1160 :       ALLOCATE (dist_stack3(stack_size*nblks_3c(3)), bsizes_RI_ext(nblks_3c(2)))
    4418             :       CALL dbt_get_info(ints_template, proc_dist_1=dist1, proc_dist_2=dist2, &
    4419         232 :                         proc_dist_3=dist3, blk_size_2=bsizes_RI_ext)
    4420        3082 :       DO is = 1, stack_size
    4421       11582 :          dist_stack3((is - 1)*nblks_3c(3) + 1:is*nblks_3c(3)) = dist3(:)
    4422             :       END DO
    4423             : 
    4424             :       !First 3c tensor matches the distribution of template
    4425         232 :       CALL dbt_distribution_new(t_dist, ri_data%pgrid_1, dist1, dist2, dist_stack3)
    4426             :       CALL dbt_create(ints_stack(1), "ints_stack", t_dist, [1, 2], [3], ri_data%bsizes_AO_split, &
    4427         232 :                       bsizes_RI_ext, bsizes_stack)
    4428         232 :       CALL dbt_distribution_destroy(t_dist)
    4429         232 :       DEALLOCATE (dist1, dist2, dist3, dist_stack3)
    4430             : 
    4431             :       !Second 3c tensor has optimal pgrid
    4432         232 :       pdims_3d = 0
    4433         928 :       CALL dbt_pgrid_create(para_env, pdims_3d, pgrid, tensor_dims=[nblks_3c(1), nblks_3c(2), stack_size*nblks_3c(3)])
    4434             :       CALL create_3c_tensor(ints_stack(2), dist1, dist2, dist3, pgrid, ri_data%bsizes_AO_split, &
    4435         232 :                             bsizes_RI_ext, bsizes_stack, [1, 2], [3], name="ints_stack")
    4436         232 :       DEALLOCATE (dist1, dist2, dist3)
    4437         232 :       CALL dbt_pgrid_destroy(pgrid)
    4438             : 
    4439             :       !The result tensor has the same shape and dist as the integral tensor
    4440         232 :       CALL dbt_create(ints_stack(1), res_stack(1))
    4441         232 :       CALL dbt_create(ints_stack(2), res_stack(2))
    4442             : 
    4443         464 :    END SUBROUTINE get_stack_tensors
    4444             : 
    4445             : ! **************************************************************************************************
    4446             : !> \brief Fill the stack of 3c tensors accrding to the order in the images input
    4447             : !> \param t_3c_stack ...
    4448             : !> \param t_3c_in ...
    4449             : !> \param images ...
    4450             : !> \param stack_dim ...
    4451             : !> \param ri_data ...
    4452             : !> \param filter_at ...
    4453             : !> \param filter_dim ...
    4454             : !> \param idx_to_at ...
    4455             : !> \param img_bounds ...
    4456             : ! **************************************************************************************************
    4457       22326 :    SUBROUTINE fill_3c_stack(t_3c_stack, t_3c_in, images, stack_dim, ri_data, filter_at, filter_dim, &
    4458       22326 :                             idx_to_at, img_bounds)
    4459             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_stack
    4460             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_3c_in
    4461             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4462             :       INTEGER, INTENT(IN)                                :: stack_dim
    4463             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4464             :       INTEGER, INTENT(IN), OPTIONAL                      :: filter_at, filter_dim
    4465             :       INTEGER, DIMENSION(:), INTENT(INOUT), OPTIONAL     :: idx_to_at
    4466             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2)
    4467             : 
    4468             :       INTEGER                                            :: dest(3), i_img, idx, ind(3), lb, nblks, &
    4469             :                                                             nimg, offset, ub
    4470             :       LOGICAL                                            :: do_filter, found
    4471       22326 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4472             :       TYPE(dbt_iterator_type)                            :: iter
    4473             : 
    4474             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4475             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4476       22326 :       nimg = ri_data%nimg
    4477       22326 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4478             : 
    4479       22326 :       do_filter = .FALSE.
    4480       21806 :       IF (PRESENT(filter_at) .AND. PRESENT(filter_dim) .AND. PRESENT(idx_to_at)) do_filter = .TRUE.
    4481             : 
    4482       22326 :       lb = 1
    4483       22326 :       ub = nimg
    4484       22326 :       offset = 0
    4485       22326 :       IF (PRESENT(img_bounds)) THEN
    4486       22326 :          lb = img_bounds(1)
    4487       22326 :          ub = img_bounds(2) - 1
    4488       22326 :          offset = lb - 1
    4489             :       END IF
    4490             : 
    4491      449976 :       DO idx = lb, ub
    4492      427650 :          i_img = images(idx)
    4493      427650 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4494             : 
    4495             : !$OMP PARALLEL DEFAULT(NONE) &
    4496             : !$OMP SHARED(idx,i_img,t_3c_in,t_3c_stack,nblks,stack_dim,filter_at,filter_dim,idx_to_at,do_filter,offset) &
    4497      449976 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4498             :          CALL dbt_iterator_start(iter, t_3c_in(i_img))
    4499             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4500             :             CALL dbt_iterator_next_block(iter, ind)
    4501             :             CALL dbt_get_block(t_3c_in(i_img), ind, blk, found)
    4502             :             IF (.NOT. found) CYCLE
    4503             : 
    4504             :             IF (do_filter) THEN
    4505             :                IF (.NOT. idx_to_at(ind(filter_dim)) == filter_at) CYCLE
    4506             :             END IF
    4507             : 
    4508             :             IF (stack_dim == 1) THEN
    4509             :                dest = [(idx - offset - 1)*nblks + ind(1), ind(2), ind(3)]
    4510             :             ELSE IF (stack_dim == 2) THEN
    4511             :                dest = [ind(1), (idx - offset - 1)*nblks + ind(2), ind(3)]
    4512             :             ELSE
    4513             :                dest = [ind(1), ind(2), (idx - offset - 1)*nblks + ind(3)]
    4514             :             END IF
    4515             : 
    4516             :             CALL dbt_put_block(t_3c_stack, dest, SHAPE(blk), blk)
    4517             :             DEALLOCATE (blk)
    4518             :          END DO
    4519             :          CALL dbt_iterator_stop(iter)
    4520             : !$OMP END PARALLEL
    4521             :       END DO !i_img
    4522       22326 :       CALL dbt_finalize(t_3c_stack)
    4523             : 
    4524       44652 :    END SUBROUTINE fill_3c_stack
    4525             : 
    4526             : ! **************************************************************************************************
    4527             : !> \brief Fill the stack of 2c tensors based on the content of images input
    4528             : !> \param t_2c_stack ...
    4529             : !> \param t_2c_in ...
    4530             : !> \param images ...
    4531             : !> \param stack_dim ...
    4532             : !> \param ri_data ...
    4533             : !> \param img_bounds ...
    4534             : !> \param shift ...
    4535             : ! **************************************************************************************************
    4536       15228 :    SUBROUTINE fill_2c_stack(t_2c_stack, t_2c_in, images, stack_dim, ri_data, img_bounds, shift)
    4537             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_stack
    4538             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_in
    4539             :       INTEGER, DIMENSION(:), INTENT(INOUT)               :: images
    4540             :       INTEGER, INTENT(IN)                                :: stack_dim
    4541             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4542             :       INTEGER, INTENT(IN), OPTIONAL                      :: img_bounds(2), shift
    4543             : 
    4544             :       INTEGER                                            :: dest(2), i_img, idx, ind(2), lb, &
    4545             :                                                             my_shift, nblks, nimg, offset, ub
    4546             :       LOGICAL                                            :: found
    4547       15228 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: blk
    4548             :       TYPE(dbt_iterator_type)                            :: iter
    4549             : 
    4550             :       !We loop over the a images from the ac_pairs, then copy the 3c ints to the correct spot in
    4551             :       !in the stack tensor (corresponding to pair index). Distributions match by construction
    4552       15228 :       nimg = ri_data%nimg
    4553       15228 :       nblks = SIZE(ri_data%bsizes_AO_split)
    4554             : 
    4555       15228 :       lb = 1
    4556       15228 :       ub = nimg
    4557       15228 :       offset = 0
    4558       15228 :       IF (PRESENT(img_bounds)) THEN
    4559       15228 :          lb = img_bounds(1)
    4560       15228 :          ub = img_bounds(2) - 1
    4561       15228 :          offset = lb - 1
    4562             :       END IF
    4563             : 
    4564       15228 :       my_shift = 1
    4565       15228 :       IF (PRESENT(shift)) my_shift = shift
    4566             : 
    4567      202454 :       DO idx = lb, ub
    4568      187226 :          i_img = images(idx)
    4569      187226 :          IF (i_img == 0 .OR. i_img > nimg) CYCLE
    4570             : 
    4571             : !$OMP PARALLEL DEFAULT(NONE) SHARED(idx,i_img,t_2c_in,t_2c_stack,nblks,stack_dim,offset,my_shift) &
    4572      202454 : !$OMP PRIVATE(iter,ind,blk,found,dest)
    4573             :          CALL dbt_iterator_start(iter, t_2c_in(i_img))
    4574             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4575             :             CALL dbt_iterator_next_block(iter, ind)
    4576             :             CALL dbt_get_block(t_2c_in(i_img), ind, blk, found)
    4577             :             IF (.NOT. found) CYCLE
    4578             : 
    4579             :             IF (stack_dim == 1) THEN
    4580             :                dest = [(idx - offset - 1)*nblks + ind(1), (my_shift - 1)*nblks + ind(2)]
    4581             :             ELSE
    4582             :                dest = [(my_shift - 1)*nblks + ind(1), (idx - offset - 1)*nblks + ind(2)]
    4583             :             END IF
    4584             : 
    4585             :             CALL dbt_put_block(t_2c_stack, dest, SHAPE(blk), blk)
    4586             :             DEALLOCATE (blk)
    4587             :          END DO
    4588             :          CALL dbt_iterator_stop(iter)
    4589             : !$OMP END PARALLEL
    4590             :       END DO !idx
    4591       15228 :       CALL dbt_finalize(t_2c_stack)
    4592             : 
    4593       30456 :    END SUBROUTINE fill_2c_stack
    4594             : 
    4595             : ! **************************************************************************************************
    4596             : !> \brief Unstacks a stacked 3c tensor containing t_3c_apc
    4597             : !> \param t_3c_apc ...
    4598             : !> \param t_stacked ...
    4599             : !> \param idx ...
    4600             : ! **************************************************************************************************
    4601       17830 :    SUBROUTINE unstack_t_3c_apc(t_3c_apc, t_stacked, idx)
    4602             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_apc, t_stacked
    4603             :       INTEGER, INTENT(IN)                                :: idx
    4604             : 
    4605             :       INTEGER                                            :: current_idx
    4606             :       INTEGER, DIMENSION(3)                              :: ind, nblks_3c
    4607             :       LOGICAL                                            :: found
    4608       17830 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4609             :       TYPE(dbt_iterator_type)                            :: iter
    4610             : 
    4611             :       !Note: t_3c_apc and t_stacked must have the same ditribution
    4612       17830 :       CALL dbt_get_info(t_3c_apc, nblks_total=nblks_3c)
    4613             : 
    4614       17830 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_apc,t_stacked,idx,nblks_3c) PRIVATE(iter,ind,blk,found,current_idx)
    4615             :       CALL dbt_iterator_start(iter, t_stacked)
    4616             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4617             :          CALL dbt_iterator_next_block(iter, ind)
    4618             : 
    4619             :          !tensor is stacked along the 3rd dimension
    4620             :          current_idx = (ind(3) - 1)/nblks_3c(3) + 1
    4621             :          IF (.NOT. idx == current_idx) CYCLE
    4622             : 
    4623             :          CALL dbt_get_block(t_stacked, ind, blk, found)
    4624             :          IF (.NOT. found) CYCLE
    4625             : 
    4626             :          CALL dbt_put_block(t_3c_apc, [ind(1), ind(2), ind(3) - (idx - 1)*nblks_3c(3)], SHAPE(blk), blk)
    4627             :          DEALLOCATE (blk)
    4628             :       END DO
    4629             :       CALL dbt_iterator_stop(iter)
    4630             : !$OMP END PARALLEL
    4631             : 
    4632       17830 :    END SUBROUTINE unstack_t_3c_apc
    4633             : 
    4634             : ! **************************************************************************************************
    4635             : !> \brief copies the 3c integrals correspoinding to a single atom mu from the general (P^0| mu^0 sigam^a)
    4636             : !> \param t_3c_at ...
    4637             : !> \param t_3c_ints ...
    4638             : !> \param iatom ...
    4639             : !> \param dim_at ...
    4640             : !> \param idx_to_at ...
    4641             : ! **************************************************************************************************
    4642           0 :    SUBROUTINE get_atom_3c_ints(t_3c_at, t_3c_ints, iatom, dim_at, idx_to_at)
    4643             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_at, t_3c_ints
    4644             :       INTEGER, INTENT(IN)                                :: iatom, dim_at
    4645             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: idx_to_at
    4646             : 
    4647             :       INTEGER, DIMENSION(3)                              :: ind
    4648             :       LOGICAL                                            :: found
    4649           0 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: blk
    4650             :       TYPE(dbt_iterator_type)                            :: iter
    4651             : 
    4652           0 : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_3c_ints,t_3c_at,iatom,idx_to_at,dim_at) PRIVATE(iter,ind,blk,found)
    4653             :       CALL dbt_iterator_start(iter, t_3c_ints)
    4654             :       DO WHILE (dbt_iterator_blocks_left(iter))
    4655             :          CALL dbt_iterator_next_block(iter, ind)
    4656             :          IF (.NOT. idx_to_at(ind(dim_at)) == iatom) CYCLE
    4657             : 
    4658             :          CALL dbt_get_block(t_3c_ints, ind, blk, found)
    4659             :          IF (.NOT. found) CYCLE
    4660             : 
    4661             :          CALL dbt_put_block(t_3c_at, ind, SHAPE(blk), blk)
    4662             :          DEALLOCATE (blk)
    4663             :       END DO
    4664             :       CALL dbt_iterator_stop(iter)
    4665             : !$OMP END PARALLEL
    4666           0 :       CALL dbt_finalize(t_3c_at)
    4667             : 
    4668           0 :    END SUBROUTINE get_atom_3c_ints
    4669             : 
    4670             : ! **************************************************************************************************
    4671             : !> \brief Precalculate the 3c and 2c derivatives tensors
    4672             : !> \param t_3c_der_RI ...
    4673             : !> \param t_3c_der_AO ...
    4674             : !> \param mat_der_pot ...
    4675             : !> \param t_2c_der_metric ...
    4676             : !> \param ri_data ...
    4677             : !> \param qs_env ...
    4678             : ! **************************************************************************************************
    4679          42 :    SUBROUTINE precalc_derivatives(t_3c_der_RI, t_3c_der_AO, mat_der_pot, t_2c_der_metric, ri_data, qs_env)
    4680             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_3c_der_RI, t_3c_der_AO
    4681             :       TYPE(dbcsr_type), DIMENSION(:, :), INTENT(INOUT)   :: mat_der_pot
    4682             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(INOUT)     :: t_2c_der_metric
    4683             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4684             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4685             : 
    4686             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'precalc_derivatives'
    4687             : 
    4688             :       INTEGER                                            :: handle, handle2, i_img, i_mem, i_RI, &
    4689             :                                                             i_xyz, iatom, n_mem, natom, nblks_RI, &
    4690             :                                                             ncell_RI, nimg, nkind, nthreads
    4691             :       INTEGER(int_8)                                     :: nze
    4692          42 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: bsizes_RI_ext, bsizes_RI_ext_split, dist_AO_1, &
    4693          84 :          dist_AO_2, dist_RI, dist_RI_ext, dummy_end, dummy_start, end_blocks, start_blocks
    4694             :       INTEGER, DIMENSION(3)                              :: pcoord, pdims
    4695          84 :       INTEGER, DIMENSION(:), POINTER                     :: col_bsize, row_bsize
    4696             :       REAL(dp)                                           :: occ
    4697             :       TYPE(dbcsr_distribution_type)                      :: dbcsr_dist
    4698             :       TYPE(dbcsr_type)                                   :: dbcsr_template
    4699          42 :       TYPE(dbcsr_type), ALLOCATABLE, DIMENSION(:, :)     :: mat_der_metric
    4700         378 :       TYPE(dbt_distribution_type)                        :: t_dist
    4701         126 :       TYPE(dbt_pgrid_type)                               :: pgrid
    4702         378 :       TYPE(dbt_type)                                     :: t_3c_template
    4703          42 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :, :)    :: t_3c_der_AO_prv, t_3c_der_RI_prv
    4704             :       TYPE(dft_control_type), POINTER                    :: dft_control
    4705             :       TYPE(distribution_2d_type), POINTER                :: dist_2d
    4706             :       TYPE(distribution_3d_type)                         :: dist_3d
    4707             :       TYPE(gto_basis_set_p_type), ALLOCATABLE, &
    4708          42 :          DIMENSION(:), TARGET                            :: basis_set_AO, basis_set_RI
    4709          42 :       TYPE(mp_cart_type)                                 :: mp_comm_t3c
    4710             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    4711             :       TYPE(neighbor_list_3c_type)                        :: nl_3c
    4712             :       TYPE(neighbor_list_set_p_type), DIMENSION(:), &
    4713          42 :          POINTER                                         :: nl_2c
    4714          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    4715          42 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
    4716             : 
    4717          42 :       NULLIFY (qs_kind_set, dist_2d, nl_2c, particle_set, dft_control, para_env, row_bsize, col_bsize)
    4718             : 
    4719          42 :       CALL timeset(routineN, handle)
    4720             : 
    4721             :       CALL get_qs_env(qs_env, nkind=nkind, qs_kind_set=qs_kind_set, distribution_2d=dist_2d, natom=natom, &
    4722          42 :                       particle_set=particle_set, dft_control=dft_control, para_env=para_env)
    4723             : 
    4724          42 :       nimg = ri_data%nimg
    4725          42 :       ncell_RI = ri_data%ncell_RI
    4726             : 
    4727         300 :       ALLOCATE (basis_set_RI(nkind), basis_set_AO(nkind))
    4728          42 :       CALL basis_set_list_setup(basis_set_RI, ri_data%ri_basis_type, qs_kind_set)
    4729          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_RI)
    4730          42 :       CALL basis_set_list_setup(basis_set_AO, ri_data%orb_basis_type, qs_kind_set)
    4731          42 :       CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_AO)
    4732             : 
    4733             :       !Dealing with the 3c derivatives
    4734          42 :       nthreads = 1
    4735          42 : !$    nthreads = omp_get_num_threads()
    4736          42 :       pdims = 0
    4737         168 :       CALL dbt_pgrid_create(para_env, pdims, pgrid, tensor_dims=[MAX(1, natom/(ri_data%n_mem*nthreads)), natom, natom])
    4738             : 
    4739             :       CALL create_3c_tensor(t_3c_template, dist_AO_1, dist_AO_2, dist_RI, pgrid, &
    4740             :                             ri_data%bsizes_AO, ri_data%bsizes_AO, ri_data%bsizes_RI, &
    4741          42 :                             map1=[1, 2], map2=[3], name="tmp")
    4742          42 :       CALL dbt_destroy(t_3c_template)
    4743             : 
    4744             :       !We stack the RI basis images. Keep consistent distribution
    4745          42 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    4746         126 :       ALLOCATE (dist_RI_ext(natom*ncell_RI))
    4747          84 :       ALLOCATE (bsizes_RI_ext(natom*ncell_RI))
    4748         126 :       ALLOCATE (bsizes_RI_ext_split(nblks_RI*ncell_RI))
    4749         294 :       DO i_RI = 1, ncell_RI
    4750         756 :          bsizes_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = ri_data%bsizes_RI(:)
    4751         756 :          dist_RI_ext((i_RI - 1)*natom + 1:i_RI*natom) = dist_RI(:)
    4752        1334 :          bsizes_RI_ext_split((i_RI - 1)*nblks_RI + 1:i_RI*nblks_RI) = ri_data%bsizes_RI_split(:)
    4753             :       END DO
    4754             : 
    4755          42 :       CALL dbt_distribution_new(t_dist, pgrid, dist_AO_1, dist_AO_2, dist_RI_ext)
    4756             :       CALL dbt_create(t_3c_template, "KP_3c_der", t_dist, [1, 2], [3], &
    4757          42 :                       ri_data%bsizes_AO, ri_data%bsizes_AO, bsizes_RI_ext)
    4758          42 :       CALL dbt_distribution_destroy(t_dist)
    4759             : 
    4760        7404 :       ALLOCATE (t_3c_der_RI_prv(nimg, 1, 3), t_3c_der_AO_prv(nimg, 1, 3))
    4761         168 :       DO i_xyz = 1, 3
    4762        3198 :          DO i_img = 1, nimg
    4763        3030 :             CALL dbt_create(t_3c_template, t_3c_der_RI_prv(i_img, 1, i_xyz))
    4764        3156 :             CALL dbt_create(t_3c_template, t_3c_der_AO_prv(i_img, 1, i_xyz))
    4765             :          END DO
    4766             :       END DO
    4767          42 :       CALL dbt_destroy(t_3c_template)
    4768             : 
    4769          42 :       CALL dbt_mp_environ_pgrid(pgrid, pdims, pcoord)
    4770          42 :       CALL mp_comm_t3c%create(pgrid%mp_comm_2d, 3, pdims)
    4771             :       CALL distribution_3d_create(dist_3d, dist_AO_1, dist_AO_2, dist_RI, &
    4772          42 :                                   nkind, particle_set, mp_comm_t3c, own_comm=.TRUE.)
    4773          42 :       DEALLOCATE (dist_RI, dist_AO_1, dist_AO_2)
    4774          42 :       CALL dbt_pgrid_destroy(pgrid)
    4775             : 
    4776             :       CALL build_3c_neighbor_lists(nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, dist_3d, ri_data%ri_metric, &
    4777          42 :                                    "HFX_3c_nl", qs_env, op_pos=2, sym_jk=.FALSE., own_dist=.TRUE.)
    4778             : 
    4779          42 :       n_mem = ri_data%n_mem
    4780             :       CALL create_tensor_batches(ri_data%bsizes_RI, n_mem, dummy_start, dummy_end, &
    4781             :                                  start_blocks, end_blocks)
    4782          42 :       DEALLOCATE (dummy_start, dummy_end)
    4783             : 
    4784             :       CALL create_3c_tensor(t_3c_template, dist_RI, dist_AO_1, dist_AO_2, ri_data%pgrid_2, &
    4785             :                             bsizes_RI_ext_split, ri_data%bsizes_AO_split, ri_data%bsizes_AO_split, &
    4786          42 :                             map1=[1], map2=[2, 3], name="der (RI | AO AO)")
    4787         168 :       DO i_xyz = 1, 3
    4788        3198 :          DO i_img = 1, nimg
    4789        3030 :             CALL dbt_create(t_3c_template, t_3c_der_RI(i_img, i_xyz))
    4790        3156 :             CALL dbt_create(t_3c_template, t_3c_der_AO(i_img, i_xyz))
    4791             :          END DO
    4792             :       END DO
    4793             : 
    4794         116 :       DO i_mem = 1, n_mem
    4795             :          CALL build_3c_derivatives(t_3c_der_AO_prv, t_3c_der_RI_prv, ri_data%filter_eps, qs_env, &
    4796             :                                    nl_3c, basis_set_AO, basis_set_AO, basis_set_RI, &
    4797             :                                    ri_data%ri_metric, der_eps=ri_data%eps_schwarz_forces, op_pos=2, &
    4798             :                                    do_kpoints=.TRUE., do_hfx_kpoints=.TRUE., &
    4799             :                                    bounds_k=[start_blocks(i_mem), end_blocks(i_mem)], &
    4800         222 :                                    RI_range=ri_data%kp_RI_range, img_to_RI_cell=ri_data%img_to_RI_cell)
    4801             : 
    4802          74 :          CALL timeset(routineN//"_cpy", handle2)
    4803             :          !We go from (mu^0 sigma^i | P^j) to (P^i| sigma^j mu^0) and finally to (P^i| mu^0 sigma^j)
    4804        1994 :          DO i_img = 1, nimg
    4805        7754 :             DO i_xyz = 1, 3
    4806             :                !derivative wrt to mu^0
    4807        5760 :                CALL get_tensor_occupancy(t_3c_der_AO_prv(i_img, 1, i_xyz), nze, occ)
    4808        5760 :                IF (nze > 0) THEN
    4809             :                   CALL dbt_copy(t_3c_der_AO_prv(i_img, 1, i_xyz), t_3c_template, &
    4810        3434 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4811        3434 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4812             :                   CALL dbt_copy(t_3c_template, t_3c_der_AO(i_img, i_xyz), &
    4813        3434 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4814             :                END IF
    4815             : 
    4816             :                !derivative wrt to P^i
    4817        5760 :                CALL get_tensor_occupancy(t_3c_der_RI_prv(i_img, 1, i_xyz), nze, occ)
    4818       13440 :                IF (nze > 0) THEN
    4819             :                   CALL dbt_copy(t_3c_der_RI_prv(i_img, 1, i_xyz), t_3c_template, &
    4820        3424 :                                 order=[3, 2, 1], move_data=.TRUE.)
    4821        3424 :                   CALL dbt_filter(t_3c_template, ri_data%filter_eps)
    4822             :                   CALL dbt_copy(t_3c_template, t_3c_der_RI(i_img, i_xyz), &
    4823        3424 :                                 order=[1, 3, 2], move_data=.TRUE., summation=.TRUE.)
    4824             :                END IF
    4825             :             END DO
    4826             :          END DO
    4827         190 :          CALL timestop(handle2)
    4828             :       END DO
    4829          42 :       CALL dbt_destroy(t_3c_template)
    4830             : 
    4831          42 :       CALL neighbor_list_3c_destroy(nl_3c)
    4832         168 :       DO i_xyz = 1, 3
    4833        3198 :          DO i_img = 1, nimg
    4834        3030 :             CALL dbt_destroy(t_3c_der_RI_prv(i_img, 1, i_xyz))
    4835        3156 :             CALL dbt_destroy(t_3c_der_AO_prv(i_img, 1, i_xyz))
    4836             :          END DO
    4837             :       END DO
    4838        6102 :       DEALLOCATE (t_3c_der_RI_prv, t_3c_der_AO_prv)
    4839             : 
    4840             :       !Reorder 3c derivatives to be consistant with ints
    4841          42 :       CALL reorder_3c_derivs(t_3c_der_RI, ri_data)
    4842          42 :       CALL reorder_3c_derivs(t_3c_der_AO, ri_data)
    4843             : 
    4844          42 :       CALL timeset(routineN//"_2c", handle2)
    4845             :       !The 2-center derivatives
    4846          42 :       CALL cp_dbcsr_dist2d_to_dist(dist_2d, dbcsr_dist)
    4847         126 :       ALLOCATE (row_bsize(SIZE(ri_data%bsizes_RI)))
    4848          84 :       ALLOCATE (col_bsize(SIZE(ri_data%bsizes_RI)))
    4849         126 :       row_bsize(:) = ri_data%bsizes_RI
    4850         126 :       col_bsize(:) = ri_data%bsizes_RI
    4851             : 
    4852             :       CALL dbcsr_create(dbcsr_template, "2c_der", dbcsr_dist, dbcsr_type_no_symmetry, &
    4853          42 :                         row_bsize, col_bsize)
    4854          42 :       CALL dbcsr_distribution_release(dbcsr_dist)
    4855          42 :       DEALLOCATE (col_bsize, row_bsize)
    4856             : 
    4857        3282 :       ALLOCATE (mat_der_metric(nimg, 3))
    4858         168 :       DO i_xyz = 1, 3
    4859        3198 :          DO i_img = 1, nimg
    4860        3030 :             CALL dbcsr_create(mat_der_pot(i_img, i_xyz), template=dbcsr_template)
    4861        3156 :             CALL dbcsr_create(mat_der_metric(i_img, i_xyz), template=dbcsr_template)
    4862             :          END DO
    4863             :       END DO
    4864          42 :       CALL dbcsr_release(dbcsr_template)
    4865             : 
    4866             :       !HFX potential derivatives
    4867             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%hfx_pot, &
    4868          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4869             :       CALL build_2c_derivatives(mat_der_pot, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4870          42 :                                 basis_set_RI, basis_set_RI, ri_data%hfx_pot, do_kpoints=.TRUE.)
    4871          42 :       CALL release_neighbor_list_sets(nl_2c)
    4872             : 
    4873             :       !RI metric derivatives
    4874             :       CALL build_2c_neighbor_lists(nl_2c, basis_set_RI, basis_set_RI, ri_data%ri_metric, &
    4875          42 :                                    "HFX_2c_nl_pot", qs_env, sym_ij=.FALSE., dist_2d=dist_2d)
    4876             :       CALL build_2c_derivatives(mat_der_metric, ri_data%filter_eps_2c, qs_env, nl_2c, &
    4877          42 :                                 basis_set_RI, basis_set_RI, ri_data%ri_metric, do_kpoints=.TRUE.)
    4878          42 :       CALL release_neighbor_list_sets(nl_2c)
    4879             : 
    4880             :       !Get into extended RI basis and tensor format
    4881         168 :       DO i_xyz = 1, 3
    4882         378 :          DO iatom = 1, natom
    4883         252 :             CALL dbt_create(ri_data%t_2c_inv(1, 1), t_2c_der_metric(iatom, i_xyz))
    4884             :             CALL get_ext_2c_int(t_2c_der_metric(iatom, i_xyz), mat_der_metric(:, i_xyz), &
    4885         378 :                                 iatom, iatom, 1, ri_data, qs_env)
    4886             :          END DO
    4887        3198 :          DO i_img = 1, nimg
    4888        3156 :             CALL dbcsr_release(mat_der_metric(i_img, i_xyz))
    4889             :          END DO
    4890             :       END DO
    4891          42 :       CALL timestop(handle2)
    4892             : 
    4893          42 :       CALL timestop(handle)
    4894             : 
    4895         252 :    END SUBROUTINE precalc_derivatives
    4896             : 
    4897             : ! **************************************************************************************************
    4898             : !> \brief Update the forces due to the derivative of the a 2-center product d/dR (Q|R)
    4899             : !> \param force ...
    4900             : !> \param t_2c_contr A precontracted tensor containing sum_abcdPS (ab|P)(P|Q)^-1 (R|S)^-1 (S|cd) P_ac P_bd
    4901             : !> \param t_2c_der the d/dR (Q|R) tensor, in all 3 cartesian directions
    4902             : !> \param atom_of_kind ...
    4903             : !> \param kind_of ...
    4904             : !> \param img in which periodic image the second center of the tensor is
    4905             : !> \param pref ...
    4906             : !> \param ri_data ...
    4907             : !> \param qs_env ...
    4908             : !> \param work_virial ...
    4909             : !> \param cell ...
    4910             : !> \param particle_set ...
    4911             : !> \param diag ...
    4912             : !> \param offdiag ...
    4913             : !> \note IMPORTANT: t_tc_contr and t_2c_der need to have the same distribution. Atomic block sizes are
    4914             : !>                  assumed
    4915             : ! **************************************************************************************************
    4916        2895 :    SUBROUTINE get_2c_der_force(force, t_2c_contr, t_2c_der, atom_of_kind, kind_of, img, pref, &
    4917             :                                ri_data, qs_env, work_virial, cell, particle_set, diag, offdiag)
    4918             : 
    4919             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    4920             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_2c_contr
    4921             :       TYPE(dbt_type), DIMENSION(:), INTENT(INOUT)        :: t_2c_der
    4922             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of
    4923             :       INTEGER, INTENT(IN)                                :: img
    4924             :       REAL(dp), INTENT(IN)                               :: pref
    4925             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    4926             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    4927             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    4928             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    4929             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    4930             :          POINTER                                         :: particle_set
    4931             :       LOGICAL, INTENT(IN), OPTIONAL                      :: diag, offdiag
    4932             : 
    4933             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'get_2c_der_force'
    4934             : 
    4935             :       INTEGER                                            :: handle, i_img, i_RI, i_xyz, iat, &
    4936             :                                                             iat_of_kind, ikind, j_img, j_RI, &
    4937             :                                                             j_xyz, jat, jat_of_kind, jkind, natom
    4938             :       INTEGER, DIMENSION(2)                              :: ind
    4939        2895 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    4940             :       LOGICAL                                            :: found, my_diag, my_offdiag, use_virial
    4941             :       REAL(dp)                                           :: new_force
    4942        2895 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :), TARGET     :: contr_blk, der_blk
    4943             :       REAL(dp), DIMENSION(3)                             :: scoord
    4944             :       TYPE(dbt_iterator_type)                            :: iter
    4945             :       TYPE(kpoint_type), POINTER                         :: kpoints
    4946             : 
    4947        2895 :       NULLIFY (kpoints, index_to_cell)
    4948             : 
    4949             :       !Loop over the blocks of d/dR (Q|R), contract with the corresponding block of t_2c_contr and
    4950             :       !update the relevant force
    4951             : 
    4952        2895 :       CALL timeset(routineN, handle)
    4953             : 
    4954        2895 :       use_virial = .FALSE.
    4955        2895 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    4956             : 
    4957        2895 :       my_diag = .FALSE.
    4958        2895 :       IF (PRESENT(diag)) my_diag = diag
    4959             : 
    4960        2316 :       my_offdiag = .FALSE.
    4961        2316 :       IF (PRESENT(diag)) my_offdiag = offdiag
    4962             : 
    4963        2895 :       CALL get_qs_env(qs_env, kpoints=kpoints, natom=natom)
    4964        2895 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    4965             : 
    4966             : !$OMP PARALLEL DEFAULT(NONE) &
    4967             : !$OMP SHARED(t_2c_der,t_2c_contr,work_virial,force,use_virial,natom,index_to_cell,ri_data,img) &
    4968             : !$OMP SHARED(pref,atom_of_kind,kind_of,particle_set,cell,my_diag,my_offdiag) &
    4969             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk,contr_blk,found,new_force,i_RI,i_img,j_RI,j_img) &
    4970        2895 : !$OMP PRIVATE(iat,jat,iat_of_kind,jat_of_kind,ikind,jkind,scoord)
    4971             :       DO i_xyz = 1, 3
    4972             :          CALL dbt_iterator_start(iter, t_2c_der(i_xyz))
    4973             :          DO WHILE (dbt_iterator_blocks_left(iter))
    4974             :             CALL dbt_iterator_next_block(iter, ind)
    4975             : 
    4976             :             !Only take forecs due to block diagonal or block off-diagonal, depending on arguments
    4977             :             IF ((my_diag .AND. .NOT. my_offdiag) .OR. (.NOT. my_diag .AND. my_offdiag)) THEN
    4978             :                IF (my_diag .AND. (ind(1) .NE. ind(2))) CYCLE
    4979             :                IF (my_offdiag .AND. (ind(1) == ind(2))) CYCLE
    4980             :             END IF
    4981             : 
    4982             :             CALL dbt_get_block(t_2c_der(i_xyz), ind, der_blk, found)
    4983             :             CPASSERT(found)
    4984             :             CALL dbt_get_block(t_2c_contr, ind, contr_blk, found)
    4985             : 
    4986             :             IF (found) THEN
    4987             : 
    4988             :                !an element of d/dR (Q|R) corresponds to 2 things because of translational invariance
    4989             :                !(Q'| R) = - (Q| R'), once wrt the center on Q, and once on R
    4990             :                new_force = pref*SUM(der_blk(:, :)*contr_blk(:, :))
    4991             : 
    4992             :                i_RI = (ind(1) - 1)/natom + 1
    4993             :                i_img = ri_data%RI_cell_to_img(i_RI)
    4994             :                iat = ind(1) - (i_RI - 1)*natom
    4995             :                iat_of_kind = atom_of_kind(iat)
    4996             :                ikind = kind_of(iat)
    4997             : 
    4998             :                j_RI = (ind(2) - 1)/natom + 1
    4999             :                j_img = ri_data%RI_cell_to_img(j_RI)
    5000             :                jat = ind(2) - (j_RI - 1)*natom
    5001             :                jat_of_kind = atom_of_kind(jat)
    5002             :                jkind = kind_of(jat)
    5003             : 
    5004             :                !Force on iatom (first center)
    5005             : !$OMP ATOMIC
    5006             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5007             :                                                           + new_force
    5008             : 
    5009             :                IF (use_virial) THEN
    5010             : 
    5011             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5012             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_img), dp)
    5013             : 
    5014             :                   DO j_xyz = 1, 3
    5015             : !$OMP ATOMIC
    5016             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5017             :                   END DO
    5018             :                END IF
    5019             : 
    5020             :                !Force on jatom (second center)
    5021             : !$OMP ATOMIC
    5022             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5023             :                                                           - new_force
    5024             : 
    5025             :                IF (use_virial) THEN
    5026             : 
    5027             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5028             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, j_img) + index_to_cell(:, img), dp)
    5029             : 
    5030             :                   DO j_xyz = 1, 3
    5031             : !$OMP ATOMIC
    5032             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) - new_force*scoord(j_xyz)
    5033             :                   END DO
    5034             :                END IF
    5035             : 
    5036             :                DEALLOCATE (contr_blk)
    5037             :             END IF
    5038             : 
    5039             :             DEALLOCATE (der_blk)
    5040             :          END DO !iter
    5041             :          CALL dbt_iterator_stop(iter)
    5042             : 
    5043             :       END DO !i_xyz
    5044             : !$OMP END PARALLEL
    5045        2895 :       CALL timestop(handle)
    5046             : 
    5047        5790 :    END SUBROUTINE get_2c_der_force
    5048             : 
    5049             : ! **************************************************************************************************
    5050             : !> \brief This routines calculates the force contribution from a trace over 3D tensors, i.e.
    5051             : !>        force = sum_ijk A_ijk B_ijk., the B tensor is (P^0| sigma^0 lambda^img), with P in the
    5052             : !>        extended RI basis. Note that all tensors are stacked along the 3rd dimension
    5053             : !> \param force ...
    5054             : !> \param t_3c_contr ...
    5055             : !> \param t_3c_der_1 ...
    5056             : !> \param t_3c_der_2 ...
    5057             : !> \param atom_of_kind ...
    5058             : !> \param kind_of ...
    5059             : !> \param idx_to_at_RI ...
    5060             : !> \param idx_to_at_AO ...
    5061             : !> \param i_images ...
    5062             : !> \param lb_img ...
    5063             : !> \param pref ...
    5064             : !> \param ri_data ...
    5065             : !> \param qs_env ...
    5066             : !> \param work_virial ...
    5067             : !> \param cell ...
    5068             : !> \param particle_set ...
    5069             : ! **************************************************************************************************
    5070        1666 :    SUBROUTINE get_force_from_3c_trace(force, t_3c_contr, t_3c_der_1, t_3c_der_2, atom_of_kind, kind_of, &
    5071        3332 :                                       idx_to_at_RI, idx_to_at_AO, i_images, lb_img, pref, &
    5072             :                                       ri_data, qs_env, work_virial, cell, particle_set)
    5073             : 
    5074             :       TYPE(qs_force_type), DIMENSION(:), POINTER         :: force
    5075             :       TYPE(dbt_type), INTENT(INOUT)                      :: t_3c_contr
    5076             :       TYPE(dbt_type), DIMENSION(3), INTENT(INOUT)        :: t_3c_der_1, t_3c_der_2
    5077             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_of_kind, kind_of, idx_to_at_RI, &
    5078             :                                                             idx_to_at_AO, i_images
    5079             :       INTEGER, INTENT(IN)                                :: lb_img
    5080             :       REAL(dp), INTENT(IN)                               :: pref
    5081             :       TYPE(hfx_ri_type), INTENT(INOUT)                   :: ri_data
    5082             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    5083             :       REAL(dp), DIMENSION(3, 3), INTENT(INOUT), OPTIONAL :: work_virial
    5084             :       TYPE(cell_type), OPTIONAL, POINTER                 :: cell
    5085             :       TYPE(particle_type), DIMENSION(:), OPTIONAL, &
    5086             :          POINTER                                         :: particle_set
    5087             : 
    5088             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'get_force_from_3c_trace'
    5089             : 
    5090             :       INTEGER :: handle, i_RI, i_xyz, iat, iat_of_kind, idx, ikind, j_xyz, jat, jat_of_kind, &
    5091             :          jkind, kat, kat_of_kind, kkind, nblks_AO, nblks_RI, RI_img
    5092             :       INTEGER, DIMENSION(3)                              :: ind
    5093        1666 :       INTEGER, DIMENSION(:, :), POINTER                  :: index_to_cell
    5094             :       LOGICAL                                            :: found, found_1, found_2, use_virial
    5095             :       REAL(dp)                                           :: new_force
    5096        1666 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :), TARGET  :: contr_blk, der_blk_1, der_blk_2, &
    5097        1666 :                                                             der_blk_3
    5098             :       REAL(dp), DIMENSION(3)                             :: scoord
    5099             :       TYPE(dbt_iterator_type)                            :: iter
    5100             :       TYPE(kpoint_type), POINTER                         :: kpoints
    5101             : 
    5102        1666 :       NULLIFY (kpoints, index_to_cell)
    5103             : 
    5104        1666 :       CALL timeset(routineN, handle)
    5105             : 
    5106        1666 :       CALL get_qs_env(qs_env, kpoints=kpoints)
    5107        1666 :       CALL get_kpoint_info(kpoints, index_to_cell=index_to_cell)
    5108             : 
    5109        1666 :       nblks_RI = SIZE(ri_data%bsizes_RI_split)
    5110        1666 :       nblks_AO = SIZE(ri_data%bsizes_AO_split)
    5111             : 
    5112        1666 :       use_virial = .FALSE.
    5113        1666 :       IF (PRESENT(work_virial) .AND. PRESENT(cell) .AND. PRESENT(particle_set)) use_virial = .TRUE.
    5114             : 
    5115             : !$OMP PARALLEL DEFAULT(NONE) &
    5116             : !$OMP SHARED(t_3c_der_1, t_3c_der_2,t_3c_contr,work_virial,force,use_virial,index_to_cell,i_images,lb_img) &
    5117             : !$OMP SHARED(pref,idx_to_at_AO,atom_of_kind,kind_of,particle_set,cell,idx_to_at_RI,ri_data,nblks_RI,nblks_AO) &
    5118             : !$OMP PRIVATE(i_xyz,j_xyz,iter,ind,der_blk_1,contr_blk,found,new_force,iat,iat_of_kind,ikind,scoord) &
    5119        1666 : !$OMP PRIVATE(jat,kat,jat_of_kind,kat_of_kind,jkind,kkind,i_RI,RI_img,der_blk_2,der_blk_3,found_1,found_2,idx)
    5120             :       CALL dbt_iterator_start(iter, t_3c_contr)
    5121             :       DO WHILE (dbt_iterator_blocks_left(iter))
    5122             :          CALL dbt_iterator_next_block(iter, ind)
    5123             : 
    5124             :          CALL dbt_get_block(t_3c_contr, ind, contr_blk, found)
    5125             :          IF (found) THEN
    5126             : 
    5127             :             DO i_xyz = 1, 3
    5128             :                CALL dbt_get_block(t_3c_der_1(i_xyz), ind, der_blk_1, found_1)
    5129             :                IF (.NOT. found_1) THEN
    5130             :                   DEALLOCATE (der_blk_1)
    5131             :                   ALLOCATE (der_blk_1(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5132             :                   der_blk_1(:, :, :) = 0.0_dp
    5133             :                END IF
    5134             :                CALL dbt_get_block(t_3c_der_2(i_xyz), ind, der_blk_2, found_2)
    5135             :                IF (.NOT. found_2) THEN
    5136             :                   DEALLOCATE (der_blk_2)
    5137             :                   ALLOCATE (der_blk_2(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5138             :                   der_blk_2(:, :, :) = 0.0_dp
    5139             :                END IF
    5140             : 
    5141             :                ALLOCATE (der_blk_3(SIZE(contr_blk, 1), SIZE(contr_blk, 2), SIZE(contr_blk, 3)))
    5142             :                der_blk_3(:, :, :) = -(der_blk_1(:, :, :) + der_blk_2(:, :, :))
    5143             : 
    5144             :                !We assume the tensors are in the format (P^0| sigma^0 mu^a+c-b), with P a member of the
    5145             :                !extended RI basis set
    5146             : 
    5147             :                !Force for the first center (RI extended basis, zero cell)
    5148             :                new_force = pref*SUM(der_blk_1(:, :, :)*contr_blk(:, :, :))
    5149             : 
    5150             :                i_RI = (ind(1) - 1)/nblks_RI + 1
    5151             :                RI_img = ri_data%RI_cell_to_img(i_RI)
    5152             :                iat = idx_to_at_RI(ind(1) - (i_RI - 1)*nblks_RI)
    5153             :                iat_of_kind = atom_of_kind(iat)
    5154             :                ikind = kind_of(iat)
    5155             : 
    5156             : !$OMP ATOMIC
    5157             :                force(ikind)%fock_4c(i_xyz, iat_of_kind) = force(ikind)%fock_4c(i_xyz, iat_of_kind) &
    5158             :                                                           + new_force
    5159             : 
    5160             :                IF (use_virial) THEN
    5161             : 
    5162             :                   CALL real_to_scaled(scoord, pbc(particle_set(iat)%r, cell), cell)
    5163             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, RI_img), dp)
    5164             : 
    5165             :                   DO j_xyz = 1, 3
    5166             : !$OMP ATOMIC
    5167             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5168             :                   END DO
    5169             :                END IF
    5170             : 
    5171             :                !Force with respect to the second center (AO basis, zero cell)
    5172             :                new_force = pref*SUM(der_blk_2(:, :, :)*contr_blk(:, :, :))
    5173             :                jat = idx_to_at_AO(ind(2))
    5174             :                jat_of_kind = atom_of_kind(jat)
    5175             :                jkind = kind_of(jat)
    5176             : 
    5177             : !$OMP ATOMIC
    5178             :                force(jkind)%fock_4c(i_xyz, jat_of_kind) = force(jkind)%fock_4c(i_xyz, jat_of_kind) &
    5179             :                                                           + new_force
    5180             : 
    5181             :                IF (use_virial) THEN
    5182             : 
    5183             :                   CALL real_to_scaled(scoord, pbc(particle_set(jat)%r, cell), cell)
    5184             : 
    5185             :                   DO j_xyz = 1, 3
    5186             : !$OMP ATOMIC
    5187             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5188             :                   END DO
    5189             :                END IF
    5190             : 
    5191             :                !Force with respect to the third center (AO basis, apc_img - b_img)
    5192             :                !Note: tensors are stacked along the 3rd direction
    5193             :                new_force = pref*SUM(der_blk_3(:, :, :)*contr_blk(:, :, :))
    5194             :                idx = (ind(3) - 1)/nblks_AO + 1
    5195             :                kat = idx_to_at_AO(ind(3) - (idx - 1)*nblks_AO)
    5196             :                kat_of_kind = atom_of_kind(kat)
    5197             :                kkind = kind_of(kat)
    5198             : 
    5199             : !$OMP ATOMIC
    5200             :                force(kkind)%fock_4c(i_xyz, kat_of_kind) = force(kkind)%fock_4c(i_xyz, kat_of_kind) &
    5201             :                                                           + new_force
    5202             : 
    5203             :                IF (use_virial) THEN
    5204             :                   CALL real_to_scaled(scoord, pbc(particle_set(kat)%r, cell), cell)
    5205             :                   scoord(:) = scoord(:) + REAL(index_to_cell(:, i_images(lb_img - 1 + idx)), dp)
    5206             : 
    5207             :                   DO j_xyz = 1, 3
    5208             : !$OMP ATOMIC
    5209             :                      work_virial(i_xyz, j_xyz) = work_virial(i_xyz, j_xyz) + new_force*scoord(j_xyz)
    5210             :                   END DO
    5211             :                END IF
    5212             : 
    5213             :                DEALLOCATE (der_blk_1, der_blk_2, der_blk_3)
    5214             :             END DO !i_xyz
    5215             :             DEALLOCATE (contr_blk)
    5216             :          END IF !found
    5217             :       END DO !iter
    5218             :       CALL dbt_iterator_stop(iter)
    5219             : !$OMP END PARALLEL
    5220        1666 :       CALL timestop(handle)
    5221             : 
    5222        3332 :    END SUBROUTINE get_force_from_3c_trace
    5223             : 
    5224             : END MODULE

Generated by: LCOV version 1.15