LCOV - code coverage report
Current view: top level - src - xas_tdp_correction.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 655 661 99.1 %
Date: 2024-12-21 06:28:57 Functions: 11 11 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Second order perturbation correction to XAS_TDP spectra (i.e. shift)
      10             : !> \author A. Bussy (01.2020)
      11             : ! **************************************************************************************************
      12             : 
      13             : MODULE xas_tdp_correction
      14             :    USE admm_types,                      ONLY: admm_type
      15             :    USE admm_utils,                      ONLY: admm_correct_for_eigenvalues
      16             :    USE bibliography,                    ONLY: Bussy2021b,&
      17             :                                               Shigeta2001,&
      18             :                                               cite_reference
      19             :    USE cp_array_utils,                  ONLY: cp_1d_i_p_type,&
      20             :                                               cp_1d_r_p_type
      21             :    USE cp_blacs_env,                    ONLY: cp_blacs_env_type
      22             :    USE cp_cfm_types,                    ONLY: cp_cfm_create,&
      23             :                                               cp_cfm_get_submatrix,&
      24             :                                               cp_cfm_release,&
      25             :                                               cp_cfm_type,&
      26             :                                               cp_fm_to_cfm
      27             :    USE cp_control_types,                ONLY: dft_control_type
      28             :    USE cp_dbcsr_api,                    ONLY: &
      29             :         dbcsr_copy, dbcsr_create, dbcsr_distribution_get, dbcsr_distribution_new, &
      30             :         dbcsr_distribution_release, dbcsr_distribution_type, dbcsr_get_info, dbcsr_p_type, &
      31             :         dbcsr_release, dbcsr_type
      32             :    USE cp_dbcsr_operations,             ONLY: copy_fm_to_dbcsr,&
      33             :                                               cp_dbcsr_sm_fm_multiply,&
      34             :                                               dbcsr_deallocate_matrix_set
      35             :    USE cp_fm_diag,                      ONLY: choose_eigv_solver
      36             :    USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
      37             :                                               cp_fm_struct_p_type,&
      38             :                                               cp_fm_struct_release,&
      39             :                                               cp_fm_struct_type
      40             :    USE cp_fm_types,                     ONLY: cp_fm_create,&
      41             :                                               cp_fm_get_diag,&
      42             :                                               cp_fm_get_submatrix,&
      43             :                                               cp_fm_release,&
      44             :                                               cp_fm_to_fm,&
      45             :                                               cp_fm_to_fm_submat,&
      46             :                                               cp_fm_type
      47             :    USE cp_log_handling,                 ONLY: cp_logger_get_default_io_unit
      48             :    USE dbt_api,                         ONLY: &
      49             :         dbt_contract, dbt_copy, dbt_copy_matrix_to_tensor, dbt_create, dbt_default_distvec, &
      50             :         dbt_destroy, dbt_distribution_destroy, dbt_distribution_new, dbt_distribution_type, &
      51             :         dbt_finalize, dbt_get_block, dbt_get_info, dbt_iterator_blocks_left, &
      52             :         dbt_iterator_next_block, dbt_iterator_start, dbt_iterator_stop, dbt_iterator_type, &
      53             :         dbt_pgrid_create, dbt_pgrid_destroy, dbt_pgrid_type, dbt_put_block, dbt_type
      54             :    USE hfx_admm_utils,                  ONLY: create_admm_xc_section
      55             :    USE input_section_types,             ONLY: section_vals_create,&
      56             :                                               section_vals_get_subs_vals,&
      57             :                                               section_vals_release,&
      58             :                                               section_vals_retain,&
      59             :                                               section_vals_set_subs_vals,&
      60             :                                               section_vals_type
      61             :    USE kinds,                           ONLY: dp
      62             :    USE machine,                         ONLY: m_flush
      63             :    USE mathlib,                         ONLY: complex_diag
      64             :    USE message_passing,                 ONLY: mp_para_env_type
      65             :    USE parallel_gemm_api,               ONLY: parallel_gemm
      66             :    USE physcon,                         ONLY: evolt
      67             :    USE qs_environment_types,            ONLY: get_qs_env,&
      68             :                                               qs_environment_type
      69             :    USE qs_ks_methods,                   ONLY: qs_ks_build_kohn_sham_matrix
      70             :    USE qs_mo_types,                     ONLY: deallocate_mo_set,&
      71             :                                               duplicate_mo_set,&
      72             :                                               get_mo_set,&
      73             :                                               mo_set_type,&
      74             :                                               reassign_allocated_mos
      75             :    USE util,                            ONLY: get_limit
      76             :    USE xas_tdp_kernel,                  ONLY: contract2_AO_to_doMO,&
      77             :                                               ri_all_blocks_mm
      78             :    USE xas_tdp_types,                   ONLY: donor_state_type,&
      79             :                                               xas_tdp_control_type,&
      80             :                                               xas_tdp_env_type
      81             : 
      82             : !$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num
      83             : #include "./base/base_uses.f90"
      84             : 
      85             :    IMPLICIT NONE
      86             :    PRIVATE
      87             : 
      88             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'xas_tdp_correction'
      89             : 
      90             :    PUBLIC :: gw2x_shift, get_soc_splitting
      91             : 
      92             : CONTAINS
      93             : 
      94             : ! **************************************************************************************************
      95             : !> \brief Computes the ionization potential using the GW2X method of Shigeta et. al. The result cam
      96             : !>        be used for XAS correction (shift) or XPS directly.
      97             : !> \param donor_state ...
      98             : !> \param xas_tdp_env ...
      99             : !> \param xas_tdp_control ...
     100             : !> \param qs_env ...
     101             : ! **************************************************************************************************
     102          30 :    SUBROUTINE GW2X_shift(donor_state, xas_tdp_env, xas_tdp_control, qs_env)
     103             : 
     104             :       TYPE(donor_state_type), POINTER                    :: donor_state
     105             :       TYPE(xas_tdp_env_type), POINTER                    :: xas_tdp_env
     106             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
     107             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     108             : 
     109             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'GW2X_shift'
     110             : 
     111             :       INTEGER :: ex_idx, exat, first_domo(2), handle, i, ido_mo, iloc, ilocat, ispin, jspin, &
     112             :          locat, nao, natom, ndo_mo, nhomo(2), nlumo(2), nonloc, nspins, start_sgf
     113          30 :       INTEGER, DIMENSION(:), POINTER                     :: nsgf_blk
     114             :       LOGICAL                                            :: pseudo_canonical
     115             :       REAL(dp)                                           :: og_hfx_frac
     116          30 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: contract_coeffs_backup
     117             :       TYPE(admm_type), POINTER                           :: admm_env
     118          30 :       TYPE(cp_1d_r_p_type), ALLOCATABLE, DIMENSION(:)    :: homo_evals, lumo_evals
     119             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
     120             :       TYPE(cp_fm_struct_p_type), ALLOCATABLE, &
     121          30 :          DIMENSION(:)                                    :: all_struct, homo_struct, lumo_struct
     122             :       TYPE(cp_fm_struct_type), POINTER                   :: hoho_struct, lulu_struct
     123             :       TYPE(cp_fm_type)                                   :: hoho_fock, hoho_work, homo_work, &
     124             :                                                             lulu_fock, lulu_work, lumo_work
     125          30 :       TYPE(cp_fm_type), ALLOCATABLE, DIMENSION(:)        :: all_coeffs, homo_coeffs, lumo_coeffs
     126             :       TYPE(cp_fm_type), POINTER                          :: mo_coeff
     127          30 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: dbcsr_work, fock_matrix, matrix_ks
     128          30 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: ja_X, oI_Y
     129          30 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :)       :: mo_template
     130             :       TYPE(dft_control_type), POINTER                    :: dft_control
     131          30 :       TYPE(mo_set_type), DIMENSION(:), POINTER           :: mos
     132             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     133             :       TYPE(section_vals_type), POINTER                   :: xc_fun_empty, xc_fun_original, xc_section
     134             : 
     135          30 :       NULLIFY (xc_fun_empty, xc_fun_original, xc_section, mos, dft_control, dbcsr_work, &
     136          30 :                fock_matrix, matrix_ks, para_env, mo_coeff, blacs_env, nsgf_blk)
     137             : 
     138          30 :       CALL cite_reference(Shigeta2001)
     139          30 :       CALL cite_reference(Bussy2021b)
     140             : 
     141          30 :       CALL timeset(routineN, handle)
     142             : 
     143             :       !The GW2X correction we want to compute goes like this, where omega is the corrected epsilon_I:
     144             :       !omega = eps_I + 0.5 * sum_ajk |<Ia||jk>|^2/(omega + eps_a - eps_j - eps_k)
     145             :       !              + 0.5 * sum_jab |<Ij||ab>|^2/(omega + eps_j - eps_a - eps_b)
     146             :       ! j,k denote occupied spin-orbitals and a,b denote virtual spin orbitals
     147             : 
     148             :       !The strategy is the following (we assume restricted closed-shell):
     149             :       !1) Get the LUMOs from xas_tdp_env
     150             :       !2) Get the HOMOs from qs_env
     151             :       !3) Compute or fetch the generalize Fock matric
     152             :       !4) Diagonalize it in the subspace of HOMOs and LUMOs (or just take diagonal matrix elements)
     153             :       !5) Build the full HOMO-LUMO basis that we will use and compute eigenvalues
     154             :       !6) Iterate over GW2X steps to compute the self energy
     155             : 
     156             :       !We implement 2 approaches => diagonal elements of Fock matrix with original MOs and
     157             :       !pseudo-canonical MOs
     158          30 :       pseudo_canonical = xas_tdp_control%pseudo_canonical
     159             : 
     160             :       !Get donor state info
     161          30 :       ndo_mo = donor_state%ndo_mo
     162          30 :       nspins = 1; IF (xas_tdp_control%do_uks .OR. xas_tdp_control%do_roks) nspins = 2
     163             : 
     164             :       !1) Get the LUMO coefficients from the xas_tdp_env, that have been precomputed
     165             : 
     166             :       CALL get_qs_env(qs_env, matrix_ks=matrix_ks, mos=mos, para_env=para_env, &
     167          30 :                       blacs_env=blacs_env, natom=natom)
     168             : 
     169         158 :       ALLOCATE (lumo_struct(nspins), lumo_coeffs(nspins))
     170             : 
     171          64 :       DO ispin = 1, nspins
     172          34 :          CALL get_mo_set(mos(ispin), homo=nhomo(ispin), nao=nao)
     173          34 :          nlumo(ispin) = nao - nhomo(ispin)
     174             : 
     175             :          CALL cp_fm_struct_create(lumo_struct(ispin)%struct, para_env=para_env, context=blacs_env, &
     176          34 :                                   ncol_global=nlumo(ispin), nrow_global=nao)
     177             : 
     178          34 :          CALL cp_fm_create(lumo_coeffs(ispin), lumo_struct(ispin)%struct)
     179          98 :          CALL cp_fm_to_fm(xas_tdp_env%lumo_evecs(ispin), lumo_coeffs(ispin))
     180             :       END DO
     181             : 
     182             :       !2) get the HOMO coeffs. Reminder: keep all non-localized MOs + those localized on core atom
     183             :       !   For this to work, it is assumed that the LOCALIZE keyword is used
     184         188 :       ALLOCATE (homo_struct(nspins), homo_coeffs(nspins))
     185             : 
     186          64 :       DO ispin = 1, nspins
     187          34 :          nonloc = nhomo(ispin) - xas_tdp_control%n_search
     188          34 :          exat = donor_state%at_index
     189          80 :          ex_idx = MINLOC(ABS(xas_tdp_env%ex_atom_indices - exat), 1)
     190         156 :          locat = COUNT(xas_tdp_env%mos_of_ex_atoms(:, ex_idx, ispin) == 1)
     191             : 
     192             :          CALL cp_fm_struct_create(homo_struct(ispin)%struct, para_env=para_env, context=blacs_env, &
     193          34 :                                   ncol_global=locat + nonloc, nrow_global=nao)
     194          34 :          CALL cp_fm_create(homo_coeffs(ispin), homo_struct(ispin)%struct)
     195             : 
     196          34 :          CALL get_mo_set(mos(ispin), mo_coeff=mo_coeff)
     197             :          CALL cp_fm_to_fm_submat(mo_coeff, homo_coeffs(ispin), nrow=nao, ncol=nonloc, s_firstrow=1, &
     198          34 :                                  s_firstcol=xas_tdp_control%n_search + 1, t_firstrow=1, t_firstcol=locat + 1)
     199             : 
     200             :          !this bit is taken from xas_tdp_methods
     201          34 :          ilocat = 1
     202         156 :          DO iloc = 1, xas_tdp_control%n_search
     203         122 :             IF (xas_tdp_env%mos_of_ex_atoms(iloc, ex_idx, ispin) == -1) CYCLE
     204             :             CALL cp_fm_to_fm_submat(mo_coeff, homo_coeffs(ispin), nrow=nao, ncol=1, s_firstrow=1, &
     205          82 :                                     s_firstcol=iloc, t_firstrow=1, t_firstcol=ilocat)
     206             :             !keep track of donor MO index
     207          82 :             IF (iloc == donor_state%mo_indices(1, ispin)) first_domo(ispin) = ilocat !first donor MO
     208             : 
     209         156 :             ilocat = ilocat + 1
     210             :          END DO
     211          64 :          nhomo(ispin) = locat + nonloc
     212             :       END DO
     213             : 
     214             :       !3) Computing the generalized Fock Matrix, if not there already
     215          30 :       IF (ASSOCIATED(xas_tdp_env%fock_matrix)) THEN
     216          12 :          fock_matrix => xas_tdp_env%fock_matrix
     217             :       ELSE
     218          18 :          BLOCK
     219          18 :             TYPE(mo_set_type), DIMENSION(:), ALLOCATABLE :: backup_mos
     220             : 
     221          56 :             ALLOCATE (xas_tdp_env%fock_matrix(nspins))
     222          18 :             fock_matrix => xas_tdp_env%fock_matrix
     223             : 
     224             :             ! remove the xc_functionals and set HF fraction to 1
     225          18 :             xc_section => section_vals_get_subs_vals(qs_env%input, "DFT%XC")
     226          18 :             xc_fun_original => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
     227          18 :             CALL section_vals_retain(xc_fun_original)
     228          18 :             CALL section_vals_create(xc_fun_empty, xc_fun_original%section)
     229          18 :             CALL section_vals_set_subs_vals(xc_section, "XC_FUNCTIONAL", xc_fun_empty)
     230          18 :             CALL section_vals_release(xc_fun_empty)
     231          18 :             og_hfx_frac = qs_env%x_data(1, 1)%general_parameter%fraction
     232          54 :             qs_env%x_data(:, :)%general_parameter%fraction = 1.0_dp
     233             : 
     234             :             !In case of ADMM, we need to re-create the admm XC section for the new hfx_fraction
     235             :             !We also need to make a backup of the MOs as theiy are modified
     236          18 :             CALL get_qs_env(qs_env, dft_control=dft_control, admm_env=admm_env)
     237          18 :             IF (dft_control%do_admm) THEN
     238           2 :                IF (ASSOCIATED(admm_env%xc_section_primary)) CALL section_vals_release(admm_env%xc_section_primary)
     239           2 :                IF (ASSOCIATED(admm_env%xc_section_aux)) CALL section_vals_release(admm_env%xc_section_aux)
     240           2 :                CALL create_admm_xc_section(qs_env%x_data, xc_section, admm_env)
     241             : 
     242          10 :                ALLOCATE (backup_mos(SIZE(mos)))
     243           4 :                DO i = 1, SIZE(mos)
     244           4 :                   CALL duplicate_mo_set(backup_mos(i), mos(i))
     245             :                END DO
     246             :             END IF
     247             : 
     248          56 :             ALLOCATE (dbcsr_work(nspins))
     249          38 :             DO ispin = 1, nspins
     250          20 :                ALLOCATE (dbcsr_work(ispin)%matrix)
     251          38 :                CALL dbcsr_copy(dbcsr_work(ispin)%matrix, matrix_ks(ispin)%matrix)
     252             :             END DO
     253             : 
     254             :             !both spins treated internally
     255          18 :             CALL qs_ks_build_kohn_sham_matrix(qs_env, calculate_forces=.FALSE., just_energy=.FALSE.)
     256             : 
     257          38 :             DO ispin = 1, nspins
     258          20 :                ALLOCATE (fock_matrix(ispin)%matrix)
     259          20 :                CALL dbcsr_copy(fock_matrix(ispin)%matrix, matrix_ks(ispin)%matrix, name="FOCK MATRIX")
     260          20 :                CALL dbcsr_release(matrix_ks(ispin)%matrix)
     261          38 :                CALL dbcsr_copy(matrix_ks(ispin)%matrix, dbcsr_work(ispin)%matrix)
     262             :             END DO
     263          18 :             CALL dbcsr_deallocate_matrix_set(dbcsr_work)
     264             : 
     265             :             !In case of ADMM, we want to correct for eigenvalues
     266          18 :             IF (dft_control%do_admm) THEN
     267           4 :                DO ispin = 1, nspins
     268           4 :                   CALL admm_correct_for_eigenvalues(ispin, admm_env, fock_matrix(ispin)%matrix)
     269             :                END DO
     270             :             END IF
     271             : 
     272             :             !restore xc and HF fraction
     273          18 :             CALL section_vals_set_subs_vals(xc_section, "XC_FUNCTIONAL", xc_fun_original)
     274          18 :             CALL section_vals_release(xc_fun_original)
     275          54 :             qs_env%x_data(:, :)%general_parameter%fraction = og_hfx_frac
     276             : 
     277          20 :             IF (dft_control%do_admm) THEN
     278           2 :                IF (ASSOCIATED(admm_env%xc_section_primary)) CALL section_vals_release(admm_env%xc_section_primary)
     279           2 :                IF (ASSOCIATED(admm_env%xc_section_aux)) CALL section_vals_release(admm_env%xc_section_aux)
     280           2 :                CALL create_admm_xc_section(qs_env%x_data, xc_section, admm_env)
     281             : 
     282           4 :                DO i = 1, SIZE(mos)
     283           2 :                   CALL reassign_allocated_mos(mos(i), backup_mos(i))
     284           4 :                   CALL deallocate_mo_set(backup_mos(i))
     285             :                END DO
     286           2 :                DEALLOCATE (backup_mos)
     287             :             END IF
     288             :          END BLOCK
     289             :       END IF
     290             : 
     291             :       !4,5) Build pseudo-canonical MOs if needed + get related Fock matrix elements
     292         188 :       ALLOCATE (all_struct(nspins), all_coeffs(nspins))
     293         158 :       ALLOCATE (homo_evals(nspins), lumo_evals(nspins))
     294          30 :       CALL dbcsr_get_info(matrix_ks(1)%matrix, row_blk_size=nsgf_blk)
     295         120 :       ALLOCATE (contract_coeffs_backup(nsgf_blk(exat), nspins*ndo_mo))
     296             : 
     297          64 :       DO ispin = 1, nspins
     298             :          CALL cp_fm_struct_create(hoho_struct, para_env=para_env, context=blacs_env, &
     299          34 :                                   ncol_global=nhomo(ispin), nrow_global=nhomo(ispin))
     300             :          CALL cp_fm_struct_create(lulu_struct, para_env=para_env, context=blacs_env, &
     301          34 :                                   ncol_global=nlumo(ispin), nrow_global=nlumo(ispin))
     302             : 
     303          34 :          CALL cp_fm_create(hoho_work, hoho_struct)
     304          34 :          CALL cp_fm_create(lulu_work, lulu_struct)
     305          34 :          CALL cp_fm_create(homo_work, homo_struct(ispin)%struct)
     306          34 :          CALL cp_fm_create(lumo_work, lumo_struct(ispin)%struct)
     307             : 
     308          34 :          IF (pseudo_canonical) THEN
     309             :             !That is where we rotate the MOs to make them pseudo canonical
     310             :             !The eigenvalues we get from the diagonalization
     311             : 
     312             :             !The Fock matrix in the HOMO subspace
     313          32 :             CALL cp_fm_create(hoho_fock, hoho_struct)
     314          32 :             NULLIFY (homo_evals(ispin)%array)
     315          96 :             ALLOCATE (homo_evals(ispin)%array(nhomo(ispin)))
     316             :             CALL cp_dbcsr_sm_fm_multiply(fock_matrix(ispin)%matrix, homo_coeffs(ispin), &
     317          32 :                                          homo_work, ncol=nhomo(ispin))
     318             :             CALL parallel_gemm('T', 'N', nhomo(ispin), nhomo(ispin), nao, 1.0_dp, homo_coeffs(ispin), &
     319          32 :                                homo_work, 0.0_dp, hoho_fock)
     320             : 
     321             :             !diagonalize and get pseudo-canonical MOs
     322          32 :             CALL choose_eigv_solver(hoho_fock, hoho_work, homo_evals(ispin)%array)
     323             :             CALL parallel_gemm('N', 'N', nao, nhomo(ispin), nhomo(ispin), 1.0_dp, homo_coeffs(ispin), &
     324          32 :                                hoho_work, 0.0_dp, homo_work)
     325          32 :             CALL cp_fm_to_fm(homo_work, homo_coeffs(ispin))
     326             : 
     327             :             !overwrite the donor_state's contract coeffs with those
     328             :             contract_coeffs_backup(:, (ispin - 1)*ndo_mo + 1:ispin*ndo_mo) = &
     329        1052 :                donor_state%contract_coeffs(:, (ispin - 1)*ndo_mo + 1:ispin*ndo_mo)
     330          46 :             start_sgf = SUM(nsgf_blk(1:exat - 1)) + 1
     331             :             CALL cp_fm_get_submatrix(homo_coeffs(ispin), &
     332             :                                      donor_state%contract_coeffs(:, (ispin - 1)*ndo_mo + 1:ispin*ndo_mo), &
     333             :                                      start_row=start_sgf, start_col=first_domo(ispin), &
     334          32 :                                      n_rows=nsgf_blk(exat), n_cols=ndo_mo)
     335             : 
     336             :             !do the same for the pseudo-LUMOs
     337          32 :             CALL cp_fm_create(lulu_fock, lulu_struct)
     338          32 :             NULLIFY (lumo_evals(ispin)%array)
     339          96 :             ALLOCATE (lumo_evals(ispin)%array(nlumo(ispin)))
     340             :             CALL cp_dbcsr_sm_fm_multiply(fock_matrix(ispin)%matrix, lumo_coeffs(ispin), &
     341          32 :                                          lumo_work, ncol=nlumo(ispin))
     342             :             CALL parallel_gemm('T', 'N', nlumo(ispin), nlumo(ispin), nao, 1.0_dp, lumo_coeffs(ispin), &
     343          32 :                                lumo_work, 0.0_dp, lulu_fock)
     344             : 
     345             :             !diagonalize and get pseudo-canonical MOs
     346          32 :             CALL choose_eigv_solver(lulu_fock, lulu_work, lumo_evals(ispin)%array)
     347             :             CALL parallel_gemm('N', 'N', nao, nlumo(ispin), nlumo(ispin), 1.0_dp, lumo_coeffs(ispin), &
     348          32 :                                lulu_work, 0.0_dp, lumo_work)
     349          32 :             CALL cp_fm_to_fm(lumo_work, lumo_coeffs(ispin))
     350             : 
     351          32 :             CALL cp_fm_release(lulu_fock)
     352          96 :             CALL cp_fm_release(hoho_fock)
     353             : 
     354             :          ELSE !using the generalized Fock matrix diagonal elements
     355             : 
     356             :             !Compute their Fock matrix diagonal
     357           6 :             ALLOCATE (homo_evals(ispin)%array(nhomo(ispin)))
     358             :             CALL cp_dbcsr_sm_fm_multiply(fock_matrix(ispin)%matrix, homo_coeffs(ispin), &
     359           2 :                                          homo_work, ncol=nhomo(ispin))
     360             :             CALL parallel_gemm('T', 'N', nhomo(ispin), nhomo(ispin), nao, 1.0_dp, homo_coeffs(ispin), &
     361           2 :                                homo_work, 0.0_dp, hoho_work)
     362           2 :             CALL cp_fm_get_diag(hoho_work, homo_evals(ispin)%array)
     363             : 
     364           6 :             ALLOCATE (lumo_evals(ispin)%array(nlumo(ispin)))
     365             :             CALL cp_dbcsr_sm_fm_multiply(fock_matrix(ispin)%matrix, lumo_coeffs(ispin), &
     366           2 :                                          lumo_work, ncol=nlumo(ispin))
     367             :             CALL parallel_gemm('T', 'N', nlumo(ispin), nlumo(ispin), nao, 1.0_dp, lumo_coeffs(ispin), &
     368           2 :                                lumo_work, 0.0_dp, lulu_work)
     369           2 :             CALL cp_fm_get_diag(lulu_work, lumo_evals(ispin)%array)
     370             : 
     371             :          END IF
     372          34 :          CALL cp_fm_release(homo_work)
     373          34 :          CALL cp_fm_release(hoho_work)
     374          34 :          CALL cp_fm_struct_release(hoho_struct)
     375          34 :          CALL cp_fm_release(lumo_work)
     376          34 :          CALL cp_fm_release(lulu_work)
     377          34 :          CALL cp_fm_struct_release(lulu_struct)
     378             : 
     379             :          !Put back homo and lumo coeffs together, to fit tensor structure
     380             :          CALL cp_fm_struct_create(all_struct(ispin)%struct, para_env=para_env, context=blacs_env, &
     381          34 :                                   ncol_global=nhomo(ispin) + nlumo(ispin), nrow_global=nao)
     382          34 :          CALL cp_fm_create(all_coeffs(ispin), all_struct(ispin)%struct)
     383             :          CALL cp_fm_to_fm(homo_coeffs(ispin), all_coeffs(ispin), ncol=nhomo(ispin), &
     384          34 :                           source_start=1, target_start=1)
     385             :          CALL cp_fm_to_fm(lumo_coeffs(ispin), all_coeffs(ispin), ncol=nlumo(ispin), &
     386          98 :                           source_start=1, target_start=nhomo(ispin) + 1)
     387             : 
     388             :       END DO !ispin
     389             : 
     390             :       !get semi-contracted tensor (AOs to MOs, keep RI uncontracted)
     391             :       CALL contract_AOs_to_MOs(ja_X, oI_Y, mo_template, all_coeffs, nhomo, nlumo, &
     392          30 :                                donor_state, xas_tdp_env, xas_tdp_control, qs_env)
     393             : 
     394             :       !intermediate clean-up
     395          64 :       DO ispin = 1, nspins
     396          34 :          CALL cp_fm_release(all_coeffs(ispin))
     397          34 :          CALL cp_fm_release(homo_coeffs(ispin))
     398          34 :          CALL cp_fm_release(lumo_coeffs(ispin))
     399          34 :          CALL cp_fm_struct_release(all_struct(ispin)%struct)
     400          34 :          CALL cp_fm_struct_release(lumo_struct(ispin)%struct)
     401          64 :          CALL cp_fm_struct_release(homo_struct(ispin)%struct)
     402             :       END DO
     403             : 
     404             :       !6) GW2X iterations
     405             : 
     406          30 :       IF (nspins == 1) THEN
     407             :          !restricted-closed shell: only alpha spin
     408             :          CALL GW2X_rcs_iterations(first_domo(1), ja_X(1), oI_Y, mo_template(1, 1), homo_evals(1)%array, &
     409          26 :                                   lumo_evals(1)%array, donor_state, xas_tdp_control, qs_env)
     410             :       ELSE
     411             :          !open-shell, need both spins
     412             :          CALL GW2X_os_iterations(first_domo, ja_X, oI_Y, mo_template, homo_evals, lumo_evals, &
     413           4 :                                  donor_state, xas_tdp_control, qs_env)
     414             :       END IF
     415             : 
     416             :       !restore proper contract_coeffs
     417          30 :       IF (pseudo_canonical) THEN
     418        1048 :          donor_state%contract_coeffs(:, :) = contract_coeffs_backup(:, :)
     419             :       END IF
     420             : 
     421             :       !Final clean-up
     422          72 :       DO ido_mo = 1, nspins*ndo_mo
     423          72 :          CALL dbt_destroy(oI_Y(ido_mo))
     424             :       END DO
     425          64 :       DO ispin = 1, nspins
     426          34 :          CALL dbt_destroy(ja_X(ispin))
     427          34 :          DEALLOCATE (homo_evals(ispin)%array)
     428          34 :          DEALLOCATE (lumo_evals(ispin)%array)
     429         106 :          DO jspin = 1, nspins
     430          76 :             CALL dbt_destroy(mo_template(ispin, jspin))
     431             :          END DO
     432             :       END DO
     433          72 :       DEALLOCATE (oI_Y, homo_evals, lumo_evals)
     434             : 
     435          30 :       CALL timestop(handle)
     436             : 
     437         166 :    END SUBROUTINE GW2X_shift
     438             : 
     439             : ! **************************************************************************************************
     440             : !> \brief Preforms the GW2X iterations in the restricted-closed shell formalism according to the
     441             : !>        Newton-Raphson method
     442             : !> \param first_domo index of the first core donor MO to consider
     443             : !> \param ja_X semi-contracted tensor with j: occupied MO, a: virtual MO, X: RI basis element
     444             : !> \param oI_Y semi-contracted tensors with o: all MOs, I donor core MO, Y: RI basis element
     445             : !> \param mo_template tensor template for fully MO contracted tensor
     446             : !> \param homo_evals ...
     447             : !> \param lumo_evals ...
     448             : !> \param donor_state ...
     449             : !> \param xas_tdp_control ...
     450             : !> \param qs_env ...
     451             : ! **************************************************************************************************
     452          26 :    SUBROUTINE GW2X_rcs_iterations(first_domo, ja_X, oI_Y, mo_template, homo_evals, lumo_evals, &
     453             :                                   donor_state, xas_tdp_control, qs_env)
     454             : 
     455             :       INTEGER, INTENT(IN)                                :: first_domo
     456             :       TYPE(dbt_type), INTENT(inout)                      :: ja_X
     457             :       TYPE(dbt_type), DIMENSION(:), INTENT(inout)        :: oI_Y
     458             :       TYPE(dbt_type), INTENT(inout)                      :: mo_template
     459             :       REAL(dp), DIMENSION(:), INTENT(IN)                 :: homo_evals, lumo_evals
     460             :       TYPE(donor_state_type), POINTER                    :: donor_state
     461             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
     462             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     463             : 
     464             :       CHARACTER(len=*), PARAMETER :: routineN = 'GW2X_rcs_iterations'
     465             : 
     466             :       INTEGER :: batch_size, bounds_1d(2), bounds_2d(2, 2), handle, i, ibatch, ido_mo, iloop, &
     467             :          max_iter, nbatch_occ, nbatch_virt, nblk_occ, nblk_virt, nblks(3), ndo_mo, nhomo, nlumo, &
     468             :          occ_bo(2), output_unit, tmp_sum, virt_bo(2)
     469          26 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: mo_blk_size
     470             :       REAL(dp)                                           :: c_os, c_ss, dg, diff, ds1, ds2, eps_I, &
     471             :                                                             eps_iter, g, omega_k, parts(4), s1, s2
     472         858 :       TYPE(dbt_type)                                     :: aj_Ib, aj_Ib_diff, aj_X, ja_Ik, &
     473         234 :                                                             ja_Ik_diff
     474             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     475             : 
     476          26 :       CALL timeset(routineN, handle)
     477             : 
     478          26 :       eps_iter = xas_tdp_control%gw2x_eps
     479          26 :       max_iter = xas_tdp_control%max_gw2x_iter
     480          26 :       c_os = xas_tdp_control%c_os
     481          26 :       c_ss = xas_tdp_control%c_ss
     482          26 :       batch_size = xas_tdp_control%batch_size
     483             : 
     484          26 :       ndo_mo = donor_state%ndo_mo
     485          26 :       output_unit = cp_logger_get_default_io_unit()
     486             : 
     487          26 :       nhomo = SIZE(homo_evals)
     488          26 :       nlumo = SIZE(lumo_evals)
     489             : 
     490          26 :       CALL get_qs_env(qs_env, para_env=para_env)
     491             : 
     492             :       !We use the Newton-Raphson method to find the zero of the function:
     493             :       !g(omega) = eps_I - omega + mp2 terms, dg(omega) = -1 + d/d_omega (mp2 terms)
     494             :       !We simply compute at each iteration: omega_k+1 = omega_k - g(omega_k)/dg(omega_k)
     495             : 
     496             :       !need transposed tensor of (ja|X) for optimal contraction scheme (s.t. (aj|X) block is on same
     497             :       !processor as (ja|X))
     498          26 :       CALL dbt_create(ja_X, aj_X)
     499          26 :       CALL dbt_copy(ja_X, aj_X, order=[2, 1, 3])
     500             : 
     501             :       !split the MO blocks into batches for memory friendly batched contraction
     502             :       !huge dense tensors never need to be stored
     503          26 :       CALL dbt_get_info(ja_X, nblks_total=nblks)
     504          78 :       ALLOCATE (mo_blk_size(nblks(1)))
     505          26 :       CALL dbt_get_info(ja_X, blk_size_1=mo_blk_size)
     506             : 
     507          26 :       tmp_sum = 0
     508          52 :       DO i = 1, nblks(1)
     509          52 :          tmp_sum = tmp_sum + mo_blk_size(i)
     510          52 :          IF (tmp_sum == nhomo) THEN
     511          26 :             nblk_occ = i
     512          26 :             nblk_virt = nblks(1) - i
     513          26 :             EXIT
     514             :          END IF
     515             :       END DO
     516          26 :       nbatch_occ = MAX(1, nblk_occ/batch_size)
     517          26 :       nbatch_virt = MAX(1, nblk_virt/batch_size)
     518             : 
     519             :       !Loop over donor_states
     520          60 :       DO ido_mo = 1, ndo_mo
     521          34 :          IF (output_unit > 0) THEN
     522             :             WRITE (UNIT=output_unit, FMT="(/,T5,A,I2,A,I4,A,/,T5,A)") &
     523          17 :                "- GW2X correction for donor MO with spin ", 1, &
     524          17 :                " and MO index ", donor_state%mo_indices(ido_mo, 1), ":", &
     525          34 :                "                         iteration                convergence (eV)"
     526          17 :             CALL m_flush(output_unit)
     527             :          END IF
     528             : 
     529             :          !starting values
     530          34 :          eps_I = homo_evals(first_domo + ido_mo - 1)
     531          34 :          omega_k = eps_I
     532          34 :          iloop = 0
     533          34 :          diff = 2.0_dp*eps_iter
     534             : 
     535         168 :          DO WHILE (ABS(diff) > eps_iter)
     536         134 :             iloop = iloop + 1
     537             : 
     538             :             !Compute the mp2 terms and their first derivative
     539         134 :             parts = 0.0_dp
     540             : 
     541             :             !We do batched contraction for (ja|Ik) and (ja|Ib) to never have to carry the full tensor
     542         276 :             DO ibatch = 1, nbatch_occ
     543             : 
     544         142 :                occ_bo = get_limit(nblk_occ, nbatch_occ, ibatch - 1)
     545         710 :                bounds_1d = [SUM(mo_blk_size(1:occ_bo(1) - 1)) + 1, SUM(mo_blk_size(1:occ_bo(2)))]
     546             : 
     547         142 :                CALL dbt_create(mo_template, ja_Ik)
     548             :                CALL dbt_contract(alpha=1.0_dp, tensor_1=ja_X, tensor_2=oI_Y(ido_mo), &
     549             :                                  beta=0.0_dp, tensor_3=ja_Ik, contract_1=[3], &
     550             :                                  notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     551         142 :                                  map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     552             : 
     553             :                !opposite-spin contribution
     554             :                CALL calc_os_oov_contrib(parts(1), parts(2), ja_Ik, homo_evals, lumo_evals, homo_evals, &
     555         142 :                                         omega_k, c_os, nhomo)
     556             : 
     557         426 :                bounds_2d(:, 2) = bounds_1d
     558         142 :                bounds_2d(1, 1) = nhomo + 1
     559         142 :                bounds_2d(2, 1) = nhomo + nlumo
     560             : 
     561             :                !same-spin contribution. Contraction only neede if c_ss != 0
     562             :                !directly compute the difference (ja|Ik) - (ka|Ij)
     563         142 :                IF (ABS(c_ss) > EPSILON(1.0_dp)) THEN
     564             : 
     565         142 :                   CALL dbt_create(ja_Ik, ja_Ik_diff, map1_2d=[1], map2_2d=[2, 3])
     566         142 :                   CALL dbt_copy(ja_Ik, ja_Ik_diff, move_data=.TRUE.)
     567             : 
     568             :                   CALL dbt_contract(alpha=-1.0_dp, tensor_1=oI_Y(ido_mo), tensor_2=aj_X, &
     569             :                                     beta=1.0_dp, tensor_3=ja_Ik_diff, contract_1=[2], &
     570             :                                     notcontract_1=[1], contract_2=[3], notcontract_2=[1, 2], &
     571         426 :                                     map_1=[1], map_2=[2, 3], bounds_2=[1, nhomo], bounds_3=bounds_2d)
     572             : 
     573         142 :                   CALL calc_ss_oov_contrib(parts(1), parts(2), ja_Ik_diff, homo_evals, lumo_evals, omega_k, c_ss)
     574             : 
     575         142 :                   CALL dbt_destroy(ja_Ik_diff)
     576             :                END IF !c_ss != 0
     577             : 
     578         276 :                CALL dbt_destroy(ja_Ik)
     579             :             END DO
     580             : 
     581         268 :             DO ibatch = 1, nbatch_virt
     582             : 
     583         134 :                virt_bo = get_limit(nblk_virt, nbatch_virt, ibatch - 1)
     584             :                bounds_1d = [SUM(mo_blk_size(1:nblk_occ + virt_bo(1) - 1)) + 1, &
     585        1252 :                             SUM(mo_blk_size(1:nblk_occ + virt_bo(2)))]
     586             : 
     587         134 :                CALL dbt_create(mo_template, aj_Ib)
     588             :                CALL dbt_contract(alpha=1.0_dp, tensor_1=aj_X, tensor_2=oI_Y(ido_mo), &
     589             :                                  beta=0.0_dp, tensor_3=aj_Ib, contract_1=[3], &
     590             :                                  notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     591         134 :                                  map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     592             : 
     593             :                !opposite-spin contribution
     594             :                CALL calc_os_ovv_contrib(parts(3), parts(4), aj_Ib, lumo_evals, homo_evals, lumo_evals, &
     595         134 :                                         omega_k, c_os, nhomo, nhomo)
     596             : 
     597             :                !same-spin contribution, only if c_ss is not 0
     598             :                !directly compute the difference (aj|Ib) - (bj|Ia)
     599         134 :                IF (ABS(c_ss) > EPSILON(1.0_dp)) THEN
     600         134 :                   bounds_2d(1, 1) = 1
     601         134 :                   bounds_2d(2, 1) = nhomo
     602         402 :                   bounds_2d(:, 2) = bounds_1d
     603             : 
     604         134 :                   CALL dbt_create(aj_Ib, aj_Ib_diff, map1_2d=[1], map2_2d=[2, 3])
     605         134 :                   CALL dbt_copy(aj_Ib, aj_Ib_diff, move_data=.TRUE.)
     606             : 
     607             :                   CALL dbt_contract(alpha=-1.0_dp, tensor_1=oI_Y(ido_mo), tensor_2=ja_X, &
     608             :                                     beta=1.0_dp, tensor_3=aj_Ib_diff, contract_1=[2], &
     609             :                                     notcontract_1=[1], contract_2=[3], notcontract_2=[1, 2], &
     610             :                                     map_1=[1], map_2=[2, 3], &
     611         402 :                                     bounds_2=[nhomo + 1, nhomo + nlumo], bounds_3=bounds_2d)
     612             : 
     613         134 :                   CALL calc_ss_ovv_contrib(parts(3), parts(4), aj_Ib_diff, homo_evals, lumo_evals, omega_k, c_ss)
     614             : 
     615         134 :                   CALL dbt_destroy(aj_Ib_diff)
     616             :                END IF ! c_ss not 0
     617             : 
     618         268 :                CALL dbt_destroy(aj_Ib)
     619             :             END DO
     620             : 
     621         134 :             CALL para_env%sum(parts)
     622         134 :             s1 = parts(1); ds1 = parts(2)
     623         134 :             s2 = parts(3); ds2 = parts(4)
     624             : 
     625             :             !evaluate g and its derivative
     626         134 :             g = eps_I - omega_k + s1 + s2
     627         134 :             dg = -1.0_dp + ds1 + ds2
     628             : 
     629             :             !compute the diff to the new step
     630         134 :             diff = -g/dg
     631             : 
     632             :             !and the new omega
     633         134 :             omega_k = omega_k + diff
     634         134 :             diff = diff*evolt
     635             : 
     636         134 :             IF (output_unit > 0) THEN
     637             :                WRITE (UNIT=output_unit, FMT="(T21,I18,F32.6)") &
     638          67 :                   iloop, diff
     639          67 :                CALL m_flush(output_unit)
     640             :             END IF
     641             : 
     642         168 :             IF (iloop > max_iter) THEN
     643           0 :                CPWARN("GW2X iteration not converged.")
     644           0 :                EXIT
     645             :             END IF
     646             :          END DO !while loop on eps_iter
     647             : 
     648             :          !compute the shift and update donor_state
     649          34 :          donor_state%gw2x_evals(ido_mo, 1) = omega_k
     650             : 
     651          60 :          IF (output_unit > 0) THEN
     652             :             WRITE (UNIT=output_unit, FMT="(/T7,A,F11.6,/,T5,A,F11.6)") &
     653          17 :                "Final GW2X shift for this donor MO (eV):", &
     654          34 :                (donor_state%energy_evals(ido_mo, 1) - omega_k)*evolt
     655             :          END IF
     656             : 
     657             :       END DO !ido_mo
     658             : 
     659          26 :       CALL dbt_destroy(aj_X)
     660             : 
     661          26 :       CALL timestop(handle)
     662             : 
     663          52 :    END SUBROUTINE GW2X_rcs_iterations
     664             : 
     665             : ! **************************************************************************************************
     666             : !> \brief Preforms the GW2X iterations in the open-shell shell formalism according to the
     667             : !>        Newton-Raphson method
     668             : !> \param first_domo index of the first core donor MO to consider, for each spin
     669             : !> \param ja_X semi-contracted tensors with j: occupied MO, a: virtual MO, X: RI basis element
     670             : !> \param oI_Y semi-contracted tensors with o: all MOs, I donor core MO, Y: RI basis element
     671             : !> \param mo_template tensor template for fully MO contracted tensor, for each spin combination
     672             : !> \param homo_evals ...
     673             : !> \param lumo_evals ...
     674             : !> \param donor_state ...
     675             : !> \param xas_tdp_control ...
     676             : !> \param qs_env ...
     677             : ! **************************************************************************************************
     678           4 :    SUBROUTINE GW2X_os_iterations(first_domo, ja_X, oI_Y, mo_template, homo_evals, lumo_evals, &
     679             :                                  donor_state, xas_tdp_control, qs_env)
     680             : 
     681             :       INTEGER, INTENT(IN)                                :: first_domo(2)
     682             :       TYPE(dbt_type), DIMENSION(:), INTENT(inout)        :: ja_X, oI_Y
     683             :       TYPE(dbt_type), DIMENSION(:, :), INTENT(inout)     :: mo_template
     684             :       TYPE(cp_1d_r_p_type), DIMENSION(:), INTENT(in)     :: homo_evals, lumo_evals
     685             :       TYPE(donor_state_type), POINTER                    :: donor_state
     686             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
     687             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     688             : 
     689             :       CHARACTER(len=*), PARAMETER :: routineN = 'GW2X_os_iterations'
     690             : 
     691             :       INTEGER :: batch_size, bounds_1d(2), bounds_2d(2, 2), handle, i, ibatch, ido_mo, iloop, &
     692             :          ispin, max_iter, nbatch_occ, nbatch_virt, nblk_occ, nblk_virt, nblks(3), ndo_mo, &
     693             :          nhomo(2), nlumo(2), nspins, occ_bo(2), other_spin, output_unit, tmp_sum, virt_bo(2)
     694           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: mo_blk_size
     695             :       REAL(dp)                                           :: c_os, c_ss, dg, diff, ds1, ds2, eps_I, &
     696             :                                                             eps_iter, g, omega_k, parts(4), s1, s2
     697         136 :       TYPE(dbt_type)                                     :: aj_Ib, aj_Ib_diff, ja_Ik, ja_Ik_diff
     698           4 :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:)          :: aj_X
     699             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     700             : 
     701           4 :       CALL timeset(routineN, handle)
     702             : 
     703           4 :       eps_iter = xas_tdp_control%gw2x_eps
     704           4 :       max_iter = xas_tdp_control%max_gw2x_iter
     705           4 :       c_os = xas_tdp_control%c_os
     706           4 :       c_ss = xas_tdp_control%c_ss
     707           4 :       batch_size = xas_tdp_control%batch_size
     708             : 
     709           4 :       nspins = 2
     710           4 :       ndo_mo = donor_state%ndo_mo
     711           4 :       output_unit = cp_logger_get_default_io_unit()
     712             : 
     713          12 :       DO ispin = 1, nspins
     714           8 :          nhomo(ispin) = SIZE(homo_evals(ispin)%array)
     715          12 :          nlumo(ispin) = SIZE(lumo_evals(ispin)%array)
     716             :       END DO
     717             : 
     718           4 :       CALL get_qs_env(qs_env, para_env=para_env)
     719             : 
     720             :       !We use the Newton-Raphson method to find the zero of the function:
     721             :       !g(omega) = eps_I - omega + mp2 terms, dg(omega) = -1 + d/d_omega (mp2 terms)
     722             :       !We simply compute at each iteration: omega_k+1 = omega_k - g(omega_k)/dg(omega_k)
     723             : 
     724          48 :       ALLOCATE (aj_X(2))
     725          12 :       DO ispin = 1, nspins
     726             : 
     727             :          !need transposed tensor of (ja|X) for optimal contraction scheme,
     728             :          !s.t. (aj|X) block is on same processor as (ja|X)) and differences can be taken
     729           8 :          CALL dbt_create(ja_X(ispin), aj_X(ispin))
     730          12 :          CALL dbt_copy(ja_X(ispin), aj_X(ispin), order=[2, 1, 3])
     731             : 
     732             :       END DO ! ispin
     733          12 :       DO ispin = 1, nspins
     734             : 
     735           8 :          other_spin = 3 - ispin
     736             : 
     737             :          !split the MO blocks into batches for memory friendly batched contraction
     738             :          !huge dense tensors never need to be stored. Split MOs for the current spin
     739           8 :          CALL dbt_get_info(ja_X(ispin), nblks_total=nblks)
     740          24 :          ALLOCATE (mo_blk_size(nblks(1)))
     741           8 :          CALL dbt_get_info(ja_X(ispin), blk_size_1=mo_blk_size)
     742             : 
     743           8 :          tmp_sum = 0
     744          16 :          DO i = 1, nblks(1)
     745          16 :             tmp_sum = tmp_sum + mo_blk_size(i)
     746          16 :             IF (tmp_sum == nhomo(ispin)) THEN
     747           8 :                nblk_occ = i
     748           8 :                nblk_virt = nblks(1) - i
     749           8 :                EXIT
     750             :             END IF
     751             :          END DO
     752           8 :          nbatch_occ = MAX(1, nblk_occ/batch_size)
     753           8 :          nbatch_virt = MAX(1, nblk_virt/batch_size)
     754             : 
     755             :          !Loop over donor_states of the current spin
     756          16 :          DO ido_mo = 1, ndo_mo
     757           8 :             IF (output_unit > 0) THEN
     758             :                WRITE (UNIT=output_unit, FMT="(/,T5,A,I2,A,I4,A,/,T5,A)") &
     759           4 :                   "- GW2X correction for donor MO with spin ", ispin, &
     760           4 :                   " and MO index ", donor_state%mo_indices(ido_mo, ispin), ":", &
     761           8 :                   "                         iteration                convergence (eV)"
     762           4 :                CALL m_flush(output_unit)
     763             :             END IF
     764             : 
     765             :             !starting values
     766           8 :             eps_I = homo_evals(ispin)%array(first_domo(ispin) + ido_mo - 1)
     767           8 :             omega_k = eps_I
     768           8 :             iloop = 0
     769           8 :             diff = 2.0_dp*eps_iter
     770             : 
     771          40 :             DO WHILE (ABS(diff) > eps_iter)
     772          32 :                iloop = iloop + 1
     773             : 
     774             :                !Compute the mp2 terms and their first derivative
     775          32 :                parts = 0.0_dp
     776             : 
     777             :                !We do batched contraction for (ja|Ik) and (ja|Ib) to never have to carry the full tensor
     778          64 :                DO ibatch = 1, nbatch_occ
     779             : 
     780             :                   !opposite-spin contribution, i.e. (j_beta a_beta| I_alpha k_alpha) and vice-versa
     781             :                   !do the batching along k because same spin as donor MO
     782          32 :                   occ_bo = get_limit(nblk_occ, nbatch_occ, ibatch - 1)
     783         160 :                   bounds_1d = [SUM(mo_blk_size(1:occ_bo(1) - 1)) + 1, SUM(mo_blk_size(1:occ_bo(2)))]
     784             : 
     785          32 :                   CALL dbt_create(mo_template(other_spin, ispin), ja_Ik)
     786             :                   CALL dbt_contract(alpha=1.0_dp, tensor_1=ja_X(other_spin), &
     787             :                                     tensor_2=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     788             :                                     beta=0.0_dp, tensor_3=ja_Ik, contract_1=[3], &
     789             :                                     notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     790          32 :                                     map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     791             : 
     792             :                   CALL calc_os_oov_contrib(parts(1), parts(2), ja_Ik, homo_evals(other_spin)%array, &
     793             :                                            lumo_evals(other_spin)%array, homo_evals(ispin)%array, &
     794          32 :                                            omega_k, c_os, nhomo(other_spin))
     795             : 
     796          32 :                   CALL dbt_destroy(ja_Ik)
     797             : 
     798             :                   !same-spin contribution, need to compute (ja|Ik) - (ka|Ij), all with the current spin
     799             :                   !skip if c_ss == 0
     800          64 :                   IF (ABS(c_ss) > EPSILON(1.0_dp)) THEN
     801             : 
     802             :                      !same batching as opposite spin
     803          32 :                      CALL dbt_create(mo_template(ispin, ispin), ja_Ik)
     804             :                      CALL dbt_contract(alpha=1.0_dp, tensor_1=ja_X(ispin), &
     805             :                                        tensor_2=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     806             :                                        beta=0.0_dp, tensor_3=ja_Ik, contract_1=[3], &
     807             :                                        notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     808          32 :                                        map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     809             : 
     810          96 :                      bounds_2d(:, 2) = bounds_1d
     811          32 :                      bounds_2d(1, 1) = nhomo(ispin) + 1
     812          32 :                      bounds_2d(2, 1) = nhomo(ispin) + nlumo(ispin)
     813             : 
     814             :                      !the tensor difference is directly taken here
     815          32 :                      CALL dbt_create(ja_Ik, ja_Ik_diff, map1_2d=[1], map2_2d=[2, 3])
     816          32 :                      CALL dbt_copy(ja_Ik, ja_Ik_diff, move_data=.TRUE.)
     817             : 
     818             :                      CALL dbt_contract(alpha=-1.0_dp, tensor_1=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     819             :                                        tensor_2=aj_X(ispin), beta=1.0_dp, tensor_3=ja_Ik_diff, &
     820             :                                        contract_1=[2], notcontract_1=[1], contract_2=[3], notcontract_2=[1, 2], &
     821          96 :                                        map_1=[1], map_2=[2, 3], bounds_2=[1, nhomo(ispin)], bounds_3=bounds_2d)
     822             : 
     823             :                      CALL calc_ss_oov_contrib(parts(1), parts(2), ja_Ik_diff, homo_evals(ispin)%array, &
     824          32 :                                               lumo_evals(ispin)%array, omega_k, c_ss)
     825             : 
     826          32 :                      CALL dbt_destroy(ja_Ik_diff)
     827          32 :                      CALL dbt_destroy(ja_Ik)
     828             :                   END IF !c_ss !!= 0
     829             : 
     830             :                END DO
     831             : 
     832          64 :                DO ibatch = 1, nbatch_virt
     833             : 
     834             :                   !opposite-spin contribution, i.e. (a_beta j_beta| I_alpha b_alpha) and vice-versa
     835             :                   !do the batching along b because same spin as donor MO
     836          32 :                   virt_bo = get_limit(nblk_virt, nbatch_virt, ibatch - 1)
     837             :                   bounds_1d = [SUM(mo_blk_size(1:nblk_occ + virt_bo(1) - 1)) + 1, &
     838         256 :                                SUM(mo_blk_size(1:nblk_occ + virt_bo(2)))]
     839             : 
     840          32 :                   CALL dbt_create(mo_template(other_spin, ispin), aj_Ib)
     841             :                   CALL dbt_contract(alpha=1.0_dp, tensor_1=aj_X(other_spin), &
     842             :                                     tensor_2=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     843             :                                     beta=0.0_dp, tensor_3=aj_Ib, contract_1=[3], &
     844             :                                     notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     845          32 :                                     map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     846             : 
     847             :                   CALL calc_os_ovv_contrib(parts(3), parts(4), aj_Ib, lumo_evals(other_spin)%array, &
     848             :                                            homo_evals(other_spin)%array, lumo_evals(ispin)%array, &
     849          32 :                                            omega_k, c_os, nhomo(other_spin), nhomo(ispin))
     850             : 
     851          32 :                   CALL dbt_destroy(aj_Ib)
     852             : 
     853             :                   !same-spin contribution, need to compute (aj|Ib) - (bj|Ia), all with the current spin
     854             :                   !skip if c_ss == 0
     855          64 :                   IF (ABS(c_ss) > EPSILON(1.0_dp)) THEN
     856             : 
     857             :                      !same batching as opposite spin
     858          32 :                      CALL dbt_create(mo_template(ispin, ispin), aj_Ib)
     859             :                      CALL dbt_contract(alpha=1.0_dp, tensor_1=aj_X(ispin), &
     860             :                                        tensor_2=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     861             :                                        beta=0.0_dp, tensor_3=aj_Ib, contract_1=[3], &
     862             :                                        notcontract_1=[1, 2], contract_2=[2], notcontract_2=[1], &
     863          32 :                                        map_1=[1, 2], map_2=[3], bounds_3=bounds_1d)
     864             : 
     865          32 :                      bounds_2d(1, 1) = 1
     866          32 :                      bounds_2d(2, 1) = nhomo(ispin)
     867          96 :                      bounds_2d(:, 2) = bounds_1d
     868             : 
     869          32 :                      CALL dbt_create(aj_Ib, aj_Ib_diff, map1_2d=[1], map2_2d=[2, 3])
     870          32 :                      CALL dbt_copy(aj_Ib, aj_Ib_diff, move_data=.TRUE.)
     871             : 
     872             :                      CALL dbt_contract(alpha=-1.0_dp, tensor_1=oI_Y((ispin - 1)*ndo_mo + ido_mo), &
     873             :                                        tensor_2=ja_X(ispin), beta=1.0_dp, tensor_3=aj_Ib_diff, &
     874             :                                        contract_1=[2], notcontract_1=[1], contract_2=[3], &
     875             :                                        notcontract_2=[1, 2], map_1=[1], map_2=[2, 3], &
     876             :                                        bounds_2=[nhomo(ispin) + 1, nhomo(ispin) + nlumo(ispin)], &
     877          96 :                                        bounds_3=bounds_2d)
     878             : 
     879             :                      CALL calc_ss_ovv_contrib(parts(3), parts(4), aj_Ib_diff, homo_evals(ispin)%array, &
     880          32 :                                               lumo_evals(ispin)%array, omega_k, c_ss)
     881             : 
     882          32 :                      CALL dbt_destroy(aj_Ib_diff)
     883          32 :                      CALL dbt_destroy(aj_Ib)
     884             :                   END IF ! c_ss not 0
     885             : 
     886             :                END DO
     887             : 
     888          32 :                CALL para_env%sum(parts)
     889          32 :                s1 = parts(1); ds1 = parts(2)
     890          32 :                s2 = parts(3); ds2 = parts(4)
     891             : 
     892             :                !evaluate g and its derivative
     893          32 :                g = eps_I - omega_k + s1 + s2
     894          32 :                dg = -1.0_dp + ds1 + ds2
     895             : 
     896             :                !compute the diff to the new step
     897          32 :                diff = -g/dg
     898             : 
     899             :                !and the new omega
     900          32 :                omega_k = omega_k + diff
     901          32 :                diff = diff*evolt
     902             : 
     903          32 :                IF (output_unit > 0) THEN
     904             :                   WRITE (UNIT=output_unit, FMT="(T21,I18,F32.6)") &
     905          16 :                      iloop, diff
     906          16 :                   CALL m_flush(output_unit)
     907             :                END IF
     908             : 
     909          40 :                IF (iloop > max_iter) THEN
     910           0 :                   CPWARN("GW2X iteration not converged.")
     911           0 :                   EXIT
     912             :                END IF
     913             :             END DO !while loop on eps_iter
     914             : 
     915             :             !compute the shift and update donor_state
     916           8 :             donor_state%gw2x_evals(ido_mo, ispin) = omega_k
     917             : 
     918          16 :             IF (output_unit > 0) THEN
     919             :                WRITE (UNIT=output_unit, FMT="(/T7,A,F11.6,/,T5,A,F11.6)") &
     920           4 :                   "Final GW2X shift for this donor MO (eV):", &
     921           8 :                   (donor_state%energy_evals(ido_mo, ispin) - omega_k)*evolt
     922             :             END IF
     923             : 
     924             :          END DO !ido_mo
     925             : 
     926          12 :          DEALLOCATE (mo_blk_size)
     927             :       END DO ! ispin
     928             : 
     929          12 :       DO ispin = 1, nspins
     930          12 :          CALL dbt_destroy(aj_X(ispin))
     931             :       END DO
     932             : 
     933           4 :       CALL timestop(handle)
     934             : 
     935          20 :    END SUBROUTINE GW2X_os_iterations
     936             : 
     937             : ! **************************************************************************************************
     938             : !> \brief Takes the 3-center integrals from the ri_ex_3c tensor and returns a full tensor. Since
     939             : !>        ri_ex_3c is only half filled because of symmetry, we have to add the transpose
     940             : !>        and scale the diagonal blocks by 0.5
     941             : !> \param pq_X the full (desymmetrized) tensor containing the (pq|X) exchange integrals, in a new
     942             : !>        3d distribution and optimized block sizes
     943             : !> \param exat index of current excited atom
     944             : !> \param xas_tdp_env ...
     945             : !> \param qs_env ...
     946             : ! **************************************************************************************************
     947          34 :    SUBROUTINE get_full_pqX_from_3c_ex(pq_X, exat, xas_tdp_env, qs_env)
     948             : 
     949             :       TYPE(dbt_type), INTENT(INOUT)                      :: pq_X
     950             :       INTEGER, INTENT(IN)                                :: exat
     951             :       TYPE(xas_tdp_env_type), POINTER                    :: xas_tdp_env
     952             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     953             : 
     954             :       INTEGER                                            :: i, ind(3), natom, nblk_ri, nsgf_x
     955          34 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: orb_blk_size, proc_dist_1, proc_dist_2, &
     956          34 :                                                             proc_dist_3
     957             :       INTEGER, DIMENSION(3)                              :: pdims
     958             :       LOGICAL                                            :: found
     959          34 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: pblock
     960         306 :       TYPE(dbt_distribution_type)                        :: t_dist
     961             :       TYPE(dbt_iterator_type)                            :: iter
     962         102 :       TYPE(dbt_pgrid_type)                               :: t_pgrid
     963         612 :       TYPE(dbt_type)                                     :: pq_X_tmp, work
     964             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     965             : 
     966          34 :       NULLIFY (para_env)
     967             : 
     968             :       !create work tensor with same 2D dist as pq_X, but only keep excited atom along RI direction
     969          34 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natom)
     970          34 :       CALL dbt_get_info(xas_tdp_env%ri_3c_ex, pdims=pdims)
     971          34 :       nsgf_x = SIZE(xas_tdp_env%ri_inv_ex, 1)
     972          34 :       nblk_ri = 1
     973             : 
     974          34 :       CALL dbt_pgrid_create(para_env, pdims, t_pgrid)
     975         170 :       ALLOCATE (proc_dist_1(natom), proc_dist_2(natom), orb_blk_size(natom))
     976             :       CALL dbt_get_info(xas_tdp_env%ri_3c_ex, proc_dist_1=proc_dist_1, proc_dist_2=proc_dist_2, &
     977          34 :                         blk_size_1=orb_blk_size)
     978             :       CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=proc_dist_1, nd_dist_2=proc_dist_2, &
     979         136 :                                 nd_dist_3=[(0, i=1, nblk_ri)])
     980             : 
     981             :       CALL dbt_create(work, name="(pq|X)", dist=t_dist, map1_2d=[1], map2_2d=[2, 3], &
     982          68 :                       blk_size_1=orb_blk_size, blk_size_2=orb_blk_size, blk_size_3=[nsgf_x])
     983          34 :       CALL dbt_distribution_destroy(t_dist)
     984             : 
     985             :       !dist of 3c_ex and work match, can simply copy blocks over. Diagonal with factor 0.5
     986             : 
     987             : !$OMP PARALLEL DEFAULT(NONE) SHARED(xas_tdp_env,exat,work,orb_blk_size,nsgf_x) &
     988          34 : !$OMP PRIVATE(iter,ind,pblock,found)
     989             :       CALL dbt_iterator_start(iter, xas_tdp_env%ri_3c_ex)
     990             :       DO WHILE (dbt_iterator_blocks_left(iter))
     991             :          CALL dbt_iterator_next_block(iter, ind)
     992             :          CALL dbt_get_block(xas_tdp_env%ri_3c_ex, ind, pblock, found)
     993             : 
     994             :          IF (ind(1) == ind(2)) pblock = 0.5_dp*pblock
     995             :          IF (ind(3) /= exat) CYCLE
     996             : 
     997             :          CALL dbt_put_block(work, [ind(1), ind(2), 1], &
     998             :                             [orb_blk_size(ind(1)), orb_blk_size(ind(2)), nsgf_x], pblock)
     999             : 
    1000             :          DEALLOCATE (pblock)
    1001             :       END DO
    1002             :       CALL dbt_iterator_stop(iter)
    1003             : !$OMP END PARALLEL
    1004          34 :       CALL dbt_finalize(work)
    1005             : 
    1006             :       !create (pq|X) based on work and copy over
    1007          34 :       CALL dbt_create(work, pq_X_tmp)
    1008          34 :       CALL dbt_copy(work, pq_X_tmp)
    1009          34 :       CALL dbt_copy(work, pq_X_tmp, order=[2, 1, 3], summation=.TRUE., move_data=.TRUE.)
    1010             : 
    1011          34 :       CALL dbt_destroy(work)
    1012             : 
    1013             :       !create the pgrid, based on the 2D dbcsr grid
    1014          34 :       CALL dbt_pgrid_destroy(t_pgrid)
    1015          34 :       pdims = 0
    1016         136 :       CALL dbt_pgrid_create(para_env, pdims, t_pgrid, tensor_dims=[natom, natom, 1])
    1017             : 
    1018             :       !cyclic distribution accross all directions.
    1019          34 :       ALLOCATE (proc_dist_3(nblk_ri))
    1020          34 :       CALL dbt_default_distvec(natom, pdims(1), orb_blk_size, proc_dist_1)
    1021          34 :       CALL dbt_default_distvec(natom, pdims(2), orb_blk_size, proc_dist_2)
    1022          68 :       CALL dbt_default_distvec(nblk_ri, pdims(3), [nsgf_x], proc_dist_3)
    1023             :       CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=proc_dist_1, nd_dist_2=proc_dist_2, &
    1024          34 :                                 nd_dist_3=proc_dist_3)
    1025             : 
    1026             :       CALL dbt_create(pq_X, name="(pq|X)", dist=t_dist, map1_2d=[2, 3], map2_2d=[1], &
    1027          68 :                       blk_size_1=orb_blk_size, blk_size_2=orb_blk_size, blk_size_3=[nsgf_x])
    1028          34 :       CALL dbt_copy(pq_X_tmp, pq_X, move_data=.TRUE.)
    1029             : 
    1030          34 :       CALL dbt_distribution_destroy(t_dist)
    1031          34 :       CALL dbt_pgrid_destroy(t_pgrid)
    1032          34 :       CALL dbt_destroy(pq_X_tmp)
    1033             : 
    1034          68 :    END SUBROUTINE get_full_pqX_from_3c_ex
    1035             : 
    1036             : ! **************************************************************************************************
    1037             : !> \brief Contracts (pq|X) and (rI|Y) from AOs to MOs to (ja|X) and (oI|Y) respectively, where
    1038             : !>        j is a occupied MO, a is a virtual MO and o is a general MO
    1039             : !> \param ja_X partial contraction over occupied MOs j, virtual MOs a: (ja|X), for both spins (alpha-alpha or beta-beta)
    1040             : !> \param oI_Y partial contraction over all MOs o and donor MOs I (can be more than 1 if 2p or open-shell)
    1041             : !> \param ja_Io_template template to be able to build tensors after calling this routine, for each spin combination
    1042             : !> \param mo_coeffs ...
    1043             : !> \param nocc ...
    1044             : !> \param nvirt ...
    1045             : !> \param donor_state ...
    1046             : !> \param xas_tdp_env ...
    1047             : !> \param xas_tdp_control ...
    1048             : !> \param qs_env ...
    1049             : !> \note the multiplication by (X|Y)^-1 is included in the final (oI|Y) tensor. Only integrals with the
    1050             : !>       same spin on one center are non-zero, i.e. (oI|Y) is non zero only if both o and Y have the same spin
    1051             : ! **************************************************************************************************
    1052          30 :    SUBROUTINE contract_AOs_to_MOs(ja_X, oI_Y, ja_Io_template, mo_coeffs, nocc, nvirt, &
    1053             :                                   donor_state, xas_tdp_env, xas_tdp_control, qs_env)
    1054             : 
    1055             :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:), &
    1056             :          INTENT(INOUT)                                   :: ja_X, oI_Y
    1057             :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:, :), &
    1058             :          INTENT(INOUT)                                   :: ja_Io_template
    1059             :       TYPE(cp_fm_type), DIMENSION(:), INTENT(INOUT)      :: mo_coeffs
    1060             :       INTEGER, INTENT(IN)                                :: nocc(2), nvirt(2)
    1061             :       TYPE(donor_state_type), POINTER                    :: donor_state
    1062             :       TYPE(xas_tdp_env_type), POINTER                    :: xas_tdp_env
    1063             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
    1064             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1065             : 
    1066             :       CHARACTER(len=*), PARAMETER :: routineN = 'contract_AOs_to_MOs'
    1067             : 
    1068             :       INTEGER                                            :: bo(2), handle, i, ispin, jspin, &
    1069             :                                                             nblk_aos, nblk_mos(2), nblk_occ(2), &
    1070             :                                                             nblk_pqX(3), nblk_ri, nblk_virt(2), &
    1071             :                                                             nspins
    1072             :       INTEGER, DIMENSION(3)                              :: pdims
    1073          30 :       INTEGER, DIMENSION(:), POINTER                     :: ao_blk_size, ao_col_dist, ao_row_dist, &
    1074          60 :                                                             mo_dist_3, ri_blk_size, ri_dist_3
    1075          30 :       INTEGER, DIMENSION(:, :), POINTER                  :: mat_pgrid
    1076          30 :       TYPE(cp_1d_i_p_type), ALLOCATABLE, DIMENSION(:)    :: mo_blk_size, mo_col_dist, mo_row_dist
    1077             :       TYPE(dbcsr_distribution_type)                      :: mat_dist
    1078             :       TYPE(dbcsr_distribution_type), POINTER             :: std_mat_dist
    1079             :       TYPE(dbcsr_type)                                   :: dbcsr_mo_coeffs
    1080         210 :       TYPE(dbt_distribution_type)                        :: t_dist
    1081          90 :       TYPE(dbt_pgrid_type)                               :: t_pgrid
    1082         630 :       TYPE(dbt_type)                                     :: jq_X, pq_X, t_mo_coeffs
    1083             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    1084             : 
    1085          30 :       NULLIFY (ao_blk_size, ao_col_dist, ao_row_dist, mo_dist_3, ri_blk_size, ri_dist_3, mat_pgrid, &
    1086          30 :                para_env, std_mat_dist)
    1087             : 
    1088          30 :       CALL timeset(routineN, handle)
    1089             : 
    1090          30 :       nspins = 1; IF (xas_tdp_control%do_uks .OR. xas_tdp_control%do_roks) nspins = 2
    1091             : 
    1092             :       !There are 2 contractions to do for the first tensor: (pq|X) --> (jq|X) --> (ja|X)
    1093             :       !Because memory is the main concern, we move_data everytime at the cost of extra copies
    1094             : 
    1095             :       !Some quantities need to be stored for both spins, because they are later combined
    1096          30 :       CALL get_qs_env(qs_env, para_env=para_env)
    1097         222 :       ALLOCATE (mo_blk_size(nspins), mo_row_dist(nspins), mo_col_dist(nspins))
    1098         274 :       ALLOCATE (ja_X(nspins))
    1099         312 :       ALLOCATE (oI_Y(nspins*donor_state%ndo_mo))
    1100             : 
    1101          64 :       DO ispin = 1, nspins
    1102             : 
    1103             :          !First, we need a fully populated pq_X (spin-independent)
    1104          34 :          CALL get_full_pqX_from_3c_ex(pq_X, donor_state%at_index, xas_tdp_env, qs_env)
    1105             : 
    1106             :          !Create the tensor pgrid. AOs and RI independent from spin
    1107          34 :          IF (ispin == 1) THEN
    1108          30 :             CALL dbt_get_info(pq_X, pdims=pdims, nblks_total=nblk_pqX)
    1109          30 :             CALL dbt_pgrid_create(para_env, pdims, t_pgrid)
    1110          30 :             nblk_aos = nblk_pqX(1)
    1111          30 :             nblk_ri = nblk_pqX(3)
    1112             :          END IF
    1113             : 
    1114             :          !Define MO block sizes, at worst, take one block per proc
    1115          34 :          nblk_occ(ispin) = MAX(pdims(1), nocc(ispin)/16)
    1116          34 :          nblk_virt(ispin) = MAX(pdims(2), nvirt(ispin)/16)
    1117          34 :          nblk_mos(ispin) = nblk_occ(ispin) + nblk_virt(ispin)
    1118         102 :          ALLOCATE (mo_blk_size(ispin)%array(nblk_mos(ispin)))
    1119         102 :          DO i = 1, nblk_occ(ispin)
    1120          68 :             bo = get_limit(nocc(ispin), nblk_occ(ispin), i - 1)
    1121         102 :             mo_blk_size(ispin)%array(i) = bo(2) - bo(1) + 1
    1122             :          END DO
    1123          96 :          DO i = 1, nblk_virt(ispin)
    1124          62 :             bo = get_limit(nvirt(ispin), nblk_virt(ispin), i - 1)
    1125          96 :             mo_blk_size(ispin)%array(nblk_occ(ispin) + i) = bo(2) - bo(1) + 1
    1126             :          END DO
    1127             : 
    1128             :          !Convert the fm mo_coeffs into a dbcsr matrix and then a tensor
    1129          34 :          CALL get_qs_env(qs_env, dbcsr_dist=std_mat_dist)
    1130          34 :          CALL dbcsr_distribution_get(std_mat_dist, pgrid=mat_pgrid)
    1131         170 :          ALLOCATE (ao_blk_size(nblk_aos), ri_blk_size(nblk_ri))
    1132          34 :          CALL dbt_get_info(pq_X, blk_size_1=ao_blk_size, blk_size_3=ri_blk_size)
    1133             : 
    1134             :          !we opt for a cyclic dist for the MOs (since they should be rather dense anyways)
    1135         102 :          ALLOCATE (ao_row_dist(nblk_aos), mo_col_dist(ispin)%array(nblk_mos(ispin)))
    1136          34 :          CALL dbt_default_distvec(nblk_aos, SIZE(mat_pgrid, 1), ao_blk_size, ao_row_dist)
    1137             :          CALL dbt_default_distvec(nblk_mos(ispin), SIZE(mat_pgrid, 2), mo_blk_size(ispin)%array, &
    1138          34 :                                   mo_col_dist(ispin)%array)
    1139             :          CALL dbcsr_distribution_new(mat_dist, group=para_env%get_handle(), pgrid=mat_pgrid, &
    1140          34 :                                      row_dist=ao_row_dist, col_dist=mo_col_dist(ispin)%array)
    1141             : 
    1142             :          CALL dbcsr_create(dbcsr_mo_coeffs, name="MO coeffs", matrix_type="N", dist=mat_dist, &
    1143          34 :                            row_blk_size=ao_blk_size, col_blk_size=mo_blk_size(ispin)%array)
    1144          34 :          CALL copy_fm_to_dbcsr(mo_coeffs(ispin), dbcsr_mo_coeffs)
    1145             : 
    1146          34 :          CALL dbt_create(dbcsr_mo_coeffs, t_mo_coeffs)
    1147          34 :          CALL dbt_copy_matrix_to_tensor(dbcsr_mo_coeffs, t_mo_coeffs)
    1148             : 
    1149             :          !prepare the (jq|X) tensor for the first contraction (over occupied MOs)
    1150         136 :          ALLOCATE (mo_row_dist(ispin)%array(nblk_mos(ispin)), ao_col_dist(nblk_aos), ri_dist_3(nblk_ri))
    1151          34 :          CALL dbt_default_distvec(nblk_mos(ispin), pdims(1), mo_blk_size(ispin)%array, mo_row_dist(ispin)%array)
    1152          34 :          CALL dbt_default_distvec(nblk_aos, pdims(2), ao_blk_size, ao_col_dist)
    1153          34 :          CALL dbt_default_distvec(nblk_ri, pdims(3), ri_blk_size, ri_dist_3)
    1154             :          CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=mo_row_dist(ispin)%array, &
    1155          34 :                                    nd_dist_2=ao_col_dist, nd_dist_3=ri_dist_3)
    1156             : 
    1157             :          CALL dbt_create(jq_X, name="(jq|X)", dist=t_dist, map1_2d=[1, 3], map2_2d=[2], &
    1158          34 :                          blk_size_1=mo_blk_size(ispin)%array, blk_size_2=ao_blk_size, blk_size_3=ri_blk_size)
    1159          34 :          CALL dbt_distribution_destroy(t_dist)
    1160             : 
    1161             :          !contract (pq|X) into (jq|X)
    1162             :          CALL dbt_contract(alpha=1.0_dp, tensor_1=pq_X, tensor_2=t_mo_coeffs, &
    1163             :                            beta=0.0_dp, tensor_3=jq_X, contract_1=[1], &
    1164             :                            notcontract_1=[2, 3], contract_2=[1], notcontract_2=[2], &
    1165             :                            map_1=[2, 3], map_2=[1], bounds_3=[1, nocc(ispin)], &!only want occupied MOs for j
    1166         102 :                            move_data=.TRUE.)
    1167             : 
    1168          34 :          CALL dbt_destroy(pq_X)
    1169          34 :          CALL dbt_copy_matrix_to_tensor(dbcsr_mo_coeffs, t_mo_coeffs)
    1170             : 
    1171             :          !prepare (ja|X) tensor for the second contraction (over virtual MOs)
    1172             :          !only virtual-occupied bit of the first 2 indices is occupied + it should be dense
    1173             :          !take blk dist such that blocks are evenly distributed
    1174             :          CALL dbt_default_distvec(nblk_occ(ispin), pdims(1), mo_blk_size(ispin)%array(1:nblk_occ(ispin)), &
    1175          34 :                                   mo_row_dist(ispin)%array(1:nblk_occ(ispin)))
    1176             :          CALL dbt_default_distvec(nblk_virt(ispin), pdims(1), &
    1177             :                                   mo_blk_size(ispin)%array(nblk_occ(ispin) + 1:nblk_mos(ispin)), &
    1178          34 :                                   mo_row_dist(ispin)%array(nblk_occ(ispin) + 1:nblk_mos(ispin)))
    1179             :          CALL dbt_default_distvec(nblk_occ(ispin), pdims(2), mo_blk_size(ispin)%array(1:nblk_occ(ispin)), &
    1180          34 :                                   mo_col_dist(ispin)%array(1:nblk_occ(ispin)))
    1181             :          CALL dbt_default_distvec(nblk_virt(ispin), pdims(2), &
    1182             :                                   mo_blk_size(ispin)%array(nblk_occ(ispin) + 1:nblk_mos(ispin)), &
    1183          34 :                                   mo_col_dist(ispin)%array(nblk_occ(ispin) + 1:nblk_mos(ispin)))
    1184             :          CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=mo_row_dist(ispin)%array, &
    1185          34 :                                    nd_dist_2=mo_col_dist(ispin)%array, nd_dist_3=ri_dist_3)
    1186             : 
    1187             :          CALL dbt_create(ja_X(ispin), name="(ja|X)", dist=t_dist, map1_2d=[1, 2], map2_2d=[3], &
    1188             :                          blk_size_1=mo_blk_size(ispin)%array, blk_size_2=mo_blk_size(ispin)%array, &
    1189          34 :                          blk_size_3=ri_blk_size)
    1190          34 :          CALL dbt_distribution_destroy(t_dist)
    1191             : 
    1192             :          !contract (jq|X) into (ja|X)
    1193             :          CALL dbt_contract(alpha=1.0_dp, tensor_1=jq_X, tensor_2=t_mo_coeffs, &
    1194             :                            beta=0.0_dp, tensor_3=ja_X(ispin), contract_1=[2], &
    1195             :                            notcontract_1=[1, 3], contract_2=[1], notcontract_2=[2], &
    1196             :                            map_1=[1, 3], map_2=[2], move_data=.TRUE., &
    1197         102 :                            bounds_3=[nocc(ispin) + 1, nocc(ispin) + nvirt(ispin)])
    1198             : 
    1199          34 :          CALL dbt_destroy(jq_X)
    1200          34 :          CALL dbt_copy_matrix_to_tensor(dbcsr_mo_coeffs, t_mo_coeffs)
    1201             : 
    1202             :          !Finally, get the oI_Y tensors
    1203             :          CALL get_oIY_tensors(oI_Y, ispin, ao_blk_size, mo_blk_size(ispin)%array, ri_blk_size, &
    1204          34 :                               t_mo_coeffs, donor_state, xas_tdp_env, xas_tdp_control, qs_env)
    1205             : 
    1206             :          !intermediate clen-up
    1207          34 :          CALL dbt_destroy(t_mo_coeffs)
    1208          34 :          CALL dbcsr_distribution_release(mat_dist)
    1209          34 :          CALL dbcsr_release(dbcsr_mo_coeffs)
    1210          64 :          DEALLOCATE (ao_col_dist, ri_dist_3, ri_blk_size, ao_blk_size, ao_row_dist)
    1211             : 
    1212             :       END DO !ispin
    1213             : 
    1214             :       !create a empty tensor template for the fully contracted (ja|Io) MO integrals, for all spin
    1215             :       !configureations: alpha-alpha|alpha-alpha, alpha-alpha|beta-beta, etc.
    1216         376 :       ALLOCATE (ja_Io_template(nspins, nspins))
    1217          64 :       DO ispin = 1, nspins
    1218         106 :          DO jspin = 1, nspins
    1219         126 :             ALLOCATE (mo_dist_3(nblk_mos(jspin)))
    1220             :             CALL dbt_default_distvec(nblk_occ(jspin), pdims(3), mo_blk_size(jspin)%array(1:nblk_occ(jspin)), &
    1221          42 :                                      mo_dist_3(1:nblk_occ(jspin)))
    1222             :             CALL dbt_default_distvec(nblk_virt(jspin), pdims(3), &
    1223             :                                      mo_blk_size(jspin)%array(nblk_occ(jspin) + 1:nblk_mos(jspin)), &
    1224          42 :                                      mo_dist_3(nblk_occ(jspin) + 1:nblk_mos(jspin)))
    1225             :             CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=mo_row_dist(ispin)%array, &
    1226          42 :                                       nd_dist_2=mo_col_dist(ispin)%array, nd_dist_3=mo_dist_3)
    1227             : 
    1228             :             CALL dbt_create(ja_Io_template(ispin, jspin), name="(ja|Io)", dist=t_dist, map1_2d=[1, 2], &
    1229             :                             map2_2d=[3], blk_size_1=mo_blk_size(ispin)%array, &
    1230          42 :                             blk_size_2=mo_blk_size(ispin)%array, blk_size_3=mo_blk_size(jspin)%array)
    1231          42 :             CALL dbt_distribution_destroy(t_dist)
    1232          76 :             DEALLOCATE (mo_dist_3)
    1233             :          END DO
    1234             :       END DO
    1235             : 
    1236             :       !clean-up
    1237          30 :       CALL dbt_pgrid_destroy(t_pgrid)
    1238          64 :       DO ispin = 1, nspins
    1239          34 :          DEALLOCATE (mo_blk_size(ispin)%array)
    1240          34 :          DEALLOCATE (mo_col_dist(ispin)%array)
    1241          64 :          DEALLOCATE (mo_row_dist(ispin)%array)
    1242             :       END DO
    1243             : 
    1244          30 :       CALL timestop(handle)
    1245             : 
    1246         150 :    END SUBROUTINE contract_AOs_to_MOs
    1247             : 
    1248             : ! **************************************************************************************************
    1249             : !> \brief Contracts the (oI|Y) tensors, for each donor MO
    1250             : !> \param oI_Y the contracted tensr. It is assumed to be allocated outside of this routine
    1251             : !> \param ispin ...
    1252             : !> \param ao_blk_size ...
    1253             : !> \param mo_blk_size ...
    1254             : !> \param ri_blk_size ...
    1255             : !> \param t_mo_coeffs ...
    1256             : !> \param donor_state ...
    1257             : !> \param xas_tdp_env ...
    1258             : !> \param xas_tdp_control ...
    1259             : !> \param qs_env ...
    1260             : ! **************************************************************************************************
    1261          34 :    SUBROUTINE get_oIY_tensors(oI_Y, ispin, ao_blk_size, mo_blk_size, ri_blk_size, t_mo_coeffs, &
    1262             :                               donor_state, xas_tdp_env, xas_tdp_control, qs_env)
    1263             : 
    1264             :       TYPE(dbt_type), ALLOCATABLE, DIMENSION(:), &
    1265             :          INTENT(INOUT)                                   :: oI_Y
    1266             :       INTEGER, INTENT(IN)                                :: ispin
    1267             :       INTEGER, DIMENSION(:), POINTER                     :: ao_blk_size, mo_blk_size, ri_blk_size
    1268             :       TYPE(dbt_type), INTENT(inout)                      :: t_mo_coeffs
    1269             :       TYPE(donor_state_type), POINTER                    :: donor_state
    1270             :       TYPE(xas_tdp_env_type), POINTER                    :: xas_tdp_env
    1271             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
    1272             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1273             : 
    1274             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'get_oIY_tensors'
    1275             : 
    1276             :       INTEGER                                            :: bo(2), handle, i, ido_mo, ind(2), natom, &
    1277             :                                                             nblk_aos, nblk_mos, nblk_ri, ndo_mo, &
    1278             :                                                             pdims_2d(2), proc_id
    1279          34 :       INTEGER, DIMENSION(:), POINTER                     :: ao_row_dist, mo_row_dist, ri_col_dist
    1280          34 :       INTEGER, DIMENSION(:, :), POINTER                  :: mat_pgrid
    1281             :       LOGICAL                                            :: found
    1282          34 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: pblock
    1283             :       TYPE(dbcsr_distribution_type), POINTER             :: std_mat_dist
    1284          34 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: pI_Y
    1285         306 :       TYPE(dbt_distribution_type)                        :: t_dist
    1286             :       TYPE(dbt_iterator_type)                            :: iter
    1287         102 :       TYPE(dbt_pgrid_type)                               :: t_pgrid
    1288         578 :       TYPE(dbt_type)                                     :: t_pI_Y, t_work
    1289             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    1290             : 
    1291          34 :       CALL timeset(routineN, handle)
    1292             : 
    1293          34 :       CALL get_qs_env(qs_env, natom=natom, para_env=para_env, dbcsr_dist=std_mat_dist)
    1294          34 :       ndo_mo = donor_state%ndo_mo
    1295          34 :       nblk_aos = SIZE(ao_blk_size)
    1296          34 :       nblk_mos = SIZE(mo_blk_size)
    1297          34 :       nblk_ri = SIZE(ri_blk_size)
    1298             : 
    1299             :       !We first contract (pq|X) over q into I using kernel routines (goes over all MOs and spins)
    1300          34 :       CALL contract2_AO_to_doMO(pI_Y, "EXCHANGE", donor_state, xas_tdp_env, xas_tdp_control, qs_env)
    1301             : 
    1302             :       !multiply by (X|Y)^-1
    1303          34 :       CALL ri_all_blocks_mm(pI_Y, xas_tdp_env%ri_inv_ex)
    1304             : 
    1305             :       !get standaed 2d matrix proc grid
    1306          34 :       CALL dbcsr_distribution_get(std_mat_dist, pgrid=mat_pgrid)
    1307             : 
    1308             :       !Loop over donor MOs of this spin
    1309          76 :       DO ido_mo = (ispin - 1)*ndo_mo + 1, ispin*ndo_mo
    1310             : 
    1311             :          !cast the matrix into a tensor
    1312          42 :          CALL dbt_create(pI_Y(ido_mo)%matrix, t_work)
    1313          42 :          CALL dbt_copy_matrix_to_tensor(pI_Y(ido_mo)%matrix, t_work)
    1314             : 
    1315             :          !find col proc_id of the only populated column of t_work
    1316         126 :          ALLOCATE (ri_col_dist(natom))
    1317          42 :          CALL dbt_get_info(t_work, proc_dist_2=ri_col_dist)
    1318          42 :          proc_id = ri_col_dist(donor_state%at_index)
    1319          42 :          DEALLOCATE (ri_col_dist)
    1320             : 
    1321             :          !preapre (oI_Y) tensor and (pI|Y) tensor in proper dist and blk sizes
    1322          42 :          pdims_2d(1) = SIZE(mat_pgrid, 1); pdims_2d(2) = SIZE(mat_pgrid, 2)
    1323          42 :          CALL dbt_pgrid_create(para_env, pdims_2d, t_pgrid)
    1324             : 
    1325         294 :          ALLOCATE (ri_col_dist(nblk_ri), ao_row_dist(nblk_aos), mo_row_dist(nblk_mos))
    1326          42 :          CALL dbt_get_info(t_work, proc_dist_1=ao_row_dist)
    1327          84 :          ri_col_dist = proc_id
    1328             : 
    1329          42 :          CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=ao_row_dist, nd_dist_2=ri_col_dist)
    1330             :          CALL dbt_create(t_pI_Y, name="(pI|Y)", dist=t_dist, map1_2d=[1], map2_2d=[2], &
    1331          42 :                          blk_size_1=ao_blk_size, blk_size_2=ri_blk_size)
    1332          42 :          CALL dbt_distribution_destroy(t_dist)
    1333             : 
    1334             :          !copy block by block, dist match
    1335             : 
    1336             : !$OMP PARALLEL DEFAULT(NONE) SHARED(t_work,t_pI_Y,nblk_ri,ri_blk_size,ao_blk_size) &
    1337          42 : !$OMP PRIVATE(iter,ind,pblock,found,bo)
    1338             :          CALL dbt_iterator_start(iter, t_work)
    1339             :          DO WHILE (dbt_iterator_blocks_left(iter))
    1340             :             CALL dbt_iterator_next_block(iter, ind)
    1341             :             CALL dbt_get_block(t_work, ind, pblock, found)
    1342             : 
    1343             :             DO i = 1, nblk_ri
    1344             :                bo(1) = SUM(ri_blk_size(1:i - 1)) + 1
    1345             :                bo(2) = bo(1) + ri_blk_size(i) - 1
    1346             :                CALL dbt_put_block(t_pI_Y, [ind(1), i], [ao_blk_size(ind(1)), ri_blk_size(i)], &
    1347             :                                   pblock(:, bo(1):bo(2)))
    1348             :             END DO
    1349             : 
    1350             :             DEALLOCATE (pblock)
    1351             :          END DO
    1352             :          CALL dbt_iterator_stop(iter)
    1353             : !$OMP END PARALLEL
    1354          42 :          CALL dbt_finalize(t_pI_Y)
    1355             : 
    1356             :          !get optimal pgrid  for (oI|Y)
    1357          42 :          CALL dbt_pgrid_destroy(t_pgrid)
    1358          42 :          pdims_2d = 0
    1359         126 :          CALL dbt_pgrid_create(para_env, pdims_2d, t_pgrid, tensor_dims=[nblk_mos, nblk_ri])
    1360             : 
    1361          42 :          CALL dbt_default_distvec(nblk_aos, pdims_2d(1), ao_blk_size, ao_row_dist)
    1362          42 :          CALL dbt_default_distvec(nblk_mos, pdims_2d(1), mo_blk_size, mo_row_dist)
    1363          42 :          CALL dbt_default_distvec(nblk_ri, pdims_2d(2), ri_blk_size, ri_col_dist)
    1364             : 
    1365             :          !transfer pI_Y to the correct pgrid
    1366          42 :          CALL dbt_destroy(t_work)
    1367          42 :          CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=ao_row_dist, nd_dist_2=ri_col_dist)
    1368             :          CALL dbt_create(t_work, name="t_pI_Y", dist=t_dist, map1_2d=[1], map2_2d=[2], &
    1369          42 :                          blk_size_1=ao_blk_size, blk_size_2=ri_blk_size)
    1370          42 :          CALL dbt_copy(t_pI_Y, t_work, move_data=.TRUE.)
    1371          42 :          CALL dbt_distribution_destroy(t_dist)
    1372             : 
    1373             :          !create (oI|Y)
    1374          42 :          CALL dbt_distribution_new(t_dist, t_pgrid, nd_dist_1=mo_row_dist, nd_dist_2=ri_col_dist)
    1375             :          CALL dbt_create(oI_Y(ido_mo), name="(oI|Y)", dist=t_dist, map1_2d=[1], map2_2d=[2], &
    1376          42 :                          blk_size_1=mo_blk_size, blk_size_2=ri_blk_size)
    1377          42 :          CALL dbt_distribution_destroy(t_dist)
    1378             : 
    1379             :          !contract (pI|Y) into (oI|Y)
    1380             :          CALL dbt_contract(alpha=1.0_dp, tensor_1=t_work, tensor_2=t_mo_coeffs, &
    1381             :                            beta=0.0_dp, tensor_3=oI_Y(ido_mo), contract_1=[1], &
    1382             :                            notcontract_1=[2], contract_2=[1], notcontract_2=[2], &
    1383          42 :                            map_1=[2], map_2=[1]) !no bound, all MOs needed
    1384             : 
    1385             :          !intermediate clean-up
    1386          42 :          CALL dbt_destroy(t_work)
    1387          42 :          CALL dbt_destroy(t_pI_Y)
    1388          42 :          CALL dbt_pgrid_destroy(t_pgrid)
    1389          76 :          DEALLOCATE (ri_col_dist, ao_row_dist, mo_row_dist)
    1390             : 
    1391             :       END DO !ido_mo
    1392             : 
    1393             :       !final clean-up
    1394          34 :       CALL dbcsr_deallocate_matrix_set(pI_Y)
    1395             : 
    1396          34 :       CALL timestop(handle)
    1397             : 
    1398          68 :    END SUBROUTINE get_oIY_tensors
    1399             : 
    1400             : ! **************************************************************************************************
    1401             : !> \brief Computes the same spin, occupied-occupied-virtual MO contribution to the electron propagator:
    1402             : !>        0.5 * sum_ajk |<Ia||jk>|^2/(omega + eps_a - epsj_j - eps_k) and its 1st derivative wrt omega:
    1403             : !>        -0.5 * sum_ajk |<Ia||jk>|^2/(omega + eps_a - epsj_j - eps_k)**2
    1404             : !> \param contrib ...
    1405             : !> \param dev the first derivative
    1406             : !> \param ja_Ik_diff ... contains the (ja|Ik) - (ka|Ij) tensor
    1407             : !> \param occ_evals ...
    1408             : !> \param virt_evals ...
    1409             : !> \param omega ...
    1410             : !> \param c_ss ...
    1411             : !> \note since the is same-spin, there is only one possibility for occ_evals and virt_evals
    1412             : ! **************************************************************************************************
    1413         174 :    SUBROUTINE calc_ss_oov_contrib(contrib, dev, ja_Ik_diff, occ_evals, virt_evals, omega, c_ss)
    1414             : 
    1415             :       REAL(dp), INTENT(inout)                            :: contrib, dev
    1416             :       TYPE(dbt_type), INTENT(inout)                      :: ja_Ik_diff
    1417             :       REAL(dp), DIMENSION(:), INTENT(IN)                 :: occ_evals, virt_evals
    1418             :       REAL(dp), INTENT(in)                               :: omega, c_ss
    1419             : 
    1420             :       CHARACTER(len=*), PARAMETER :: routineN = 'calc_ss_oov_contrib'
    1421             : 
    1422             :       INTEGER                                            :: a, boff(3), bsize(3), handle, idx1, &
    1423             :                                                             idx2, idx3, ind(3), j, k, nocc
    1424             :       LOGICAL                                            :: found
    1425             :       REAL(dp)                                           :: denom, tmp
    1426         174 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: tensor_blk
    1427             :       TYPE(dbt_iterator_type)                            :: iter
    1428             : 
    1429         174 :       CALL timeset(routineN, handle)
    1430             : 
    1431             :       !<Ia||jk> = <Ia|jk> - <Ia|kj> = (Ij|ak) - (Ik|aj) = (ka|Ij) - (ja|Ik)
    1432             :       !Note: the same spin contribution only involve spib-orbitals that are all of the same spin
    1433             : 
    1434         174 :       nocc = SIZE(occ_evals, 1)
    1435             : 
    1436             :       !Iterate over the tensors and sum. Both tensors have same dist
    1437             : 
    1438             : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:contrib,dev) &
    1439             : !$OMP SHARED(ja_Ik_diff,occ_evals,virt_evals,omega,c_ss,nocc) &
    1440         174 : !$OMP PRIVATE(iter,ind,boff,bsize,tensor_blk,found,idx1,idx2,idx3,j,A,k,denom,tmp)
    1441             :       CALL dbt_iterator_start(iter, ja_Ik_diff)
    1442             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1443             :          CALL dbt_iterator_next_block(iter, ind, blk_offset=boff, blk_size=bsize)
    1444             :          CALL dbt_get_block(ja_Ik_diff, ind, tensor_blk, found)
    1445             : 
    1446             :          IF (found) THEN
    1447             : 
    1448             :             DO idx3 = 1, bsize(3)
    1449             :                DO idx2 = 1, bsize(2)
    1450             :                   DO idx1 = 1, bsize(1)
    1451             : 
    1452             :                      !get proper MO indices
    1453             :                      j = boff(1) + idx1 - 1
    1454             :                      a = boff(2) + idx2 - 1 - nocc
    1455             :                      k = boff(3) + idx3 - 1
    1456             : 
    1457             :                      !the denominator
    1458             :                      denom = omega + virt_evals(a) - occ_evals(j) - occ_evals(k)
    1459             : 
    1460             :                      !the same spin contribution
    1461             :                      tmp = c_ss*tensor_blk(idx1, idx2, idx3)**2
    1462             : 
    1463             :                      contrib = contrib + 0.5_dp*tmp/denom
    1464             :                      dev = dev - 0.5_dp*tmp/denom**2
    1465             : 
    1466             :                   END DO
    1467             :                END DO
    1468             :             END DO
    1469             :          END IF
    1470             :          DEALLOCATE (tensor_blk)
    1471             :       END DO
    1472             :       CALL dbt_iterator_stop(iter)
    1473             : !$OMP END PARALLEL
    1474             : 
    1475         174 :       CALL timestop(handle)
    1476             : 
    1477         348 :    END SUBROUTINE calc_ss_oov_contrib
    1478             : 
    1479             : ! **************************************************************************************************
    1480             : !> \brief Computes the opposite spin, occupied-occupied-virtual MO contribution to the electron propagator:
    1481             : !>        0.5 * sum_ajk |<Ia||jk>|^2/(omega + eps_a - epsj_j - eps_k) and its 1st derivative wrt omega:
    1482             : !>        -0.5 * sum_ajk |<Ia||jk>|^2/(omega + eps_a - epsj_j - eps_k)**2
    1483             : !> \param contrib ...
    1484             : !> \param dev the first derivative
    1485             : !> \param ja_Ik ...
    1486             : !> \param j_evals ocucpied evals for j MO
    1487             : !> \param a_evals virtual evals for a MO
    1488             : !> \param k_evals ocucpied evals for k MO
    1489             : !> \param omega ...
    1490             : !> \param c_os ...
    1491             : !> \param a_offset the number of occupied MOs for the same spin as a MOs
    1492             : !> \note since this is opposite-spin, evals might be different for different spins
    1493             : ! **************************************************************************************************
    1494         174 :    SUBROUTINE calc_os_oov_contrib(contrib, dev, ja_Ik, j_evals, a_evals, k_evals, omega, c_os, a_offset)
    1495             : 
    1496             :       REAL(dp), INTENT(inout)                            :: contrib, dev
    1497             :       TYPE(dbt_type), INTENT(inout)                      :: ja_Ik
    1498             :       REAL(dp), DIMENSION(:), INTENT(IN)                 :: j_evals, a_evals, k_evals
    1499             :       REAL(dp), INTENT(in)                               :: omega, c_os
    1500             :       INTEGER, INTENT(IN)                                :: a_offset
    1501             : 
    1502             :       CHARACTER(len=*), PARAMETER :: routineN = 'calc_os_oov_contrib'
    1503             : 
    1504             :       INTEGER                                            :: a, boff(3), bsize(3), handle, idx1, &
    1505             :                                                             idx2, idx3, ind(3), j, k
    1506             :       LOGICAL                                            :: found
    1507             :       REAL(dp)                                           :: denom, tmp
    1508         174 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: ja_Ik_blk
    1509             :       TYPE(dbt_iterator_type)                            :: iter
    1510             : 
    1511         174 :       CALL timeset(routineN, handle)
    1512             : 
    1513             :       !<Ia||jk> = <Ia|jk> - <Ia|kj> = (Ij|ak) - (Ik|aj) = (ka|Ij) - (ja|Ik)
    1514             :       !Note: the opposite spin contribution comes in 2 parts, once (ka|Ij) and once (ja|Ik) only,
    1515             :       !      where both spin-orbitals on one center have the same spin, but it is the opposite of
    1516             :       !      the spin on the other center. Because it is eventually summed, can consider only one
    1517             :       !      of the 2 terms, but with a factor 2
    1518             : 
    1519             :       !Iterate over the tensor and sum
    1520             : 
    1521             : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:contrib,dev) &
    1522             : !$OMP SHARED(ja_Ik,j_evals,a_evals,k_evals,omega,c_os,a_offset) &
    1523         174 : !$OMP PRIVATE(iter,ind,boff,bsize,ja_Ik_blk,found,idx1,idx2,idx3,j,A,k,denom,tmp)
    1524             :       CALL dbt_iterator_start(iter, ja_Ik)
    1525             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1526             :          CALL dbt_iterator_next_block(iter, ind, blk_offset=boff, blk_size=bsize)
    1527             :          CALL dbt_get_block(ja_Ik, ind, ja_Ik_blk, found)
    1528             : 
    1529             :          IF (found) THEN
    1530             : 
    1531             :             DO idx3 = 1, bsize(3)
    1532             :                DO idx2 = 1, bsize(2)
    1533             :                   DO idx1 = 1, bsize(1)
    1534             : 
    1535             :                      !get proper MO indices
    1536             :                      j = boff(1) + idx1 - 1
    1537             :                      a = boff(2) + idx2 - 1 - a_offset
    1538             :                      k = boff(3) + idx3 - 1
    1539             : 
    1540             :                      !the denominator
    1541             :                      denom = omega + a_evals(a) - j_evals(j) - k_evals(k)
    1542             : 
    1543             :                      !the opposite spin contribution
    1544             :                      tmp = c_os*ja_Ik_blk(idx1, idx2, idx3)**2
    1545             : 
    1546             :                      !take factor 2 into acocunt (2 x 0.5 = 1)
    1547             :                      contrib = contrib + tmp/denom
    1548             :                      dev = dev - tmp/denom**2
    1549             : 
    1550             :                   END DO
    1551             :                END DO
    1552             :             END DO
    1553             :          END IF
    1554             :          DEALLOCATE (ja_Ik_blk)
    1555             :       END DO
    1556             :       CALL dbt_iterator_stop(iter)
    1557             : !$OMP END PARALLEL
    1558             : 
    1559         174 :       CALL timestop(handle)
    1560             : 
    1561         348 :    END SUBROUTINE calc_os_oov_contrib
    1562             : 
    1563             : ! **************************************************************************************************
    1564             : !> \brief Computes the same-spin occupied-virtual-virtual MO contribution to the electron propagator:
    1565             : !>        0.5 * sum_abj |<Ij||ab>|^2/(omega + eps_j - eps_a - eps_b) as well as its first derivative:
    1566             : !>        -0.5 * sum_abj |<Ij||ab>|^2/(omega + eps_j - eps_a - eps_b)**2
    1567             : !> \param contrib ...
    1568             : !> \param dev the first derivative
    1569             : !> \param aj_Ib_diff contatins the (aj|Ib) - (bj|Ia) tensor
    1570             : !> \param occ_evals ...
    1571             : !> \param virt_evals ...
    1572             : !> \param omega ...
    1573             : !> \param c_ss ...
    1574             : !> \note since the is same-spin, there is only one possibility for occ_evals and virt_evals
    1575             : ! **************************************************************************************************
    1576         166 :    SUBROUTINE calc_ss_ovv_contrib(contrib, dev, aj_Ib_diff, occ_evals, virt_evals, omega, c_ss)
    1577             : 
    1578             :       REAL(dp), INTENT(inout)                            :: contrib, dev
    1579             :       TYPE(dbt_type), INTENT(inout)                      :: aj_Ib_diff
    1580             :       REAL(dp), DIMENSION(:), INTENT(IN)                 :: occ_evals, virt_evals
    1581             :       REAL(dp), INTENT(in)                               :: omega, c_ss
    1582             : 
    1583             :       CHARACTER(len=*), PARAMETER :: routineN = 'calc_ss_ovv_contrib'
    1584             : 
    1585             :       INTEGER                                            :: a, b, boff(3), bsize(3), handle, idx1, &
    1586             :                                                             idx2, idx3, ind(3), j, nocc
    1587             :       LOGICAL                                            :: found
    1588             :       REAL(dp)                                           :: denom, tmp
    1589         166 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: tensor_blk
    1590             :       TYPE(dbt_iterator_type)                            :: iter
    1591             : 
    1592         166 :       CALL timeset(routineN, handle)
    1593             : 
    1594             :       !<Ij||ab> = <Ij|ab> - <Ij|ba> = (Ia|jb) - (Ib|ja) = (jb|Ia) - (ja|Ib)
    1595             :       !Notes: only non-zero contribution if all MOs have the same spin
    1596             : 
    1597         166 :       nocc = SIZE(occ_evals, 1)
    1598             : 
    1599             :       !tensors have matching distributions, can do that safely
    1600             : 
    1601             : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:contrib,dev) &
    1602             : !$OMP SHARED(aj_Ib_diff,occ_evals,virt_evals,omega,c_ss,nocc) &
    1603         166 : !$OMP PRIVATE(iter,ind,boff,bsize,tensor_blk,found,idx1,idx2,idx3,j,A,b,denom,tmp)
    1604             :       CALL dbt_iterator_start(iter, aj_Ib_diff)
    1605             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1606             :          CALL dbt_iterator_next_block(iter, ind, blk_offset=boff, blk_size=bsize)
    1607             :          CALL dbt_get_block(aj_Ib_diff, ind, tensor_blk, found)
    1608             : 
    1609             :          IF (found) THEN
    1610             : 
    1611             :             DO idx3 = 1, bsize(3)
    1612             :                DO idx2 = 1, bsize(2)
    1613             :                   DO idx1 = 1, bsize(1)
    1614             : 
    1615             :                      !get proper MO indices
    1616             :                      a = boff(1) + idx1 - 1 - nocc
    1617             :                      j = boff(2) + idx2 - 1
    1618             :                      b = boff(3) + idx3 - 1 - nocc
    1619             : 
    1620             :                      !the common denominator
    1621             :                      denom = omega + occ_evals(j) - virt_evals(a) - virt_evals(b)
    1622             : 
    1623             :                      !the same spin contribution
    1624             :                      tmp = c_ss*tensor_blk(idx1, idx2, idx3)**2
    1625             : 
    1626             :                      contrib = contrib + 0.5_dp*tmp/denom
    1627             :                      dev = dev - 0.5_dp*tmp/denom**2
    1628             : 
    1629             :                   END DO
    1630             :                END DO
    1631             :             END DO
    1632             :          END IF
    1633             :          DEALLOCATE (tensor_blk)
    1634             :       END DO
    1635             :       CALL dbt_iterator_stop(iter)
    1636             : !$OMP END PARALLEL
    1637             : 
    1638         166 :       CALL timestop(handle)
    1639             : 
    1640         332 :    END SUBROUTINE calc_ss_ovv_contrib
    1641             : 
    1642             : ! **************************************************************************************************
    1643             : !> \brief Computes the opposite-spin occupied-virtual-virtual MO contribution to the electron propagator:
    1644             : !>        0.5 * sum_abj |<Ij||ab>|^2/(omega + eps_j - eps_a - eps_b) as well as its first derivative:
    1645             : !>        -0.5 * sum_abj |<Ij||ab>|^2/(omega + eps_j - eps_a - eps_b)**2
    1646             : !> \param contrib ...
    1647             : !> \param dev the first derivative
    1648             : !> \param aj_Ib ...
    1649             : !> \param a_evals virtual evals for a MO
    1650             : !> \param j_evals occupied evals for j MO
    1651             : !> \param b_evals virtual evals for b MO
    1652             : !> \param omega ...
    1653             : !> \param c_os ...
    1654             : !> \param a_offset number of occupied MOs for the same spin as a MO
    1655             : !> \param b_offset number of occupied MOs for the same spin as b MO
    1656             : !> \note since this is opposite-spin, evals might be different for different spins
    1657             : ! **************************************************************************************************
    1658         166 :    SUBROUTINE calc_os_ovv_contrib(contrib, dev, aj_Ib, a_evals, j_evals, b_evals, omega, c_os, &
    1659             :                                   a_offset, b_offset)
    1660             : 
    1661             :       REAL(dp), INTENT(inout)                            :: contrib, dev
    1662             :       TYPE(dbt_type), INTENT(inout)                      :: aj_Ib
    1663             :       REAL(dp), DIMENSION(:), INTENT(IN)                 :: a_evals, j_evals, b_evals
    1664             :       REAL(dp), INTENT(in)                               :: omega, c_os
    1665             :       INTEGER, INTENT(IN)                                :: a_offset, b_offset
    1666             : 
    1667             :       CHARACTER(len=*), PARAMETER :: routineN = 'calc_os_ovv_contrib'
    1668             : 
    1669             :       INTEGER                                            :: a, b, boff(3), bsize(3), handle, idx1, &
    1670             :                                                             idx2, idx3, ind(3), j
    1671             :       LOGICAL                                            :: found
    1672             :       REAL(dp)                                           :: denom, tmp
    1673         166 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: aj_Ib_blk
    1674             :       TYPE(dbt_iterator_type)                            :: iter
    1675             : 
    1676         166 :       CALL timeset(routineN, handle)
    1677             : 
    1678             :       !<Ij||ab> = <Ij|ab> - <Ij|ba> = (Ia|jb) - (Ib|ja) = (jb|Ia) - (ja|Ib)
    1679             :       !Notes: only 2 distinct contributions, once from (jb|Ia) and once form (ja|Ib) only, when the 2
    1680             :       !       MOs on one center have one spin and the 2 MOs on the other center have another spin
    1681             :       !       In the end, the sum is such that can take one of those with a factor 2
    1682             : 
    1683             : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:contrib,dev) &
    1684             : !$OMP SHARED(aj_Ib,a_evals,j_evals,b_evals,omega,c_os,a_offset,b_offset) &
    1685         166 : !$OMP PRIVATE(iter,ind,boff,bsize,aj_Ib_blk,found,idx1,idx2,idx3,j,A,b,denom,tmp)
    1686             :       CALL dbt_iterator_start(iter, aj_Ib)
    1687             :       DO WHILE (dbt_iterator_blocks_left(iter))
    1688             :          CALL dbt_iterator_next_block(iter, ind, blk_offset=boff, blk_size=bsize)
    1689             :          CALL dbt_get_block(aj_Ib, ind, aj_Ib_blk, found)
    1690             : 
    1691             :          IF (found) THEN
    1692             : 
    1693             :             DO idx3 = 1, bsize(3)
    1694             :                DO idx2 = 1, bsize(2)
    1695             :                   DO idx1 = 1, bsize(1)
    1696             : 
    1697             :                      !get proper MO indices
    1698             :                      a = boff(1) + idx1 - 1 - a_offset
    1699             :                      j = boff(2) + idx2 - 1
    1700             :                      b = boff(3) + idx3 - 1 - b_offset
    1701             : 
    1702             :                      !the denominator
    1703             :                      denom = omega + j_evals(j) - a_evals(a) - b_evals(b)
    1704             : 
    1705             :                      !the opposite-spin contribution. Factor 2 taken into account (2 x 0.5 = 1)
    1706             :                      tmp = c_os*(aj_Ib_blk(idx1, idx2, idx3))**2
    1707             : 
    1708             :                      contrib = contrib + tmp/denom
    1709             :                      dev = dev - tmp/denom**2
    1710             : 
    1711             :                   END DO
    1712             :                END DO
    1713             :             END DO
    1714             :          END IF
    1715             :          DEALLOCATE (aj_Ib_blk)
    1716             :       END DO
    1717             :       CALL dbt_iterator_stop(iter)
    1718             : !$OMP END PARALLEL
    1719             : 
    1720         166 :       CALL timestop(handle)
    1721             : 
    1722         332 :    END SUBROUTINE calc_os_ovv_contrib
    1723             : 
    1724             : ! **************************************************************************************************
    1725             : !> \brief We try to compute the spin-orbit splitting via perturbation theory. We keep it
    1726             : !>\        cheap by only inculding the degenerate states (2p, 3d, 3p, etc.).
    1727             : !> \param soc_shifts the SOC corrected orbital shifts to apply to original energies, for both spins
    1728             : !> \param donor_state ...
    1729             : !> \param xas_tdp_env ...
    1730             : !> \param xas_tdp_control ...
    1731             : !> \param qs_env ...
    1732             : ! **************************************************************************************************
    1733           4 :    SUBROUTINE get_soc_splitting(soc_shifts, donor_state, xas_tdp_env, xas_tdp_control, qs_env)
    1734             : 
    1735             :       REAL(dp), ALLOCATABLE, DIMENSION(:, :), &
    1736             :          INTENT(out)                                     :: soc_shifts
    1737             :       TYPE(donor_state_type), POINTER                    :: donor_state
    1738             :       TYPE(xas_tdp_env_type), POINTER                    :: xas_tdp_env
    1739             :       TYPE(xas_tdp_control_type), POINTER                :: xas_tdp_control
    1740             :       TYPE(qs_environment_type), POINTER                 :: qs_env
    1741             : 
    1742             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'get_soc_splitting'
    1743             : 
    1744           4 :       COMPLEX(dp), ALLOCATABLE, DIMENSION(:, :)          :: evecs, hami
    1745             :       INTEGER                                            :: beta_spin, handle, ialpha, ibeta, &
    1746             :                                                             ido_mo, ispin, nao, ndo_mo, ndo_so, &
    1747             :                                                             nspins
    1748             :       REAL(dp)                                           :: alpha_tot_contrib, beta_tot_contrib
    1749           4 :       REAL(dp), ALLOCATABLE, DIMENSION(:)                :: evals
    1750           4 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: tmp_shifts
    1751             :       TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
    1752             :       TYPE(cp_cfm_type)                                  :: hami_cfm
    1753             :       TYPE(cp_fm_struct_type), POINTER                   :: ao_domo_struct, domo_domo_struct, &
    1754             :                                                             doso_doso_struct
    1755             :       TYPE(cp_fm_type)                                   :: alpha_gs_coeffs, ao_domo_work, &
    1756             :                                                             beta_gs_coeffs, domo_domo_work, &
    1757             :                                                             img_fm, real_fm
    1758           4 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks
    1759             :       TYPE(dbcsr_type), POINTER                          :: orb_soc_x, orb_soc_y, orb_soc_z
    1760             :       TYPE(mp_para_env_type), POINTER                    :: para_env
    1761             : 
    1762           4 :       NULLIFY (matrix_ks, para_env, blacs_env, ao_domo_struct, domo_domo_struct, &
    1763           4 :                doso_doso_struct, orb_soc_x, orb_soc_y, orb_soc_z)
    1764             : 
    1765           4 :       CALL timeset(routineN, handle)
    1766             : 
    1767             :       ! Idea: we compute the SOC matrix in the space of the degenerate spin-orbitals, add it to
    1768             :       !       the KS matrix in the same basis, diagonalize the whole thing and get the corrected energies
    1769             :       !       for SOC
    1770             : 
    1771           4 :       CALL get_qs_env(qs_env, matrix_ks=matrix_ks, para_env=para_env, blacs_env=blacs_env)
    1772             : 
    1773           4 :       orb_soc_x => xas_tdp_env%orb_soc(1)%matrix
    1774           4 :       orb_soc_y => xas_tdp_env%orb_soc(2)%matrix
    1775           4 :       orb_soc_z => xas_tdp_env%orb_soc(3)%matrix
    1776             : 
    1777             :       ! Whether it is open-shell or not, we have 2*ndo_mo spin-orbitals
    1778           4 :       nspins = 2
    1779           4 :       ndo_mo = donor_state%ndo_mo
    1780           4 :       ndo_so = nspins*ndo_mo
    1781           4 :       CALL dbcsr_get_info(matrix_ks(1)%matrix, nfullrows_total=nao)
    1782             : 
    1783             :       ! Build the fm infrastructure
    1784             :       CALL cp_fm_struct_create(ao_domo_struct, context=blacs_env, para_env=para_env, &
    1785           4 :                                nrow_global=nao, ncol_global=ndo_mo)
    1786             :       CALL cp_fm_struct_create(domo_domo_struct, context=blacs_env, para_env=para_env, &
    1787           4 :                                nrow_global=ndo_mo, ncol_global=ndo_mo)
    1788             :       CALL cp_fm_struct_create(doso_doso_struct, context=blacs_env, para_env=para_env, &
    1789           4 :                                nrow_global=ndo_so, ncol_global=ndo_so)
    1790             : 
    1791           4 :       CALL cp_fm_create(alpha_gs_coeffs, ao_domo_struct)
    1792           4 :       CALL cp_fm_create(beta_gs_coeffs, ao_domo_struct)
    1793           4 :       CALL cp_fm_create(ao_domo_work, ao_domo_struct)
    1794           4 :       CALL cp_fm_create(domo_domo_work, domo_domo_struct)
    1795           4 :       CALL cp_fm_create(real_fm, doso_doso_struct)
    1796           4 :       CALL cp_fm_create(img_fm, doso_doso_struct)
    1797             : 
    1798             :       ! Put the gs_coeffs in the correct format.
    1799           4 :       IF (xas_tdp_control%do_uks) THEN
    1800             : 
    1801             :          CALL cp_fm_to_fm_submat(msource=donor_state%gs_coeffs, mtarget=alpha_gs_coeffs, nrow=nao, &
    1802           0 :                                  ncol=ndo_mo, s_firstrow=1, s_firstcol=1, t_firstrow=1, t_firstcol=1)
    1803             :          CALL cp_fm_to_fm_submat(msource=donor_state%gs_coeffs, mtarget=beta_gs_coeffs, nrow=nao, &
    1804           0 :                                  ncol=ndo_mo, s_firstrow=1, s_firstcol=ndo_mo + 1, t_firstrow=1, t_firstcol=1)
    1805             : 
    1806             :       ELSE
    1807             : 
    1808           4 :          CALL cp_fm_to_fm(donor_state%gs_coeffs, alpha_gs_coeffs)
    1809           4 :          CALL cp_fm_to_fm(donor_state%gs_coeffs, beta_gs_coeffs)
    1810             :       END IF
    1811             : 
    1812             :       ! Compute the KS matrix in this basis, add it to the real part of the final matrix
    1813             :       !alpha-alpha block in upper left quadrant
    1814           4 :       CALL cp_dbcsr_sm_fm_multiply(matrix_ks(1)%matrix, alpha_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1815             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, alpha_gs_coeffs, ao_domo_work, 0.0_dp, &
    1816           4 :                          domo_domo_work)
    1817             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=real_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1818           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=1, t_firstcol=1)
    1819             : 
    1820             :       !beta-beta block in lower right quadrant
    1821           4 :       beta_spin = 1; IF (xas_tdp_control%do_uks .OR. xas_tdp_control%do_roks) beta_spin = 2
    1822           4 :       CALL cp_dbcsr_sm_fm_multiply(matrix_ks(beta_spin)%matrix, beta_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1823             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, beta_gs_coeffs, ao_domo_work, 0.0_dp, &
    1824           4 :                          domo_domo_work)
    1825             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=real_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1826           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=ndo_mo + 1, t_firstcol=ndo_mo + 1)
    1827             : 
    1828             :       ! Compute the SOC matrix elements and add them to the real or imaginary part of the matrix
    1829             :       ! alpha-alpha block, only Hz not zero, purely imaginary, addition
    1830           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_z, alpha_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1831             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, alpha_gs_coeffs, ao_domo_work, 0.0_dp, &
    1832           4 :                          domo_domo_work)
    1833             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=img_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1834           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=1, t_firstcol=1)
    1835             : 
    1836             :       ! beta-beta block, only Hz not zero, purely imaginary, substraciton
    1837           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_z, beta_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1838             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, -1.0_dp, beta_gs_coeffs, ao_domo_work, 0.0_dp, &
    1839           4 :                          domo_domo_work)
    1840             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=img_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1841           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=ndo_mo + 1, t_firstcol=ndo_mo + 1)
    1842             : 
    1843             :       ! alpha-beta block, two non-zero terms in Hx and Hy
    1844             :       ! Hx term, purely imaginary, addition
    1845           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_x, beta_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1846             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, alpha_gs_coeffs, ao_domo_work, 0.0_dp, &
    1847           4 :                          domo_domo_work)
    1848             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=img_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1849           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=1, t_firstcol=ndo_mo + 1)
    1850             :       ! Hy term, purely real, addition
    1851           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_y, beta_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1852             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, alpha_gs_coeffs, ao_domo_work, 0.0_dp, &
    1853           4 :                          domo_domo_work)
    1854             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=real_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1855           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=1, t_firstcol=ndo_mo + 1)
    1856             : 
    1857             :       ! beta-alpha block, two non-zero terms in Hx and Hy
    1858             :       ! Hx term, purely imaginary, addition
    1859           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_x, alpha_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1860             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, 1.0_dp, beta_gs_coeffs, ao_domo_work, 0.0_dp, &
    1861           4 :                          domo_domo_work)
    1862             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=img_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1863           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=ndo_mo + 1, t_firstcol=1)
    1864             :       ! Hy term, purely real, substraction
    1865           4 :       CALL cp_dbcsr_sm_fm_multiply(orb_soc_y, alpha_gs_coeffs, ao_domo_work, ncol=ndo_mo)
    1866             :       CALL parallel_gemm('T', 'N', ndo_mo, ndo_mo, nao, -1.0_dp, beta_gs_coeffs, ao_domo_work, 0.0_dp, &
    1867           4 :                          domo_domo_work)
    1868             :       CALL cp_fm_to_fm_submat(msource=domo_domo_work, mtarget=real_fm, nrow=ndo_mo, ncol=ndo_mo, &
    1869           4 :                               s_firstrow=1, s_firstcol=1, t_firstrow=ndo_mo + 1, t_firstcol=1)
    1870             : 
    1871             :       ! Cast everything in complex fm format
    1872           4 :       CALL cp_cfm_create(hami_cfm, doso_doso_struct)
    1873           4 :       CALL cp_fm_to_cfm(real_fm, img_fm, hami_cfm)
    1874             : 
    1875             :       ! And diagonalize. Since tiny matrix (6x6), diagonalize locally
    1876          32 :       ALLOCATE (evals(ndo_so), evecs(ndo_so, ndo_so), hami(ndo_so, ndo_so))
    1877           4 :       CALL cp_cfm_get_submatrix(hami_cfm, hami)
    1878           4 :       CALL complex_diag(hami, evecs, evals)
    1879             : 
    1880             :       !The SOC corrected KS eigenvalues
    1881          12 :       ALLOCATE (tmp_shifts(ndo_mo, 2))
    1882             : 
    1883           4 :       ialpha = 1; ibeta = 1; 
    1884          28 :       DO ido_mo = 1, ndo_so
    1885             :          !need to find out whether the eigenvalue corresponds to an alpha or beta spin-orbtial
    1886          96 :          alpha_tot_contrib = REAL(DOT_PRODUCT(evecs(1:ndo_mo, ido_mo), evecs(1:ndo_mo, ido_mo)))
    1887          96 :          beta_tot_contrib = REAL(DOT_PRODUCT(evecs(ndo_mo + 1:ndo_so, ido_mo), evecs(ndo_mo + 1:ndo_so, ido_mo)))
    1888             : 
    1889          28 :          IF (alpha_tot_contrib > beta_tot_contrib) THEN
    1890          12 :             tmp_shifts(ialpha, 1) = evals(ido_mo)
    1891          12 :             ialpha = ialpha + 1
    1892             :          ELSE
    1893          12 :             tmp_shifts(ibeta, 2) = evals(ido_mo)
    1894          12 :             ibeta = ibeta + 1
    1895             :          END IF
    1896             :       END DO
    1897             : 
    1898             :       !compute shift from KS evals
    1899          16 :       ALLOCATE (soc_shifts(ndo_mo, SIZE(donor_state%energy_evals, 2)))
    1900           8 :       DO ispin = 1, SIZE(donor_state%energy_evals, 2)
    1901          20 :          soc_shifts(:, ispin) = tmp_shifts(:, ispin) - donor_state%energy_evals(:, ispin)
    1902             :       END DO
    1903             : 
    1904             :       ! clean-up
    1905           4 :       CALL cp_fm_release(alpha_gs_coeffs)
    1906           4 :       CALL cp_fm_release(beta_gs_coeffs)
    1907           4 :       CALL cp_fm_release(ao_domo_work)
    1908           4 :       CALL cp_fm_release(domo_domo_work)
    1909           4 :       CALL cp_fm_release(real_fm)
    1910           4 :       CALL cp_fm_release(img_fm)
    1911             : 
    1912           4 :       CALL cp_cfm_release(hami_cfm)
    1913             : 
    1914           4 :       CALL cp_fm_struct_release(ao_domo_struct)
    1915           4 :       CALL cp_fm_struct_release(domo_domo_struct)
    1916           4 :       CALL cp_fm_struct_release(doso_doso_struct)
    1917             : 
    1918           4 :       CALL timestop(handle)
    1919             : 
    1920          20 :    END SUBROUTINE get_soc_splitting
    1921             : 
    1922             : END MODULE xas_tdp_correction

Generated by: LCOV version 1.15