LCOV - code coverage report
Current view: top level - src - iterate_matrix.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 826 876 94.3 %
Date: 2024-11-21 06:45:46 Functions: 17 19 89.5 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : ! **************************************************************************************************
       8             : !> \brief Routines useful for iterative matrix calculations
       9             : !> \par History
      10             : !>       2010.10 created [Joost VandeVondele]
      11             : !> \author Joost VandeVondele
      12             : ! **************************************************************************************************
      13             : MODULE iterate_matrix
      14             :    USE arnoldi_api,                     ONLY: arnoldi_data_type,&
      15             :                                               arnoldi_extremal
      16             :    USE bibliography,                    ONLY: Richters2018,&
      17             :                                               cite_reference
      18             :    USE cp_dbcsr_api,                    ONLY: &
      19             :         dbcsr_add, dbcsr_add_on_diag, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, &
      20             :         dbcsr_distribution_get, dbcsr_distribution_type, dbcsr_filter, dbcsr_frobenius_norm, &
      21             :         dbcsr_gershgorin_norm, dbcsr_get_diag, dbcsr_get_info, dbcsr_get_matrix_type, &
      22             :         dbcsr_get_occupation, dbcsr_multiply, dbcsr_norm, dbcsr_norm_maxabsnorm, dbcsr_p_type, &
      23             :         dbcsr_release, dbcsr_scale, dbcsr_set, dbcsr_set_diag, dbcsr_trace, dbcsr_transposed, &
      24             :         dbcsr_type, dbcsr_type_no_symmetry
      25             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      26             :                                               cp_logger_get_default_unit_nr,&
      27             :                                               cp_logger_type
      28             :    USE input_constants,                 ONLY: ls_scf_submatrix_sign_direct,&
      29             :                                               ls_scf_submatrix_sign_direct_muadj,&
      30             :                                               ls_scf_submatrix_sign_direct_muadj_lowmem,&
      31             :                                               ls_scf_submatrix_sign_ns
      32             :    USE kinds,                           ONLY: dp,&
      33             :                                               int_8
      34             :    USE machine,                         ONLY: m_flush,&
      35             :                                               m_walltime
      36             :    USE mathconstants,                   ONLY: ifac
      37             :    USE mathlib,                         ONLY: abnormal_value
      38             :    USE message_passing,                 ONLY: mp_comm_type
      39             :    USE submatrix_dissection,            ONLY: submatrix_dissection_type
      40             : #include "./base/base_uses.f90"
      41             : 
      42             :    IMPLICIT NONE
      43             : 
      44             :    PRIVATE
      45             : 
      46             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'iterate_matrix'
      47             : 
      48             :    TYPE :: eigbuf
      49             :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE    :: eigvals
      50             :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: eigvecs
      51             :    END TYPE eigbuf
      52             : 
      53             :    INTERFACE purify_mcweeny
      54             :       MODULE PROCEDURE purify_mcweeny_orth, purify_mcweeny_nonorth
      55             :    END INTERFACE
      56             : 
      57             :    PUBLIC :: invert_Hotelling, matrix_sign_Newton_Schulz, matrix_sqrt_Newton_Schulz, &
      58             :              matrix_sqrt_proot, matrix_sign_proot, matrix_sign_submatrix, matrix_exponential, &
      59             :              matrix_sign_submatrix_mu_adjust, purify_mcweeny, invert_Taylor, determinant
      60             : 
      61             : CONTAINS
      62             : 
      63             : ! *****************************************************************************
      64             : !> \brief Computes the determinant of a symmetric positive definite matrix
      65             : !>        using the trace of the matrix logarithm via Mercator series:
      66             : !>         det(A) = det(S)det(I+X)det(S), where S=diag(sqrt(Aii),..,sqrt(Ann))
      67             : !>         det(I+X) = Exp(Trace(Ln(I+X)))
      68             : !>         Ln(I+X) = X - X^2/2 + X^3/3 - X^4/4 + ..
      69             : !>        The series converges only if the Frobenius norm of X is less than 1.
      70             : !>        If it is more than one we compute (recursevily) the determinant of
      71             : !>        the square root of (I+X).
      72             : !> \param matrix ...
      73             : !> \param det - determinant
      74             : !> \param threshold ...
      75             : !> \par History
      76             : !>       2015.04 created [Rustam Z Khaliullin]
      77             : !> \author Rustam Z. Khaliullin
      78             : ! **************************************************************************************************
      79         132 :    RECURSIVE SUBROUTINE determinant(matrix, det, threshold)
      80             : 
      81             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix
      82             :       REAL(KIND=dp), INTENT(INOUT)                       :: det
      83             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
      84             : 
      85             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'determinant'
      86             : 
      87             :       INTEGER                                            :: handle, i, max_iter_lanczos, nsize, &
      88             :                                                             order_lanczos, sign_iter, unit_nr
      89             :       INTEGER(KIND=int_8)                                :: flop1
      90             :       INTEGER, SAVE                                      :: recursion_depth = 0
      91             :       REAL(KIND=dp)                                      :: det0, eps_lanczos, frobnorm, maxnorm, &
      92             :                                                             occ_matrix, t1, t2, trace
      93         132 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: diagonal
      94             :       TYPE(cp_logger_type), POINTER                      :: logger
      95             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
      96             : 
      97         132 :       CALL timeset(routineN, handle)
      98             : 
      99             :       ! get a useful output_unit
     100         132 :       logger => cp_get_default_logger()
     101         132 :       IF (logger%para_env%is_source()) THEN
     102          66 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     103             :       ELSE
     104          66 :          unit_nr = -1
     105             :       END IF
     106             : 
     107             :       ! Note: tmp1 and tmp2 have the same matrix type as the
     108             :       ! initial matrix (tmp3 does not have symmetry constraints)
     109             :       ! this might lead to uninteded results with anti-symmetric
     110             :       ! matrices
     111             :       CALL dbcsr_create(tmp1, template=matrix, &
     112         132 :                         matrix_type=dbcsr_type_no_symmetry)
     113             :       CALL dbcsr_create(tmp2, template=matrix, &
     114         132 :                         matrix_type=dbcsr_type_no_symmetry)
     115             :       CALL dbcsr_create(tmp3, template=matrix, &
     116         132 :                         matrix_type=dbcsr_type_no_symmetry)
     117             : 
     118             :       ! compute the product of the diagonal elements
     119             :       BLOCK
     120             :          TYPE(mp_comm_type) :: group
     121             :          INTEGER :: group_handle
     122         132 :          CALL dbcsr_get_info(matrix, nfullrows_total=nsize, group=group_handle)
     123         132 :          CALL group%set_handle(group_handle)
     124         396 :          ALLOCATE (diagonal(nsize))
     125         132 :          CALL dbcsr_get_diag(matrix, diagonal)
     126         132 :          CALL group%sum(diagonal)
     127        2308 :          det = PRODUCT(diagonal)
     128             :       END BLOCK
     129             : 
     130             :       ! create diagonal SQRTI matrix
     131        2176 :       diagonal(:) = 1.0_dp/(SQRT(diagonal(:)))
     132             :       !ROLL CALL dbcsr_copy(tmp1,matrix)
     133         132 :       CALL dbcsr_desymmetrize(matrix, tmp1)
     134         132 :       CALL dbcsr_set(tmp1, 0.0_dp)
     135         132 :       CALL dbcsr_set_diag(tmp1, diagonal)
     136         132 :       CALL dbcsr_filter(tmp1, threshold)
     137         132 :       DEALLOCATE (diagonal)
     138             : 
     139             :       ! normalize the main diagonal, off-diagonal elements are scaled to
     140             :       ! make the norm of the matrix less than 1
     141             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     142             :                           matrix, &
     143             :                           tmp1, &
     144             :                           0.0_dp, tmp3, &
     145         132 :                           filter_eps=threshold)
     146             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     147             :                           tmp1, &
     148             :                           tmp3, &
     149             :                           0.0_dp, tmp2, &
     150         132 :                           filter_eps=threshold)
     151             : 
     152             :       ! subtract the main diagonal to create matrix X
     153         132 :       CALL dbcsr_add_on_diag(tmp2, -1.0_dp)
     154         132 :       frobnorm = dbcsr_frobenius_norm(tmp2)
     155         132 :       IF (unit_nr > 0) THEN
     156          66 :          IF (recursion_depth .EQ. 0) THEN
     157          41 :             WRITE (unit_nr, '()')
     158             :          ELSE
     159             :             WRITE (unit_nr, '(T6,A28,1X,I15)') &
     160          25 :                "Recursive iteration:", recursion_depth
     161             :          END IF
     162             :          WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     163          66 :             "Frobenius norm:", frobnorm
     164          66 :          CALL m_flush(unit_nr)
     165             :       END IF
     166             : 
     167         132 :       IF (frobnorm .GE. 1.0_dp) THEN
     168             : 
     169          50 :          CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
     170             :          ! these controls should be provided as input
     171          50 :          order_lanczos = 3
     172          50 :          eps_lanczos = 1.0E-4_dp
     173          50 :          max_iter_lanczos = 40
     174             :          CALL matrix_sqrt_Newton_Schulz( &
     175             :             tmp3, & ! output sqrt
     176             :             tmp1, & ! output sqrti
     177             :             tmp2, & ! input original
     178             :             threshold=threshold, &
     179             :             order=order_lanczos, &
     180             :             eps_lanczos=eps_lanczos, &
     181          50 :             max_iter_lanczos=max_iter_lanczos)
     182          50 :          recursion_depth = recursion_depth + 1
     183          50 :          CALL determinant(tmp3, det0, threshold)
     184          50 :          recursion_depth = recursion_depth - 1
     185          50 :          det = det*det0*det0
     186             : 
     187             :       ELSE
     188             : 
     189             :          ! create accumulator
     190          82 :          CALL dbcsr_copy(tmp1, tmp2)
     191             :          ! re-create to make use of symmetry
     192             :          !ROLL CALL dbcsr_create(tmp3,template=matrix)
     193             : 
     194          82 :          IF (unit_nr > 0) WRITE (unit_nr, *)
     195             : 
     196             :          ! initialize the sign of the term
     197          82 :          sign_iter = -1
     198        1078 :          DO i = 1, 100
     199             : 
     200        1078 :             t1 = m_walltime()
     201             : 
     202             :             ! multiply X^i by X
     203             :             ! note that the first iteration evaluates X^2
     204             :             ! because the trace of X^1 is zero by construction
     205             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, &
     206             :                                 0.0_dp, tmp3, &
     207             :                                 filter_eps=threshold, &
     208        1078 :                                 flop=flop1)
     209        1078 :             CALL dbcsr_copy(tmp1, tmp3)
     210             : 
     211             :             ! get trace
     212        1078 :             CALL dbcsr_trace(tmp1, trace)
     213        1078 :             trace = trace*sign_iter/(1.0_dp*(i + 1))
     214        1078 :             sign_iter = -sign_iter
     215             : 
     216             :             ! update the determinant
     217        1078 :             det = det*EXP(trace)
     218             : 
     219        1078 :             occ_matrix = dbcsr_get_occupation(tmp1)
     220             :             CALL dbcsr_norm(tmp1, &
     221        1078 :                             dbcsr_norm_maxabsnorm, norm_scalar=maxnorm)
     222             : 
     223        1078 :             t2 = m_walltime()
     224             : 
     225        1078 :             IF (unit_nr > 0) THEN
     226             :                WRITE (unit_nr, '(T6,A,1X,I3,1X,F7.5,F16.10,F10.3,F11.3)') &
     227         539 :                   "Determinant iter", i, occ_matrix, &
     228         539 :                   det, t2 - t1, &
     229        1078 :                   flop1/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     230         539 :                CALL m_flush(unit_nr)
     231             :             END IF
     232             : 
     233             :             ! exit if the trace is close to zero
     234        2156 :             IF (maxnorm < threshold) EXIT
     235             : 
     236             :          END DO ! end iterations
     237             : 
     238          82 :          IF (unit_nr > 0) THEN
     239          41 :             WRITE (unit_nr, '()')
     240          41 :             CALL m_flush(unit_nr)
     241             :          END IF
     242             : 
     243             :       END IF ! decide to do sqrt or not
     244             : 
     245         132 :       IF (unit_nr > 0) THEN
     246          66 :          IF (recursion_depth .EQ. 0) THEN
     247             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     248          41 :                "Final determinant:", det
     249          41 :             WRITE (unit_nr, '()')
     250             :          ELSE
     251             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     252          25 :                "Recursive determinant:", det
     253             :          END IF
     254          66 :          CALL m_flush(unit_nr)
     255             :       END IF
     256             : 
     257         132 :       CALL dbcsr_release(tmp1)
     258         132 :       CALL dbcsr_release(tmp2)
     259         132 :       CALL dbcsr_release(tmp3)
     260             : 
     261         132 :       CALL timestop(handle)
     262             : 
     263         132 :    END SUBROUTINE determinant
     264             : 
     265             : ! **************************************************************************************************
     266             : !> \brief invert a symmetric positive definite diagonally dominant matrix
     267             : !> \param matrix_inverse ...
     268             : !> \param matrix ...
     269             : !> \param threshold convergence threshold nased on the max abs
     270             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     271             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     272             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     273             : !> \param accelerator_order ...
     274             : !> \param max_iter_lanczos ...
     275             : !> \param eps_lanczos ...
     276             : !> \param silent ...
     277             : !> \par History
     278             : !>       2010.10 created [Joost VandeVondele]
     279             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     280             : !> \author Joost VandeVondele
     281             : ! **************************************************************************************************
     282          26 :    SUBROUTINE invert_Taylor(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     283             :                             norm_convergence, filter_eps, accelerator_order, &
     284             :                             max_iter_lanczos, eps_lanczos, silent)
     285             : 
     286             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     287             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     288             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     289             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     290             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     291             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     292             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     293             : 
     294             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Taylor'
     295             : 
     296             :       INTEGER                                            :: accelerator_type, handle, i, &
     297             :                                                             my_max_iter_lanczos, nrows, unit_nr
     298             :       INTEGER(KIND=int_8)                                :: flop2
     299             :       LOGICAL                                            :: converged, use_inv_guess
     300             :       REAL(KIND=dp)                                      :: coeff, convergence, maxnorm_matrix, &
     301             :                                                             my_eps_lanczos, occ_matrix, t1, t2
     302          26 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: p_diagonal
     303             :       TYPE(cp_logger_type), POINTER                      :: logger
     304             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2, tmp3_sym
     305             : 
     306          26 :       CALL timeset(routineN, handle)
     307             : 
     308          26 :       logger => cp_get_default_logger()
     309          26 :       IF (logger%para_env%is_source()) THEN
     310          13 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     311             :       ELSE
     312          13 :          unit_nr = -1
     313             :       END IF
     314          26 :       IF (PRESENT(silent)) THEN
     315          26 :          IF (silent) unit_nr = -1
     316             :       END IF
     317             : 
     318          26 :       convergence = threshold
     319          26 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     320             : 
     321          26 :       accelerator_type = 0
     322          26 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     323           0 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     324             : 
     325          26 :       use_inv_guess = .FALSE.
     326          26 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     327             : 
     328          26 :       my_max_iter_lanczos = 64
     329          26 :       my_eps_lanczos = 1.0E-3_dp
     330          26 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     331          26 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     332             : 
     333          26 :       CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     334          26 :       CALL dbcsr_create(tmp2, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     335          26 :       CALL dbcsr_create(tmp3_sym, template=matrix_inverse)
     336             : 
     337          26 :       CALL dbcsr_get_info(matrix, nfullrows_total=nrows)
     338          78 :       ALLOCATE (p_diagonal(nrows))
     339             : 
     340             :       ! generate the initial guess
     341          26 :       IF (.NOT. use_inv_guess) THEN
     342             : 
     343          26 :          SELECT CASE (accelerator_type)
     344             :          CASE (0)
     345             :             ! use tmp1 to hold off-diagonal elements
     346          26 :             CALL dbcsr_desymmetrize(matrix, tmp1)
     347         858 :             p_diagonal(:) = 0.0_dp
     348          26 :             CALL dbcsr_set_diag(tmp1, p_diagonal)
     349             :             !CALL dbcsr_print(tmp1)
     350             :             ! invert the main diagonal
     351          26 :             CALL dbcsr_get_diag(matrix, p_diagonal)
     352         858 :             DO i = 1, nrows
     353         858 :                IF (p_diagonal(i) .NE. 0.0_dp) THEN
     354         416 :                   p_diagonal(i) = 1.0_dp/p_diagonal(i)
     355             :                END IF
     356             :             END DO
     357          26 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     358          26 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     359          26 :             CALL dbcsr_set_diag(matrix_inverse, p_diagonal)
     360             :          CASE DEFAULT
     361          26 :             CPABORT("Illegal accelerator order")
     362             :          END SELECT
     363             : 
     364             :       ELSE
     365             : 
     366           0 :          CPABORT("Guess is NYI")
     367             : 
     368             :       END IF
     369             : 
     370             :       CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, &
     371          26 :                           0.0_dp, tmp2, filter_eps=filter_eps)
     372             : 
     373          26 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     374             : 
     375             :       ! scale the approximate inverse to be within the convergence radius
     376          26 :       t1 = m_walltime()
     377             : 
     378             :       ! done with the initial guess, start iterations
     379          26 :       converged = .FALSE.
     380          26 :       CALL dbcsr_desymmetrize(matrix_inverse, tmp1)
     381          26 :       coeff = 1.0_dp
     382         284 :       DO i = 1, 100
     383             : 
     384             :          ! coeff = +/- 1
     385         284 :          coeff = -1.0_dp*coeff
     386             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, &
     387             :                              tmp3_sym, &
     388         284 :                              flop=flop2, filter_eps=filter_eps)
     389             :          !flop=flop2)
     390         284 :          CALL dbcsr_add(matrix_inverse, tmp3_sym, 1.0_dp, coeff)
     391         284 :          CALL dbcsr_release(tmp1)
     392         284 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     393         284 :          CALL dbcsr_desymmetrize(tmp3_sym, tmp1)
     394             : 
     395             :          ! for the convergence check
     396             :          CALL dbcsr_norm(tmp3_sym, &
     397         284 :                          dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     398             : 
     399         284 :          t2 = m_walltime()
     400         284 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     401             : 
     402         284 :          IF (unit_nr > 0) THEN
     403         142 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Taylor iter", i, occ_matrix, &
     404         142 :                maxnorm_matrix, t2 - t1, &
     405         284 :                flop2/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     406         142 :             CALL m_flush(unit_nr)
     407             :          END IF
     408             : 
     409         284 :          IF (maxnorm_matrix < convergence) THEN
     410             :             converged = .TRUE.
     411             :             EXIT
     412             :          END IF
     413             : 
     414         258 :          t1 = m_walltime()
     415             : 
     416             :       END DO
     417             : 
     418             :       !last convergence check
     419             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_inverse, 0.0_dp, tmp1, &
     420          26 :                           filter_eps=filter_eps)
     421          26 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     422             :       !frob_matrix =  dbcsr_frobenius_norm(tmp1)
     423          26 :       CALL dbcsr_norm(tmp1, dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     424          26 :       IF (unit_nr > 0) THEN
     425          13 :          WRITE (unit_nr, '(T6,A,E12.5)') "Final Taylor error", maxnorm_matrix
     426          13 :          WRITE (unit_nr, '()')
     427          13 :          CALL m_flush(unit_nr)
     428             :       END IF
     429          26 :       IF (maxnorm_matrix > convergence) THEN
     430           0 :          converged = .FALSE.
     431           0 :          IF (unit_nr > 0) THEN
     432           0 :             WRITE (unit_nr, *) 'Final convergence check failed'
     433             :          END IF
     434             :       END IF
     435             : 
     436          26 :       IF (.NOT. converged) THEN
     437           0 :          CPABORT("Taylor inversion did not converge")
     438             :       END IF
     439             : 
     440          26 :       CALL dbcsr_release(tmp1)
     441          26 :       CALL dbcsr_release(tmp2)
     442          26 :       CALL dbcsr_release(tmp3_sym)
     443             : 
     444          26 :       DEALLOCATE (p_diagonal)
     445             : 
     446          26 :       CALL timestop(handle)
     447             : 
     448          52 :    END SUBROUTINE invert_Taylor
     449             : 
     450             : ! **************************************************************************************************
     451             : !> \brief invert a symmetric positive definite matrix by Hotelling's method
     452             : !>        explicit symmetrization makes this code not suitable for other matrix types
     453             : !>        Currently a bit messy with the options, to to be cleaned soon
     454             : !> \param matrix_inverse ...
     455             : !> \param matrix ...
     456             : !> \param threshold convergence threshold nased on the max abs
     457             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     458             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     459             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     460             : !> \param accelerator_order ...
     461             : !> \param max_iter_lanczos ...
     462             : !> \param eps_lanczos ...
     463             : !> \param silent ...
     464             : !> \par History
     465             : !>       2010.10 created [Joost VandeVondele]
     466             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     467             : !> \author Joost VandeVondele
     468             : ! **************************************************************************************************
     469        2032 :    SUBROUTINE invert_Hotelling(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     470             :                                norm_convergence, filter_eps, accelerator_order, &
     471             :                                max_iter_lanczos, eps_lanczos, silent)
     472             : 
     473             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     474             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     475             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     476             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     477             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     478             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     479             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     480             : 
     481             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Hotelling'
     482             : 
     483             :       INTEGER                                            :: accelerator_type, handle, i, &
     484             :                                                             my_max_iter_lanczos, unit_nr
     485             :       INTEGER(KIND=int_8)                                :: flop1, flop2
     486             :       LOGICAL                                            :: arnoldi_converged, converged, &
     487             :                                                             use_inv_guess
     488             :       REAL(KIND=dp) :: convergence, frob_matrix, gershgorin_norm, max_ev, maxnorm_matrix, min_ev, &
     489             :          my_eps_lanczos, my_filter_eps, occ_matrix, scalingf, t1, t2
     490             :       TYPE(cp_logger_type), POINTER                      :: logger
     491             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2
     492             : 
     493             :       !TYPE(arnoldi_data_type)                            :: my_arnoldi
     494             :       !TYPE(dbcsr_p_type), DIMENSION(1)                   :: mymat
     495             : 
     496        2032 :       CALL timeset(routineN, handle)
     497             : 
     498        2032 :       logger => cp_get_default_logger()
     499        2032 :       IF (logger%para_env%is_source()) THEN
     500        1016 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     501             :       ELSE
     502        1016 :          unit_nr = -1
     503             :       END IF
     504        2032 :       IF (PRESENT(silent)) THEN
     505        2014 :          IF (silent) unit_nr = -1
     506             :       END IF
     507             : 
     508        2032 :       convergence = threshold
     509        2032 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     510             : 
     511        2032 :       accelerator_type = 1
     512        2032 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     513        1436 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     514             : 
     515        2032 :       use_inv_guess = .FALSE.
     516        2032 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     517             : 
     518        2032 :       my_max_iter_lanczos = 64
     519        2032 :       my_eps_lanczos = 1.0E-3_dp
     520        2032 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     521        2032 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     522             : 
     523        2032 :       my_filter_eps = threshold
     524        2032 :       IF (PRESENT(filter_eps)) my_filter_eps = filter_eps
     525             : 
     526             :       ! generate the initial guess
     527        2032 :       IF (.NOT. use_inv_guess) THEN
     528             : 
     529           0 :          SELECT CASE (accelerator_type)
     530             :          CASE (0)
     531           0 :             gershgorin_norm = dbcsr_gershgorin_norm(matrix)
     532           0 :             frob_matrix = dbcsr_frobenius_norm(matrix)
     533           0 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     534           0 :             CALL dbcsr_add_on_diag(matrix_inverse, 1/MIN(gershgorin_norm, frob_matrix))
     535             :          CASE (1)
     536             :             ! initialize matrix to unity and use arnoldi (below) to scale it into the convergence range
     537        1558 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     538        1558 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     539             :          CASE DEFAULT
     540        1558 :             CPABORT("Illegal accelerator order")
     541             :          END SELECT
     542             : 
     543             :          ! everything commutes, therefore our all products will be symmetric
     544        1558 :          CALL dbcsr_create(tmp1, template=matrix_inverse)
     545             : 
     546             :       ELSE
     547             : 
     548             :          ! It is unlikely that our guess will commute with the matrix, therefore the first product will
     549             :          ! be non symmetric
     550         474 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     551             : 
     552             :       END IF
     553             : 
     554        2032 :       CALL dbcsr_create(tmp2, template=matrix_inverse)
     555             : 
     556        2032 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     557             : 
     558             :       ! scale the approximate inverse to be within the convergence radius
     559        2032 :       t1 = m_walltime()
     560             : 
     561             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     562        2032 :                           0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     563             : 
     564        2032 :       IF (accelerator_type == 1) THEN
     565             : 
     566             :          ! scale the matrix to get into the convergence range
     567             :          CALL arnoldi_extremal(tmp1, max_eV, min_eV, threshold=my_eps_lanczos, &
     568        2032 :                                max_iter=my_max_iter_lanczos, converged=arnoldi_converged)
     569             :          !mymat(1)%matrix => tmp1
     570             :          !CALL setup_arnoldi_data(my_arnoldi, mymat, max_iter=30, threshold=1.0E-3_dp, selection_crit=1, &
     571             :          !                        nval_request=2, nrestarts=2, generalized_ev=.FALSE., iram=.TRUE.)
     572             :          !CALL arnoldi_ev(mymat, my_arnoldi)
     573             :          !max_eV = REAL(get_selected_ritz_val(my_arnoldi, 2), dp)
     574             :          !min_eV = REAL(get_selected_ritz_val(my_arnoldi, 1), dp)
     575             :          !CALL deallocate_arnoldi_data(my_arnoldi)
     576             : 
     577        2032 :          IF (unit_nr > 0) THEN
     578         768 :             WRITE (unit_nr, *)
     579         768 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", my_eps_lanczos
     580         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_eV, min_eV
     581         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_eV/MAX(min_eV, EPSILON(min_eV))
     582             :          END IF
     583             : 
     584             :          ! 2.0 would be the correct scaling however, we should make sure here, that we are in the convergence radius
     585        2032 :          scalingf = 1.9_dp/(max_eV + min_eV)
     586        2032 :          CALL dbcsr_scale(tmp1, scalingf)
     587        2032 :          CALL dbcsr_scale(matrix_inverse, scalingf)
     588        2032 :          min_ev = min_ev*scalingf
     589             : 
     590             :       END IF
     591             : 
     592             :       ! done with the initial guess, start iterations
     593        2032 :       converged = .FALSE.
     594        9000 :       DO i = 1, 100
     595             : 
     596             :          ! tmp1 = S^-1 S
     597             : 
     598             :          ! for the convergence check
     599        9000 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     600             :          CALL dbcsr_norm(tmp1, &
     601        9000 :                          dbcsr_norm_maxabsnorm, norm_scalar=maxnorm_matrix)
     602        9000 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     603             : 
     604             :          ! tmp2 = S^-1 S S^-1
     605             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, 0.0_dp, tmp2, &
     606        9000 :                              flop=flop2, filter_eps=my_filter_eps)
     607             :          ! S^-1_{n+1} = 2 S^-1 - S^-1 S S^-1
     608        9000 :          CALL dbcsr_add(matrix_inverse, tmp2, 2.0_dp, -1.0_dp)
     609             : 
     610        9000 :          CALL dbcsr_filter(matrix_inverse, my_filter_eps)
     611        9000 :          t2 = m_walltime()
     612        9000 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     613             : 
     614             :          ! use the scalar form of the algorithm to trace the EV
     615        9000 :          IF (accelerator_type == 1) THEN
     616        9000 :             min_ev = min_ev*(2.0_dp - min_ev)
     617        9000 :             IF (PRESENT(norm_convergence)) maxnorm_matrix = ABS(min_eV - 1.0_dp)
     618             :          END IF
     619             : 
     620        9000 :          IF (unit_nr > 0) THEN
     621        3718 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Hotelling iter", i, occ_matrix, &
     622        3718 :                maxnorm_matrix, t2 - t1, &
     623        7436 :                (flop1 + flop2)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     624        3718 :             CALL m_flush(unit_nr)
     625             :          END IF
     626             : 
     627        9000 :          IF (maxnorm_matrix < convergence) THEN
     628             :             converged = .TRUE.
     629             :             EXIT
     630             :          END IF
     631             : 
     632             :          ! scale the matrix for improved convergence
     633        6968 :          IF (accelerator_type == 1) THEN
     634        6968 :             min_ev = min_ev*2.0_dp/(min_ev + 1.0_dp)
     635        6968 :             CALL dbcsr_scale(matrix_inverse, 2.0_dp/(min_ev + 1.0_dp))
     636             :          END IF
     637             : 
     638        6968 :          t1 = m_walltime()
     639             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     640        6968 :                              0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     641             : 
     642             :       END DO
     643             : 
     644        2032 :       IF (.NOT. converged) THEN
     645           0 :          CPABORT("Hotelling inversion did not converge")
     646             :       END IF
     647             : 
     648             :       ! try to symmetrize the output matrix
     649        2032 :       IF (dbcsr_get_matrix_type(matrix_inverse) == dbcsr_type_no_symmetry) THEN
     650         100 :          CALL dbcsr_transposed(tmp2, matrix_inverse)
     651        2132 :          CALL dbcsr_add(matrix_inverse, tmp2, 0.5_dp, 0.5_dp)
     652             :       END IF
     653             : 
     654        2032 :       IF (unit_nr > 0) THEN
     655             : !           WRITE(unit_nr,'(T6,A,1X,I3,1X,F10.8,E12.3)') "Final Hotelling ",i,occ_matrix,&
     656             : !              !frob_matrix/frob_matrix_base
     657             : !              maxnorm_matrix
     658         768 :          WRITE (unit_nr, '()')
     659         768 :          CALL m_flush(unit_nr)
     660             :       END IF
     661             : 
     662        2032 :       CALL dbcsr_release(tmp1)
     663        2032 :       CALL dbcsr_release(tmp2)
     664             : 
     665        2032 :       CALL timestop(handle)
     666             : 
     667        2032 :    END SUBROUTINE invert_Hotelling
     668             : 
     669             : ! **************************************************************************************************
     670             : !> \brief compute the sign a matrix using Newton-Schulz iterations
     671             : !> \param matrix_sign ...
     672             : !> \param matrix ...
     673             : !> \param threshold ...
     674             : !> \param sign_order ...
     675             : !> \param iounit ...
     676             : !> \par History
     677             : !>       2010.10 created [Joost VandeVondele]
     678             : !>       2019.05 extended to order byxond 2 [Robert Schade]
     679             : !> \author Joost VandeVondele, Robert Schade
     680             : ! **************************************************************************************************
     681        1058 :    SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order, iounit)
     682             : 
     683             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
     684             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     685             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order, iounit
     686             : 
     687             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_Newton_Schulz'
     688             : 
     689             :       INTEGER                                            :: count, handle, i, order, unit_nr
     690             :       INTEGER(KIND=int_8)                                :: flops
     691             :       REAL(KIND=dp)                                      :: a0, a1, a2, a3, a4, a5, floptot, &
     692             :                                                             frob_matrix, frob_matrix_base, &
     693             :                                                             gersh_matrix, occ_matrix, prefactor, &
     694             :                                                             t1, t2
     695             :       TYPE(cp_logger_type), POINTER                      :: logger
     696             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3, tmp4
     697             : 
     698        1058 :       CALL timeset(routineN, handle)
     699             : 
     700        1058 :       IF (PRESENT(iounit)) THEN
     701        1058 :          unit_nr = iounit
     702             :       ELSE
     703           0 :          logger => cp_get_default_logger()
     704           0 :          IF (logger%para_env%is_source()) THEN
     705           0 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     706             :          ELSE
     707           0 :             unit_nr = -1
     708             :          END IF
     709             :       END IF
     710             : 
     711        1058 :       IF (PRESENT(sign_order)) THEN
     712        1058 :          order = sign_order
     713             :       ELSE
     714             :          order = 2
     715             :       END IF
     716             : 
     717        1058 :       CALL dbcsr_create(tmp1, template=matrix_sign)
     718             : 
     719        1058 :       CALL dbcsr_create(tmp2, template=matrix_sign)
     720        1058 :       IF (ABS(order) .GE. 4) THEN
     721           8 :          CALL dbcsr_create(tmp3, template=matrix_sign)
     722             :       END IF
     723           8 :       IF (ABS(order) .GT. 4) THEN
     724           6 :          CALL dbcsr_create(tmp4, template=matrix_sign)
     725             :       END IF
     726             : 
     727        1058 :       CALL dbcsr_copy(matrix_sign, matrix)
     728        1058 :       CALL dbcsr_filter(matrix_sign, threshold)
     729             : 
     730             :       ! scale the matrix to get into the convergence range
     731        1058 :       frob_matrix = dbcsr_frobenius_norm(matrix_sign)
     732        1058 :       gersh_matrix = dbcsr_gershgorin_norm(matrix_sign)
     733        1058 :       CALL dbcsr_scale(matrix_sign, 1/MIN(frob_matrix, gersh_matrix))
     734             : 
     735        1058 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     736             : 
     737        1058 :       count = 0
     738       13202 :       DO i = 1, 100
     739       13202 :          floptot = 0_dp
     740       13202 :          t1 = m_walltime()
     741             :          ! tmp1 = X * X
     742             :          CALL dbcsr_multiply("N", "N", -1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     743       13202 :                              filter_eps=threshold, flop=flops)
     744       13202 :          floptot = floptot + flops
     745             : 
     746             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
     747       13202 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     748       13202 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     749       13202 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
     750             : 
     751             :          ! f(y) approx 1/sqrt(1-y)
     752             :          ! f(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     753             :          ! f2(y)=1+y/2=1/2*(2+y)
     754             :          ! f3(y)=1+y/2+3/8*y^2=3/8*(8/3+4/3*y+y^2)
     755             :          ! f4(y)=1+y/2+3/8*y^2+5/16*y^3=5/16*(16/5+8/5*y+6/5*y^2+y^3)
     756             :          ! f5(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4=35/128*(128/35+128/70*y+48/35*y^2+8/7*y^3+y^4)
     757             :          !      z(y)=(y+a_0)*y+a_1
     758             :          ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     759             :          !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     760             :          !    a_0=1/14
     761             :          !    a_1=23819/13720
     762             :          !    a_2=1269/980-2a_1=-3734/1715
     763             :          !    a_3=832591127/188238400
     764             :          ! f6(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5
     765             :          !      =63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     766             :          ! f7(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     767             :          !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     768             :          ! z(y)=(y+a_0)*y+a_1
     769             :          ! w(y)=(y+a_2)*z(y)+a_3
     770             :          ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     771             :          ! a_0= 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422
     772             :          ! a_1= 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568
     773             :          ! a_2=-1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968
     774             :          ! a_3= 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453
     775             :          ! a_4=-3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667
     776             :          ! a_5= 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578
     777             :          !
     778             :          ! y=1-X*X
     779             : 
     780             :          ! tmp1 = I-x*x
     781             :          IF (order .EQ. 2) THEN
     782       13156 :             prefactor = 0.5_dp
     783             : 
     784             :             ! update the above to 3*I-X*X
     785       13156 :             CALL dbcsr_add_on_diag(tmp1, +2.0_dp)
     786       13156 :             occ_matrix = dbcsr_get_occupation(matrix_sign)
     787             :          ELSE IF (order .EQ. 3) THEN
     788             :             ! with one multiplication
     789             :             ! tmp1=y
     790          12 :             CALL dbcsr_copy(tmp2, tmp1)
     791          12 :             CALL dbcsr_scale(tmp1, 4.0_dp/3.0_dp)
     792          12 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp/3.0_dp)
     793             : 
     794             :             ! tmp2=y^2
     795             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp2, 1.0_dp, tmp1, &
     796          12 :                                 filter_eps=threshold, flop=flops)
     797          12 :             floptot = floptot + flops
     798          12 :             prefactor = 3.0_dp/8.0_dp
     799             : 
     800             :          ELSE IF (order .EQ. 4) THEN
     801             :             ! with two multiplications
     802             :             ! tmp1=y
     803          10 :             CALL dbcsr_copy(tmp3, tmp1)
     804          10 :             CALL dbcsr_scale(tmp1, 8.0_dp/5.0_dp)
     805          10 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp/5.0_dp)
     806             : 
     807             :             !
     808             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     809          10 :                                 filter_eps=threshold, flop=flops)
     810          10 :             floptot = floptot + flops
     811             : 
     812          10 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp/5.0_dp)
     813             : 
     814             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     815          10 :                                 filter_eps=threshold, flop=flops)
     816          10 :             floptot = floptot + flops
     817             : 
     818          10 :             prefactor = 5.0_dp/16.0_dp
     819             :          ELSE IF (order .EQ. -5) THEN
     820             :             ! with three multiplications
     821             :             ! tmp1=y
     822           0 :             CALL dbcsr_copy(tmp3, tmp1)
     823           0 :             CALL dbcsr_scale(tmp1, 128.0_dp/70.0_dp)
     824           0 :             CALL dbcsr_add_on_diag(tmp1, 128.0_dp/35.0_dp)
     825             : 
     826             :             !
     827             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     828           0 :                                 filter_eps=threshold, flop=flops)
     829           0 :             floptot = floptot + flops
     830             : 
     831           0 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 48.0_dp/35.0_dp)
     832             : 
     833             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     834           0 :                                 filter_eps=threshold, flop=flops)
     835           0 :             floptot = floptot + flops
     836             : 
     837           0 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 8.0_dp/7.0_dp)
     838             : 
     839             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 1.0_dp, tmp1, &
     840           0 :                                 filter_eps=threshold, flop=flops)
     841           0 :             floptot = floptot + flops
     842             : 
     843           0 :             prefactor = 35.0_dp/128.0_dp
     844             :          ELSE IF (order .EQ. 5) THEN
     845             :             ! with two multiplications
     846             :             !      z(y)=(y+a_0)*y+a_1
     847             :             ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     848             :             !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     849             :             !    a_0=1/14
     850             :             !    a_1=23819/13720
     851             :             !    a_2=1269/980-2a_1=-3734/1715
     852             :             !    a_3=832591127/188238400
     853           8 :             a0 = 1.0_dp/14.0_dp
     854           8 :             a1 = 23819.0_dp/13720.0_dp
     855           8 :             a2 = -3734_dp/1715.0_dp
     856           8 :             a3 = 832591127_dp/188238400.0_dp
     857             : 
     858             :             ! tmp1=y
     859             :             ! tmp3=z
     860           8 :             CALL dbcsr_copy(tmp3, tmp1)
     861           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     862             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     863           8 :                                 filter_eps=threshold, flop=flops)
     864           8 :             floptot = floptot + flops
     865           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     866             : 
     867           8 :             CALL dbcsr_add_on_diag(tmp1, a2)
     868           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
     869             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, tmp3, &
     870           8 :                                 filter_eps=threshold, flop=flops)
     871           8 :             floptot = floptot + flops
     872           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     873           8 :             CALL dbcsr_copy(tmp1, tmp3)
     874             : 
     875           8 :             prefactor = 35.0_dp/128.0_dp
     876             :          ELSE IF (order .EQ. 6) THEN
     877             :             ! with four multiplications
     878             :             ! f6(y)=63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     879             :             ! tmp1=y
     880           8 :             CALL dbcsr_copy(tmp3, tmp1)
     881           8 :             CALL dbcsr_scale(tmp1, 128.0_dp/63.0_dp)
     882           8 :             CALL dbcsr_add_on_diag(tmp1, 256.0_dp/63.0_dp)
     883             : 
     884             :             !
     885             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     886           8 :                                 filter_eps=threshold, flop=flops)
     887           8 :             floptot = floptot + flops
     888             : 
     889           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 32.0_dp/21.0_dp)
     890             : 
     891             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     892           8 :                                 filter_eps=threshold, flop=flops)
     893           8 :             floptot = floptot + flops
     894             : 
     895           8 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 80.0_dp/63.0_dp)
     896             : 
     897             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp2, &
     898           8 :                                 filter_eps=threshold, flop=flops)
     899           8 :             floptot = floptot + flops
     900             : 
     901           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 10.0_dp/9.0_dp)
     902             : 
     903             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     904           8 :                                 filter_eps=threshold, flop=flops)
     905           8 :             floptot = floptot + flops
     906             : 
     907           8 :             prefactor = 63.0_dp/256.0_dp
     908             :          ELSE IF (order .EQ. 7) THEN
     909             :             ! with three multiplications
     910             : 
     911           8 :             a0 = 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422_dp
     912           8 :             a1 = 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568_dp
     913           8 :             a2 = -1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968_dp
     914           8 :             a3 = 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453_dp
     915           8 :             a4 = -3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667_dp
     916           8 :             a5 = 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578_dp
     917             :             !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     918             :             ! z(y)=(y+a_0)*y+a_1
     919             :             ! w(y)=(y+a_2)*z(y)+a_3
     920             :             ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     921             : 
     922             :             ! tmp1=y
     923             :             ! tmp3=z
     924           8 :             CALL dbcsr_copy(tmp3, tmp1)
     925           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     926             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     927           8 :                                 filter_eps=threshold, flop=flops)
     928           8 :             floptot = floptot + flops
     929           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     930             : 
     931             :             ! tmp4=w
     932           8 :             CALL dbcsr_copy(tmp4, tmp1)
     933           8 :             CALL dbcsr_add_on_diag(tmp4, a2)
     934             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp2, 0.0_dp, tmp3, &
     935           8 :                                 filter_eps=threshold, flop=flops)
     936           8 :             floptot = floptot + flops
     937           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     938             : 
     939           8 :             CALL dbcsr_add(tmp2, tmp3, 1.0_dp, 1.0_dp)
     940           8 :             CALL dbcsr_add_on_diag(tmp2, a4)
     941             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
     942           8 :                                 filter_eps=threshold, flop=flops)
     943           8 :             floptot = floptot + flops
     944           8 :             CALL dbcsr_add_on_diag(tmp1, a5)
     945             : 
     946           8 :             prefactor = 231.0_dp/1024.0_dp
     947             :          ELSE
     948           0 :             CPABORT("requested order is not implemented.")
     949             :          END IF
     950             : 
     951             :          ! tmp2 = X * prefactor *
     952             :          CALL dbcsr_multiply("N", "N", prefactor, matrix_sign, tmp1, 0.0_dp, tmp2, &
     953       13202 :                              filter_eps=threshold, flop=flops)
     954       13202 :          floptot = floptot + flops
     955             : 
     956             :          ! done iterating
     957             :          ! CALL dbcsr_filter(tmp2,threshold)
     958       13202 :          CALL dbcsr_copy(matrix_sign, tmp2)
     959       13202 :          t2 = m_walltime()
     960             : 
     961       13202 :          occ_matrix = dbcsr_get_occupation(matrix_sign)
     962             : 
     963       13202 :          IF (unit_nr > 0) THEN
     964           0 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sign iter ", i, occ_matrix, &
     965           0 :                frob_matrix/frob_matrix_base, t2 - t1, &
     966           0 :                floptot/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     967           0 :             CALL m_flush(unit_nr)
     968             :          END IF
     969             : 
     970             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
     971       13202 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) EXIT
     972             : 
     973             :       END DO
     974             : 
     975             :       ! this check is not really needed
     976             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     977        1058 :                           filter_eps=threshold)
     978        1058 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     979        1058 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     980        1058 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
     981        1058 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
     982        1058 :       IF (unit_nr > 0) THEN
     983           0 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sign iter", i, occ_matrix, &
     984           0 :             frob_matrix/frob_matrix_base
     985           0 :          WRITE (unit_nr, '()')
     986           0 :          CALL m_flush(unit_nr)
     987             :       END IF
     988             : 
     989        1058 :       CALL dbcsr_release(tmp1)
     990        1058 :       CALL dbcsr_release(tmp2)
     991        1058 :       IF (ABS(order) .GE. 4) THEN
     992           8 :          CALL dbcsr_release(tmp3)
     993             :       END IF
     994           8 :       IF (ABS(order) .GT. 4) THEN
     995           6 :          CALL dbcsr_release(tmp4)
     996             :       END IF
     997             : 
     998        1058 :       CALL timestop(handle)
     999             : 
    1000        1058 :    END SUBROUTINE matrix_sign_Newton_Schulz
    1001             : 
    1002             :    ! **************************************************************************************************
    1003             : !> \brief compute the sign a matrix using the general algorithm for the p-th root of Richters et al.
    1004             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1005             : !> \param matrix_sign ...
    1006             : !> \param matrix ...
    1007             : !> \param threshold ...
    1008             : !> \param sign_order ...
    1009             : !> \par History
    1010             : !>       2019.03 created [Robert Schade]
    1011             : !> \author Robert Schade
    1012             : ! **************************************************************************************************
    1013          16 :    SUBROUTINE matrix_sign_proot(matrix_sign, matrix, threshold, sign_order)
    1014             : 
    1015             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1016             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1017             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1018             : 
    1019             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sign_proot'
    1020             : 
    1021             :       INTEGER                                            :: handle, order, unit_nr
    1022             :       INTEGER(KIND=int_8)                                :: flop0, flop1, flop2
    1023             :       LOGICAL                                            :: converged, symmetrize
    1024             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, occ_matrix
    1025             :       TYPE(cp_logger_type), POINTER                      :: logger
    1026             :       TYPE(dbcsr_type)                                   :: matrix2, matrix_sqrt, matrix_sqrt_inv, &
    1027             :                                                             tmp1, tmp2
    1028             : 
    1029           8 :       CALL cite_reference(Richters2018)
    1030             : 
    1031           8 :       CALL timeset(routineN, handle)
    1032             : 
    1033           8 :       logger => cp_get_default_logger()
    1034           8 :       IF (logger%para_env%is_source()) THEN
    1035           4 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1036             :       ELSE
    1037           4 :          unit_nr = -1
    1038             :       END IF
    1039             : 
    1040           8 :       IF (PRESENT(sign_order)) THEN
    1041           8 :          order = sign_order
    1042             :       ELSE
    1043           0 :          order = 2
    1044             :       END IF
    1045             : 
    1046           8 :       CALL dbcsr_create(tmp1, template=matrix_sign)
    1047             : 
    1048           8 :       CALL dbcsr_create(tmp2, template=matrix_sign)
    1049             : 
    1050           8 :       CALL dbcsr_create(matrix2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1051             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix, 0.0_dp, matrix2, &
    1052           8 :                           filter_eps=threshold, flop=flop0)
    1053             :       !CALL dbcsr_filter(matrix2, threshold)
    1054             : 
    1055             :       !CALL dbcsr_copy(matrix_sign, matrix)
    1056             :       !CALL dbcsr_filter(matrix_sign, threshold)
    1057             : 
    1058           8 :       IF (unit_nr > 0) WRITE (unit_nr, *)
    1059             : 
    1060           8 :       CALL dbcsr_create(matrix_sqrt, template=matrix2)
    1061           8 :       CALL dbcsr_create(matrix_sqrt_inv, template=matrix2)
    1062           8 :       IF (unit_nr > 0) WRITE (unit_nr, *) "Threshold=", threshold
    1063             : 
    1064           8 :       symmetrize = .FALSE.
    1065             :       CALL matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix2, threshold, order, &
    1066           8 :                              0.01_dp, 100, symmetrize, converged)
    1067             : 
    1068             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_sqrt_inv, 0.0_dp, matrix_sign, &
    1069           8 :                           filter_eps=threshold, flop=flop1)
    1070             : 
    1071             :       ! this check is not really needed
    1072             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
    1073           8 :                           filter_eps=threshold, flop=flop2)
    1074           8 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1075           8 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1076           8 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1077           8 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
    1078           8 :       IF (unit_nr > 0) THEN
    1079           4 :          WRITE (unit_nr, '(T6,A,F10.8,E12.3)') "Final proot sign iter", occ_matrix, &
    1080           8 :             frob_matrix/frob_matrix_base
    1081           4 :          WRITE (unit_nr, '()')
    1082           4 :          CALL m_flush(unit_nr)
    1083             :       END IF
    1084             : 
    1085           8 :       CALL dbcsr_release(tmp1)
    1086           8 :       CALL dbcsr_release(tmp2)
    1087           8 :       CALL dbcsr_release(matrix2)
    1088           8 :       CALL dbcsr_release(matrix_sqrt)
    1089           8 :       CALL dbcsr_release(matrix_sqrt_inv)
    1090             : 
    1091           8 :       CALL timestop(handle)
    1092             : 
    1093           8 :    END SUBROUTINE matrix_sign_proot
    1094             : 
    1095             : ! **************************************************************************************************
    1096             : !> \brief compute the sign of a dense matrix using Newton-Schulz iterations
    1097             : !> \param matrix_sign ...
    1098             : !> \param matrix ...
    1099             : !> \param matrix_id ...
    1100             : !> \param threshold ...
    1101             : !> \param sign_order ...
    1102             : !> \author Michael Lass, Robert Schade
    1103             : ! **************************************************************************************************
    1104           2 :    SUBROUTINE dense_matrix_sign_Newton_Schulz(matrix_sign, matrix, matrix_id, threshold, sign_order)
    1105             : 
    1106             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: matrix_sign
    1107             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: matrix
    1108             :       INTEGER, INTENT(IN), OPTIONAL                      :: matrix_id
    1109             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1110             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1111             : 
    1112             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dense_matrix_sign_Newton_Schulz'
    1113             : 
    1114             :       INTEGER                                            :: handle, i, j, sz, unit_nr
    1115             :       LOGICAL                                            :: converged
    1116             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, &
    1117             :                                                             gersh_matrix, prefactor, scaling_factor
    1118           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp1, tmp2
    1119             :       REAL(KIND=dp), DIMENSION(1)                        :: work
    1120             :       REAL(KIND=dp), EXTERNAL                            :: dlange
    1121             :       TYPE(cp_logger_type), POINTER                      :: logger
    1122             : 
    1123           2 :       CALL timeset(routineN, handle)
    1124             : 
    1125             :       ! print output on all ranks
    1126           2 :       logger => cp_get_default_logger()
    1127           2 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1128             : 
    1129             :       ! scale the matrix to get into the convergence range
    1130           2 :       sz = SIZE(matrix, 1)
    1131           2 :       frob_matrix = dlange('F', sz, sz, matrix, sz, work) !dbcsr_frobenius_norm(matrix_sign)
    1132           2 :       gersh_matrix = dlange('1', sz, sz, matrix, sz, work) !dbcsr_gershgorin_norm(matrix_sign)
    1133           2 :       scaling_factor = 1/MIN(frob_matrix, gersh_matrix)
    1134          86 :       matrix_sign = matrix*scaling_factor
    1135           8 :       ALLOCATE (tmp1(sz, sz))
    1136           6 :       ALLOCATE (tmp2(sz, sz))
    1137             : 
    1138           2 :       converged = .FALSE.
    1139          14 :       DO i = 1, 100
    1140          14 :          CALL dgemm('N', 'N', sz, sz, sz, -1.0_dp, matrix_sign, sz, matrix_sign, sz, 0.0_dp, tmp1, sz)
    1141             : 
    1142             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1143          14 :          frob_matrix_base = dlange('F', sz, sz, tmp1, sz, work)
    1144          98 :          DO j = 1, sz
    1145          98 :             tmp1(j, j) = tmp1(j, j) + 1.0_dp
    1146             :          END DO
    1147          14 :          frob_matrix = dlange('F', sz, sz, tmp1, sz, work)
    1148             : 
    1149          14 :          IF (sign_order .EQ. 2) THEN
    1150           8 :             prefactor = 0.5_dp
    1151             :             ! update the above to 3*I-X*X
    1152          56 :             DO j = 1, sz
    1153          56 :                tmp1(j, j) = tmp1(j, j) + 2.0_dp
    1154             :             END DO
    1155           6 :          ELSE IF (sign_order .EQ. 3) THEN
    1156         258 :             tmp2(:, :) = tmp1
    1157         258 :             tmp1 = tmp1*4.0_dp/3.0_dp
    1158          42 :             DO j = 1, sz
    1159          42 :                tmp1(j, j) = tmp1(j, j) + 8.0_dp/3.0_dp
    1160             :             END DO
    1161           6 :             CALL dgemm('N', 'N', sz, sz, sz, 1.0_dp, tmp2, sz, tmp2, sz, 1.0_dp, tmp1, sz)
    1162           6 :             prefactor = 3.0_dp/8.0_dp
    1163             :          ELSE
    1164           0 :             CPABORT("requested order is not implemented.")
    1165             :          END IF
    1166             : 
    1167          14 :          CALL dgemm('N', 'N', sz, sz, sz, prefactor, matrix_sign, sz, tmp1, sz, 0.0_dp, tmp2, sz)
    1168         602 :          matrix_sign = tmp2
    1169             : 
    1170             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
    1171          14 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) THEN
    1172             :             WRITE (unit_nr, '(T6,A,1X,I6,1X,A,1X,I3,E12.3)') &
    1173           2 :                "Submatrix", matrix_id, "final NS sign iter", i, frob_matrix/frob_matrix_base
    1174           2 :             CALL m_flush(unit_nr)
    1175             :             converged = .TRUE.
    1176             :             EXIT
    1177             :          END IF
    1178             :       END DO
    1179             : 
    1180             :       IF (.NOT. converged) &
    1181           0 :          CPABORT("dense_matrix_sign_Newton_Schulz did not converge within 100 iterations")
    1182             : 
    1183           2 :       DEALLOCATE (tmp1)
    1184           2 :       DEALLOCATE (tmp2)
    1185             : 
    1186           2 :       CALL timestop(handle)
    1187             : 
    1188           2 :    END SUBROUTINE dense_matrix_sign_Newton_Schulz
    1189             : 
    1190             : ! **************************************************************************************************
    1191             : !> \brief Perform eigendecomposition of a dense matrix
    1192             : !> \param sm ...
    1193             : !> \param N ...
    1194             : !> \param eigvals ...
    1195             : !> \param eigvecs ...
    1196             : !> \par History
    1197             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1198             : !> \author Michael Lass, Robert Schade
    1199             : ! **************************************************************************************************
    1200           4 :    SUBROUTINE eigdecomp(sm, N, eigvals, eigvecs)
    1201             :       INTEGER, INTENT(IN)                                :: N
    1202             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1203             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1204             :          INTENT(OUT)                                     :: eigvals
    1205             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1206             :          INTENT(OUT)                                     :: eigvecs
    1207             : 
    1208             :       INTEGER                                            :: info, liwork, lwork
    1209           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iwork
    1210           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work
    1211           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp
    1212             : 
    1213          24 :       ALLOCATE (eigvecs(N, N), tmp(N, N))
    1214          12 :       ALLOCATE (eigvals(N))
    1215             : 
    1216             :       ! symmetrize sm
    1217         172 :       eigvecs(:, :) = 0.5*(sm + TRANSPOSE(sm))
    1218             : 
    1219             :       ! probe optimal sizes for WORK and IWORK
    1220           4 :       LWORK = -1
    1221           4 :       LIWORK = -1
    1222           4 :       ALLOCATE (WORK(1))
    1223           4 :       ALLOCATE (IWORK(1))
    1224           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1225           4 :       LWORK = INT(WORK(1))
    1226           4 :       LIWORK = INT(IWORK(1))
    1227           4 :       DEALLOCATE (IWORK)
    1228           4 :       DEALLOCATE (WORK)
    1229             : 
    1230             :       ! calculate eigenvalues and eigenvectors
    1231          12 :       ALLOCATE (WORK(LWORK))
    1232          12 :       ALLOCATE (IWORK(LIWORK))
    1233           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1234           4 :       DEALLOCATE (IWORK)
    1235           4 :       DEALLOCATE (WORK)
    1236           4 :       IF (INFO .NE. 0) CPABORT("dsyevd did not succeed")
    1237             : 
    1238           4 :       DEALLOCATE (tmp)
    1239           4 :    END SUBROUTINE eigdecomp
    1240             : 
    1241             : ! **************************************************************************************************
    1242             : !> \brief Calculate the sign matrix from eigenvalues and eigenvectors of a matrix
    1243             : !> \param sm_sign ...
    1244             : !> \param eigvals ...
    1245             : !> \param eigvecs ...
    1246             : !> \param N ...
    1247             : !> \param mu_correction ...
    1248             : !> \par History
    1249             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1250             : !> \author Michael Lass, Robert Schade
    1251             : ! **************************************************************************************************
    1252           4 :    SUBROUTINE sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, mu_correction)
    1253             :       INTEGER                                            :: N
    1254             :       REAL(KIND=dp), INTENT(IN)                          :: eigvecs(N, N), eigvals(N)
    1255             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1256             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1257             : 
    1258             :       INTEGER                                            :: i
    1259           4 :       REAL(KIND=dp)                                      :: modified_eigval, tmp(N, N)
    1260             : 
    1261         172 :       sm_sign = 0
    1262          28 :       DO i = 1, N
    1263          24 :          modified_eigval = eigvals(i) - mu_correction
    1264          28 :          IF (modified_eigval > 0) THEN
    1265           6 :             sm_sign(i, i) = 1.0
    1266          18 :          ELSE IF (modified_eigval < 0) THEN
    1267          18 :             sm_sign(i, i) = -1.0
    1268             :          ELSE
    1269           0 :             sm_sign(i, i) = 0.0
    1270             :          END IF
    1271             :       END DO
    1272             : 
    1273             :       ! Create matrix with eigenvalues in {-1,0,1} and eigenvectors of sm:
    1274             :       ! sm_sign = eigvecs * sm_sign * eigvecs.T
    1275           4 :       CALL dgemm('N', 'N', N, N, N, 1.0_dp, eigvecs, N, sm_sign, N, 0.0_dp, tmp, N)
    1276           4 :       CALL dgemm('N', 'T', N, N, N, 1.0_dp, tmp, N, eigvecs, N, 0.0_dp, sm_sign, N)
    1277           4 :    END SUBROUTINE sign_from_eigdecomp
    1278             : 
    1279             : ! **************************************************************************************************
    1280             : !> \brief Compute partial trace of a matrix from its eigenvalues and eigenvectors
    1281             : !> \param eigvals ...
    1282             : !> \param eigvecs ...
    1283             : !> \param firstcol ...
    1284             : !> \param lastcol ...
    1285             : !> \param mu_correction ...
    1286             : !> \return ...
    1287             : !> \par History
    1288             : !>       2020.05 Created [Michael Lass]
    1289             : !> \author Michael Lass
    1290             : ! **************************************************************************************************
    1291          36 :    FUNCTION trace_from_eigdecomp(eigvals, eigvecs, firstcol, lastcol, mu_correction) RESULT(trace)
    1292             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1293             :          INTENT(IN)                                      :: eigvals
    1294             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1295             :          INTENT(IN)                                      :: eigvecs
    1296             :       INTEGER, INTENT(IN)                                :: firstcol, lastcol
    1297             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1298             :       REAL(KIND=dp)                                      :: trace
    1299             : 
    1300             :       INTEGER                                            :: i, j, sm_size
    1301             :       REAL(KIND=dp)                                      :: modified_eigval, tmpsum
    1302          36 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: mapped_eigvals
    1303             : 
    1304          36 :       sm_size = SIZE(eigvals)
    1305         108 :       ALLOCATE (mapped_eigvals(sm_size))
    1306             : 
    1307         252 :       DO i = 1, sm_size
    1308         216 :          modified_eigval = eigvals(i) - mu_correction
    1309         252 :          IF (modified_eigval > 0) THEN
    1310          26 :             mapped_eigvals(i) = 1.0
    1311         190 :          ELSE IF (modified_eigval < 0) THEN
    1312         190 :             mapped_eigvals(i) = -1.0
    1313             :          ELSE
    1314           0 :             mapped_eigvals(i) = 0.0
    1315             :          END IF
    1316             :       END DO
    1317             : 
    1318          36 :       trace = 0.0_dp
    1319         252 :       DO i = firstcol, lastcol
    1320             :          tmpsum = 0.0_dp
    1321        1512 :          DO j = 1, sm_size
    1322        1512 :             tmpsum = tmpsum + (eigvecs(i, j)*mapped_eigvals(j)*eigvecs(i, j))
    1323             :          END DO
    1324         252 :          trace = trace - 0.5_dp*tmpsum + 0.5_dp
    1325             :       END DO
    1326          36 :    END FUNCTION trace_from_eigdecomp
    1327             : 
    1328             : ! **************************************************************************************************
    1329             : !> \brief Calculate the sign matrix by direct calculation of all eigenvalues and eigenvectors
    1330             : !> \param sm_sign ...
    1331             : !> \param sm ...
    1332             : !> \param N ...
    1333             : !> \par History
    1334             : !>       2020.02 Created [Michael Lass, Robert Schade]
    1335             : !>       2020.05 Extracted eigdecomp and sign_from_eigdecomp [Michael Lass]
    1336             : !> \author Michael Lass, Robert Schade
    1337             : ! **************************************************************************************************
    1338           2 :    SUBROUTINE dense_matrix_sign_direct(sm_sign, sm, N)
    1339             :       INTEGER, INTENT(IN)                                :: N
    1340             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1341             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1342             : 
    1343           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigvals
    1344           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: eigvecs
    1345             : 
    1346             :       CALL eigdecomp(sm, N, eigvals, eigvecs)
    1347           2 :       CALL sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, 0.0_dp)
    1348             : 
    1349           2 :       DEALLOCATE (eigvals, eigvecs)
    1350           2 :    END SUBROUTINE dense_matrix_sign_direct
    1351             : 
    1352             : ! **************************************************************************************************
    1353             : !> \brief Submatrix method
    1354             : !> \param matrix_sign ...
    1355             : !> \param matrix ...
    1356             : !> \param threshold ...
    1357             : !> \param sign_order ...
    1358             : !> \param submatrix_sign_method ...
    1359             : !> \par History
    1360             : !>       2019.03 created [Robert Schade]
    1361             : !>       2019.06 impl. submatrix method [Michael Lass]
    1362             : !> \author Robert Schade, Michael Lass
    1363             : ! **************************************************************************************************
    1364           6 :    SUBROUTINE matrix_sign_submatrix(matrix_sign, matrix, threshold, sign_order, submatrix_sign_method)
    1365             : 
    1366             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1367             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1368             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1369             :       INTEGER, INTENT(IN)                                :: submatrix_sign_method
    1370             : 
    1371             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix'
    1372             : 
    1373             :       INTEGER                                            :: group, handle, i, myrank, nblkcols, &
    1374             :                                                             order, sm_size, unit_nr
    1375           6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1376           6 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign
    1377             :       TYPE(cp_logger_type), POINTER                      :: logger
    1378             :       TYPE(dbcsr_distribution_type)                      :: dist
    1379           6 :       TYPE(submatrix_dissection_type)                    :: dissection
    1380             : 
    1381           6 :       CALL timeset(routineN, handle)
    1382             : 
    1383             :       ! print output on all ranks
    1384           6 :       logger => cp_get_default_logger()
    1385           6 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1386             : 
    1387           6 :       IF (PRESENT(sign_order)) THEN
    1388           6 :          order = sign_order
    1389             :       ELSE
    1390           0 :          order = 2
    1391             :       END IF
    1392             : 
    1393           6 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group)
    1394           6 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1395             : 
    1396           6 :       CALL dissection%init(matrix)
    1397           6 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1398             : 
    1399             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1400             :       !$OMP          PRIVATE(sm, sm_sign, sm_size) &
    1401           6 :       !$OMP          SHARED(dissection, myrank, my_sms, order, submatrix_sign_method, threshold, unit_nr)
    1402             :       !$OMP DO SCHEDULE(GUIDED)
    1403             :       DO i = 1, SIZE(my_sms)
    1404             :          WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "processing submatrix", my_sms(i)
    1405             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1406             :          sm_size = SIZE(sm, 1)
    1407             :          ALLOCATE (sm_sign(sm_size, sm_size))
    1408             :          SELECT CASE (submatrix_sign_method)
    1409             :          CASE (ls_scf_submatrix_sign_ns)
    1410             :             CALL dense_matrix_sign_Newton_Schulz(sm_sign, sm, my_sms(i), threshold, order)
    1411             :          CASE (ls_scf_submatrix_sign_direct, ls_scf_submatrix_sign_direct_muadj, ls_scf_submatrix_sign_direct_muadj_lowmem)
    1412             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1413             :          CASE DEFAULT
    1414             :             CPABORT("Unkown submatrix sign method.")
    1415             :          END SELECT
    1416             :          CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1417             :          DEALLOCATE (sm, sm_sign)
    1418             :       END DO
    1419             :       !$OMP END DO
    1420             :       !$OMP END PARALLEL
    1421             : 
    1422           6 :       CALL dissection%communicate_results(matrix_sign)
    1423           6 :       CALL dissection%final
    1424             : 
    1425           6 :       CALL timestop(handle)
    1426             : 
    1427          12 :    END SUBROUTINE matrix_sign_submatrix
    1428             : 
    1429             : ! **************************************************************************************************
    1430             : !> \brief Submatrix method with internal adjustment of chemical potential
    1431             : !> \param matrix_sign ...
    1432             : !> \param matrix ...
    1433             : !> \param mu ...
    1434             : !> \param nelectron ...
    1435             : !> \param threshold ...
    1436             : !> \param variant ...
    1437             : !> \par History
    1438             : !>       2020.05 Created [Michael Lass]
    1439             : !> \author Robert Schade, Michael Lass
    1440             : ! **************************************************************************************************
    1441           4 :    SUBROUTINE matrix_sign_submatrix_mu_adjust(matrix_sign, matrix, mu, nelectron, threshold, variant)
    1442             : 
    1443             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1444             :       REAL(KIND=dp), INTENT(INOUT)                       :: mu
    1445             :       INTEGER, INTENT(IN)                                :: nelectron
    1446             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1447             :       INTEGER, INTENT(IN)                                :: variant
    1448             : 
    1449             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix_mu_adjust'
    1450             :       REAL(KIND=dp), PARAMETER                           :: initial_increment = 0.01_dp
    1451             : 
    1452             :       INTEGER                                            :: group_handle, handle, i, j, myrank, &
    1453             :                                                             nblkcols, sm_firstcol, sm_lastcol, &
    1454             :                                                             sm_size, unit_nr
    1455           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1456             :       LOGICAL                                            :: has_mu_high, has_mu_low
    1457             :       REAL(KIND=dp)                                      :: increment, mu_high, mu_low, new_mu, trace
    1458           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign, tmp
    1459             :       TYPE(cp_logger_type), POINTER                      :: logger
    1460             :       TYPE(dbcsr_distribution_type)                      :: dist
    1461           4 :       TYPE(eigbuf), ALLOCATABLE, DIMENSION(:)            :: eigbufs
    1462             :       TYPE(mp_comm_type)                                 :: group
    1463           4 :       TYPE(submatrix_dissection_type)                    :: dissection
    1464             : 
    1465           4 :       CALL timeset(routineN, handle)
    1466             : 
    1467             :       ! print output on all ranks
    1468           4 :       logger => cp_get_default_logger()
    1469           4 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1470             : 
    1471           4 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group_handle)
    1472           4 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1473             : 
    1474           4 :       CALL group%set_handle(group_handle)
    1475             : 
    1476           4 :       CALL dissection%init(matrix)
    1477           4 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1478             : 
    1479          12 :       ALLOCATE (eigbufs(SIZE(my_sms)))
    1480             : 
    1481             :       trace = 0.0_dp
    1482             : 
    1483             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1484             :       !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j, tmp) &
    1485             :       !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, threshold, variant) &
    1486           4 :       !$OMP          REDUCTION(+:trace)
    1487             :       !$OMP DO SCHEDULE(GUIDED)
    1488             :       DO i = 1, SIZE(my_sms)
    1489             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1490             :          sm_size = SIZE(sm, 1)
    1491             :          WRITE (unit_nr, *) "Rank", myrank, "processing submatrix", my_sms(i), "size", sm_size
    1492             : 
    1493             :          CALL dissection%get_relevant_sm_columns(my_sms(i), sm_firstcol, sm_lastcol)
    1494             : 
    1495             :          IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1496             :             ! Store all eigenvectors in buffer. We will use it to compute sm_sign at the end.
    1497             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=eigbufs(i)%eigvecs)
    1498             :          ELSE
    1499             :             ! Only store eigenvectors that are required for mu adjustment.
    1500             :             ! Calculate sm_sign right away in the hope that mu is already correct.
    1501             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=tmp)
    1502             :             ALLOCATE (eigbufs(i)%eigvecs(sm_firstcol:sm_lastcol, 1:sm_size))
    1503             :             eigbufs(i)%eigvecs(:, :) = tmp(sm_firstcol:sm_lastcol, 1:sm_size)
    1504             : 
    1505             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1506             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, tmp, sm_size, 0.0_dp)
    1507             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1508             :             DEALLOCATE (sm_sign, tmp)
    1509             :          END IF
    1510             : 
    1511             :          DEALLOCATE (sm)
    1512             :          trace = trace + trace_from_eigdecomp(eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_firstcol, sm_lastcol, 0.0_dp)
    1513             :       END DO
    1514             :       !$OMP END DO
    1515             :       !$OMP END PARALLEL
    1516             : 
    1517           4 :       has_mu_low = .FALSE.
    1518           4 :       has_mu_high = .FALSE.
    1519           4 :       increment = initial_increment
    1520           4 :       new_mu = mu
    1521          72 :       DO i = 1, 30
    1522          72 :          CALL group%sum(trace)
    1523          72 :          IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,1X,F13.9,1X,F15.9)') &
    1524          72 :             "Density matrix:  mu, trace error: ", new_mu, trace - nelectron
    1525          72 :          IF (ABS(trace - nelectron) < 0.5_dp) EXIT
    1526          68 :          IF (trace < nelectron) THEN
    1527           8 :             mu_low = new_mu
    1528           8 :             new_mu = new_mu + increment
    1529           8 :             has_mu_low = .TRUE.
    1530           8 :             increment = increment*2
    1531             :          ELSE
    1532          60 :             mu_high = new_mu
    1533          60 :             new_mu = new_mu - increment
    1534          60 :             has_mu_high = .TRUE.
    1535          60 :             increment = increment*2
    1536             :          END IF
    1537             : 
    1538          68 :          IF (has_mu_low .AND. has_mu_high) THEN
    1539          20 :             new_mu = (mu_low + mu_high)/2
    1540          20 :             IF (ABS(mu_high - mu_low) < threshold) EXIT
    1541             :          END IF
    1542             : 
    1543             :          trace = 0
    1544             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1545             :          !$OMP          PRIVATE(i, sm_sign, tmp, sm_size, sm_firstcol, sm_lastcol) &
    1546             :          !$OMP          SHARED(dissection, my_sms, unit_nr, eigbufs, mu, new_mu, nelectron) &
    1547          72 :          !$OMP          REDUCTION(+:trace)
    1548             :          !$OMP DO SCHEDULE(GUIDED)
    1549             :          DO j = 1, SIZE(my_sms)
    1550             :             sm_size = SIZE(eigbufs(j)%eigvals)
    1551             :             CALL dissection%get_relevant_sm_columns(my_sms(j), sm_firstcol, sm_lastcol)
    1552             :             trace = trace + trace_from_eigdecomp(eigbufs(j)%eigvals, eigbufs(j)%eigvecs, sm_firstcol, sm_lastcol, new_mu - mu)
    1553             :          END DO
    1554             :          !$OMP END DO
    1555             :          !$OMP END PARALLEL
    1556             :       END DO
    1557             : 
    1558             :       ! Finalize sign matrix from eigendecompositions if we kept all eigenvectors
    1559           4 :       IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1560             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1561             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1562           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1563             :          !$OMP DO SCHEDULE(GUIDED)
    1564             :          DO i = 1, SIZE(my_sms)
    1565             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "finalizing submatrix", my_sms(i)
    1566             :             sm_size = SIZE(eigbufs(i)%eigvals)
    1567             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1568             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_size, new_mu - mu)
    1569             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1570             :             DEALLOCATE (sm_sign)
    1571             :          END DO
    1572             :          !$OMP END DO
    1573             :          !$OMP END PARALLEL
    1574             :       END IF
    1575             : 
    1576           6 :       DEALLOCATE (eigbufs)
    1577             : 
    1578             :       ! If we only stored parts of the eigenvectors and mu has changed, we need to recompute sm_sign
    1579           4 :       IF ((variant .EQ. ls_scf_submatrix_sign_direct_muadj_lowmem) .AND. (mu .NE. new_mu)) THEN
    1580             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1581             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1582           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1583             :          !$OMP DO SCHEDULE(GUIDED)
    1584             :          DO i = 1, SIZE(my_sms)
    1585             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "reprocessing submatrix", my_sms(i)
    1586             :             CALL dissection%generate_submatrix(my_sms(i), sm)
    1587             :             sm_size = SIZE(sm, 1)
    1588             :             DO j = 1, sm_size
    1589             :                sm(j, j) = sm(j, j) + mu - new_mu
    1590             :             END DO
    1591             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1592             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1593             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1594             :             DEALLOCATE (sm, sm_sign)
    1595             :          END DO
    1596             :          !$OMP END DO
    1597             :          !$OMP END PARALLEL
    1598             :       END IF
    1599             : 
    1600           4 :       mu = new_mu
    1601             : 
    1602           4 :       CALL dissection%communicate_results(matrix_sign)
    1603           4 :       CALL dissection%final
    1604             : 
    1605           4 :       CALL timestop(handle)
    1606             : 
    1607          12 :    END SUBROUTINE matrix_sign_submatrix_mu_adjust
    1608             : 
    1609             : ! **************************************************************************************************
    1610             : !> \brief compute the sqrt of a matrix via the sign function and the corresponding Newton-Schulz iterations
    1611             : !>        the order of the algorithm should be 2..5, 3 or 5 is recommended
    1612             : !> \param matrix_sqrt ...
    1613             : !> \param matrix_sqrt_inv ...
    1614             : !> \param matrix ...
    1615             : !> \param threshold ...
    1616             : !> \param order ...
    1617             : !> \param eps_lanczos ...
    1618             : !> \param max_iter_lanczos ...
    1619             : !> \param symmetrize ...
    1620             : !> \param converged ...
    1621             : !> \param iounit ...
    1622             : !> \par History
    1623             : !>       2010.10 created [Joost VandeVondele]
    1624             : !> \author Joost VandeVondele
    1625             : ! **************************************************************************************************
    1626       14412 :    SUBROUTINE matrix_sqrt_Newton_Schulz(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1627             :                                         eps_lanczos, max_iter_lanczos, symmetrize, converged, iounit)
    1628             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1629             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1630             :       INTEGER, INTENT(IN)                                :: order
    1631             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1632             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1633             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1634             :       INTEGER, INTENT(IN), OPTIONAL                      :: iounit
    1635             : 
    1636             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sqrt_Newton_Schulz'
    1637             : 
    1638             :       INTEGER                                            :: handle, i, unit_nr
    1639             :       INTEGER(KIND=int_8)                                :: flop1, flop2, flop3, flop4, flop5
    1640             :       LOGICAL                                            :: arnoldi_converged, tsym
    1641             :       REAL(KIND=dp)                                      :: a, b, c, conv, d, frob_matrix, &
    1642             :                                                             frob_matrix_base, gershgorin_norm, &
    1643             :                                                             max_ev, min_ev, oa, ob, oc, &
    1644             :                                                             occ_matrix, od, scaling, t1, t2
    1645             :       TYPE(cp_logger_type), POINTER                      :: logger
    1646             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
    1647             : 
    1648       14412 :       CALL timeset(routineN, handle)
    1649             : 
    1650       14412 :       IF (PRESENT(iounit)) THEN
    1651       13060 :          unit_nr = iounit
    1652             :       ELSE
    1653        1352 :          logger => cp_get_default_logger()
    1654        1352 :          IF (logger%para_env%is_source()) THEN
    1655         676 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1656             :          ELSE
    1657         676 :             unit_nr = -1
    1658             :          END IF
    1659             :       END IF
    1660             : 
    1661       14412 :       IF (PRESENT(converged)) converged = .FALSE.
    1662       14412 :       IF (PRESENT(symmetrize)) THEN
    1663           0 :          tsym = symmetrize
    1664             :       ELSE
    1665             :          tsym = .TRUE.
    1666             :       END IF
    1667             : 
    1668             :       ! for stability symmetry can not be assumed
    1669       14412 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1670       14412 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1671       14412 :       IF (order .GE. 4) THEN
    1672          20 :          CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1673             :       END IF
    1674             : 
    1675       14412 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1676       14412 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1677       14412 :       CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1678       14412 :       CALL dbcsr_copy(matrix_sqrt, matrix)
    1679             : 
    1680             :       ! scale the matrix to get into the convergence range
    1681       14412 :       IF (order == 0) THEN
    1682             : 
    1683           0 :          gershgorin_norm = dbcsr_gershgorin_norm(matrix_sqrt)
    1684           0 :          frob_matrix = dbcsr_frobenius_norm(matrix_sqrt)
    1685           0 :          scaling = 1.0_dp/MIN(frob_matrix, gershgorin_norm)
    1686             : 
    1687             :       ELSE
    1688             : 
    1689             :          ! scale the matrix to get into the convergence range
    1690             :          CALL arnoldi_extremal(matrix_sqrt, max_ev, min_ev, threshold=eps_lanczos, &
    1691       14412 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1692       14412 :          IF (unit_nr > 0) THEN
    1693         676 :             WRITE (unit_nr, *)
    1694         676 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1695         676 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1696         676 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1697             :          END IF
    1698             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1699             :          ! and adjust the scaling to be on the safe side
    1700       14412 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1701             : 
    1702             :       END IF
    1703             : 
    1704       14412 :       CALL dbcsr_scale(matrix_sqrt, scaling)
    1705       14412 :       CALL dbcsr_filter(matrix_sqrt, threshold)
    1706       14412 :       IF (unit_nr > 0) THEN
    1707         676 :          WRITE (unit_nr, *)
    1708         676 :          WRITE (unit_nr, *) "Order=", order
    1709             :       END IF
    1710             : 
    1711       67928 :       DO i = 1, 100
    1712             : 
    1713       67928 :          t1 = m_walltime()
    1714             : 
    1715             :          ! tmp1 = Zk * Yk - I
    1716             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1717       67928 :                              filter_eps=threshold, flop=flop1)
    1718       67928 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1719       67928 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1720             : 
    1721             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1722       67928 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    1723             : 
    1724       67928 :          flop4 = 0; flop5 = 0
    1725          36 :          SELECT CASE (order)
    1726             :          CASE (0, 2)
    1727             :             ! update the above to 0.5*(3*I-Zk*Yk)
    1728          36 :             CALL dbcsr_add_on_diag(tmp1, -2.0_dp)
    1729          36 :             CALL dbcsr_scale(tmp1, -0.5_dp)
    1730             :          CASE (3)
    1731             :             ! tmp2 = tmp1 ** 2
    1732             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1733       67832 :                                 filter_eps=threshold, flop=flop4)
    1734             :             ! tmp1 = 1/16 * (16*I-8*tmp1+6*tmp1**2-5*tmp1**3)
    1735       67832 :             CALL dbcsr_add(tmp1, tmp2, -4.0_dp, 3.0_dp)
    1736       67832 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp)
    1737       67832 :             CALL dbcsr_scale(tmp1, 0.125_dp)
    1738             :          CASE (4) ! as expensive as case(5), so little need to use it
    1739             :             ! tmp2 = tmp1 ** 2
    1740             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1741          32 :                                 filter_eps=threshold, flop=flop4)
    1742             :             ! tmp3 = tmp2 * tmp1
    1743             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp1, 0.0_dp, tmp3, &
    1744          32 :                                 filter_eps=threshold, flop=flop5)
    1745          32 :             CALL dbcsr_scale(tmp1, -8.0_dp)
    1746          32 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp)
    1747          32 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp)
    1748          32 :             CALL dbcsr_add(tmp1, tmp3, 1.0_dp, -5.0_dp)
    1749          32 :             CALL dbcsr_scale(tmp1, 1/16.0_dp)
    1750             :          CASE (5)
    1751             :             ! Knuth's reformulation to evaluate the polynomial of 4th degree in 2 multiplications
    1752             :             ! p = y4+A*y3+B*y2+C*y+D
    1753             :             ! z := y * (y+a); P := (z+y+b) * (z+c) + d.
    1754             :             ! a=(A-1)/2 ; b=B*(a+1)-C-a*(a+1)*(a+1)
    1755             :             ! c=B-b-a*(a+1)
    1756             :             ! d=D-bc
    1757          28 :             oa = -40.0_dp/35.0_dp
    1758          28 :             ob = 48.0_dp/35.0_dp
    1759          28 :             oc = -64.0_dp/35.0_dp
    1760          28 :             od = 128.0_dp/35.0_dp
    1761          28 :             a = (oa - 1)/2
    1762          28 :             b = ob*(a + 1) - oc - a*(a + 1)**2
    1763          28 :             c = ob - b - a*(a + 1)
    1764          28 :             d = od - b*c
    1765             :             ! tmp2 = tmp1 ** 2 + a * tmp1
    1766             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1767          28 :                                 filter_eps=threshold, flop=flop4)
    1768          28 :             CALL dbcsr_add(tmp2, tmp1, 1.0_dp, a)
    1769             :             ! tmp3 = tmp2 + tmp1 + b
    1770          28 :             CALL dbcsr_copy(tmp3, tmp2)
    1771          28 :             CALL dbcsr_add(tmp3, tmp1, 1.0_dp, 1.0_dp)
    1772          28 :             CALL dbcsr_add_on_diag(tmp3, b)
    1773             :             ! tmp2 = tmp2 + c
    1774          28 :             CALL dbcsr_add_on_diag(tmp2, c)
    1775             :             ! tmp1 = tmp2 * tmp3
    1776             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
    1777          28 :                                 filter_eps=threshold, flop=flop5)
    1778             :             ! tmp1 = tmp1 + d
    1779          28 :             CALL dbcsr_add_on_diag(tmp1, d)
    1780             :             ! final scale
    1781          28 :             CALL dbcsr_scale(tmp1, 35.0_dp/128.0_dp)
    1782             :          CASE DEFAULT
    1783       67928 :             CPABORT("Illegal order value")
    1784             :          END SELECT
    1785             : 
    1786             :          ! tmp2 = Yk * tmp1 = Y(k+1)
    1787             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt, tmp1, 0.0_dp, tmp2, &
    1788       67928 :                              filter_eps=threshold, flop=flop2)
    1789             :          ! CALL dbcsr_filter(tmp2,threshold)
    1790       67928 :          CALL dbcsr_copy(matrix_sqrt, tmp2)
    1791             : 
    1792             :          ! tmp2 = tmp1 * Zk = Z(k+1)
    1793             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_sqrt_inv, 0.0_dp, tmp2, &
    1794       67928 :                              filter_eps=threshold, flop=flop3)
    1795             :          ! CALL dbcsr_filter(tmp2,threshold)
    1796       67928 :          CALL dbcsr_copy(matrix_sqrt_inv, tmp2)
    1797             : 
    1798       67928 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1799             : 
    1800             :          ! done iterating
    1801       67928 :          t2 = m_walltime()
    1802             : 
    1803       67928 :          conv = frob_matrix/frob_matrix_base
    1804             : 
    1805       67928 :          IF (unit_nr > 0) THEN
    1806        3818 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sqrt iter ", i, occ_matrix, &
    1807        3818 :                conv, t2 - t1, &
    1808        7636 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    1809        3818 :             CALL m_flush(unit_nr)
    1810             :          END IF
    1811             : 
    1812       67928 :          IF (abnormal_value(conv)) &
    1813           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    1814             : 
    1815             :          ! conv < SQRT(threshold)
    1816       67928 :          IF ((conv*conv) < threshold) THEN
    1817       14412 :             IF (PRESENT(converged)) converged = .TRUE.
    1818             :             EXIT
    1819             :          END IF
    1820             : 
    1821             :       END DO
    1822             : 
    1823             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    1824       14412 :       IF (tsym) THEN
    1825       14412 :          IF (unit_nr > 0) THEN
    1826         676 :             WRITE (unit_nr, '(T6,A20)') "Symmetrizing Results"
    1827             :          END IF
    1828       14412 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    1829       14412 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    1830       14412 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    1831       14412 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    1832             :       END IF
    1833             : 
    1834             :       ! this check is not really needed
    1835             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1836       14412 :                           filter_eps=threshold)
    1837       14412 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1838       14412 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1839       14412 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1840       14412 :       occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1841       14412 :       IF (unit_nr > 0) THEN
    1842         676 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sqrt iter ", i, occ_matrix, &
    1843        1352 :             frob_matrix/frob_matrix_base
    1844         676 :          WRITE (unit_nr, '()')
    1845         676 :          CALL m_flush(unit_nr)
    1846             :       END IF
    1847             : 
    1848             :       ! scale to proper end results
    1849       14412 :       CALL dbcsr_scale(matrix_sqrt, 1/SQRT(scaling))
    1850       14412 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    1851             : 
    1852       14412 :       CALL dbcsr_release(tmp1)
    1853       14412 :       CALL dbcsr_release(tmp2)
    1854       14412 :       IF (order .GE. 4) THEN
    1855          20 :          CALL dbcsr_release(tmp3)
    1856             :       END IF
    1857             : 
    1858       14412 :       CALL timestop(handle)
    1859             : 
    1860       14412 :    END SUBROUTINE matrix_sqrt_Newton_Schulz
    1861             : 
    1862             : ! **************************************************************************************************
    1863             : !> \brief compute the sqrt of a matrix via the general algorithm for the p-th root of Richters et al.
    1864             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1865             : !> \param matrix_sqrt ...
    1866             : !> \param matrix_sqrt_inv ...
    1867             : !> \param matrix ...
    1868             : !> \param threshold ...
    1869             : !> \param order ...
    1870             : !> \param eps_lanczos ...
    1871             : !> \param max_iter_lanczos ...
    1872             : !> \param symmetrize ...
    1873             : !> \param converged ...
    1874             : !> \par History
    1875             : !>       2019.04 created [Robert Schade]
    1876             : !> \author Robert Schade
    1877             : ! **************************************************************************************************
    1878          48 :    SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1879             :                                 eps_lanczos, max_iter_lanczos, symmetrize, converged)
    1880             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1881             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1882             :       INTEGER, INTENT(IN)                                :: order
    1883             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1884             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1885             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1886             : 
    1887             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sqrt_proot'
    1888             : 
    1889             :       INTEGER                                            :: choose, handle, i, ii, j, unit_nr
    1890             :       INTEGER(KIND=int_8)                                :: f, flop1, flop2, flop3, flop4, flop5
    1891             :       LOGICAL                                            :: arnoldi_converged, test, tsym
    1892             :       REAL(KIND=dp)                                      :: conv, frob_matrix, frob_matrix_base, &
    1893             :                                                             max_ev, min_ev, occ_matrix, scaling, &
    1894             :                                                             t1, t2
    1895             :       TYPE(cp_logger_type), POINTER                      :: logger
    1896             :       TYPE(dbcsr_type)                                   :: BK2A, matrixS, Rmat, tmp1, tmp2, tmp3
    1897             : 
    1898          16 :       CALL cite_reference(Richters2018)
    1899             : 
    1900          16 :       test = .FALSE.
    1901             : 
    1902          16 :       CALL timeset(routineN, handle)
    1903             : 
    1904          16 :       logger => cp_get_default_logger()
    1905          16 :       IF (logger%para_env%is_source()) THEN
    1906           8 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1907             :       ELSE
    1908           8 :          unit_nr = -1
    1909             :       END IF
    1910             : 
    1911          16 :       IF (PRESENT(converged)) converged = .FALSE.
    1912          16 :       IF (PRESENT(symmetrize)) THEN
    1913          16 :          tsym = symmetrize
    1914             :       ELSE
    1915             :          tsym = .TRUE.
    1916             :       END IF
    1917             : 
    1918             :       ! for stability symmetry can not be assumed
    1919          16 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1920          16 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1921          16 :       CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1922          16 :       CALL dbcsr_create(Rmat, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1923          16 :       CALL dbcsr_create(matrixS, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1924             : 
    1925          16 :       CALL dbcsr_copy(matrixS, matrix)
    1926             :       IF (1 .EQ. 1) THEN
    1927             :          ! scale the matrix to get into the convergence range
    1928             :          CALL arnoldi_extremal(matrixS, max_ev, min_ev, threshold=eps_lanczos, &
    1929          16 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1930          16 :          IF (unit_nr > 0) THEN
    1931           8 :             WRITE (unit_nr, *)
    1932           8 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1933           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1934           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1935             :          END IF
    1936             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1937             :          ! and adjust the scaling to be on the safe side
    1938          16 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1939          16 :          CALL dbcsr_scale(matrixS, scaling)
    1940          16 :          CALL dbcsr_filter(matrixS, threshold)
    1941             :       ELSE
    1942             :          scaling = 1.0_dp
    1943             :       END IF
    1944             : 
    1945          16 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1946          16 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1947             :       !CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1948             : 
    1949          16 :       IF (unit_nr > 0) THEN
    1950           8 :          WRITE (unit_nr, *)
    1951           8 :          WRITE (unit_nr, *) "Order=", order
    1952             :       END IF
    1953             : 
    1954          86 :       DO i = 1, 100
    1955             : 
    1956          86 :          t1 = m_walltime()
    1957             :          IF (1 .EQ. 1) THEN
    1958             :             !build R=1-A B_K^2
    1959             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp1, &
    1960          86 :                                 filter_eps=threshold, flop=flop1)
    1961             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrixS, tmp1, 0.0_dp, Rmat, &
    1962          86 :                                 filter_eps=threshold, flop=flop2)
    1963          86 :             CALL dbcsr_scale(Rmat, -1.0_dp)
    1964          86 :             CALL dbcsr_add_on_diag(Rmat, 1.0_dp)
    1965             : 
    1966          86 :             flop4 = 0; flop5 = 0
    1967          86 :             CALL dbcsr_set(tmp1, 0.0_dp)
    1968          86 :             CALL dbcsr_add_on_diag(tmp1, 2.0_dp)
    1969             : 
    1970          86 :             flop3 = 0
    1971             : 
    1972         274 :             DO j = 2, order
    1973         188 :                IF (j .EQ. 2) THEN
    1974          86 :                   CALL dbcsr_copy(tmp2, Rmat)
    1975             :                ELSE
    1976             :                   f = 0
    1977         102 :                   CALL dbcsr_copy(tmp3, tmp2)
    1978             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, Rmat, 0.0_dp, tmp2, &
    1979         102 :                                       filter_eps=threshold, flop=f)
    1980         102 :                   flop3 = flop3 + f
    1981             :                END IF
    1982         274 :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
    1983             :             END DO
    1984             :          ELSE
    1985             :             CALL dbcsr_create(BK2A, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1986             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, tmp3, &
    1987             :                                 filter_eps=threshold, flop=flop1)
    1988             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, tmp3, 0.0_dp, BK2A, &
    1989             :                                 filter_eps=threshold, flop=flop2)
    1990             :             CALL dbcsr_copy(Rmat, BK2A)
    1991             :             CALL dbcsr_add_on_diag(Rmat, -1.0_dp)
    1992             : 
    1993             :             CALL dbcsr_set(tmp1, 0.0_dp)
    1994             :             CALL dbcsr_add_on_diag(tmp1, 1.0_dp)
    1995             : 
    1996             :             CALL dbcsr_set(tmp2, 0.0_dp)
    1997             :             CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
    1998             : 
    1999             :             flop3 = 0
    2000             :             DO j = 1, order
    2001             :                !choose=factorial(order)/(factorial(j)*factorial(order-j))
    2002             :                choose = PRODUCT((/(ii, ii=1, order)/))/(PRODUCT((/(ii, ii=1, j)/))*PRODUCT((/(ii, ii=1, order - j)/)))
    2003             :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, -1.0_dp*(-1)**j*choose)
    2004             :                IF (j .LT. order) THEN
    2005             :                   f = 0
    2006             :                   CALL dbcsr_copy(tmp3, tmp2)
    2007             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, BK2A, 0.0_dp, tmp2, &
    2008             :                                       filter_eps=threshold, flop=f)
    2009             :                   flop3 = flop3 + f
    2010             :                END IF
    2011             :             END DO
    2012             :             CALL dbcsr_release(BK2A)
    2013             :          END IF
    2014             : 
    2015          86 :          CALL dbcsr_copy(tmp3, matrix_sqrt_inv)
    2016             :          CALL dbcsr_multiply("N", "N", 0.5_dp, tmp3, tmp1, 0.0_dp, matrix_sqrt_inv, &
    2017          86 :                              filter_eps=threshold, flop=flop4)
    2018             : 
    2019          86 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2020             : 
    2021             :          ! done iterating
    2022          86 :          t2 = m_walltime()
    2023             : 
    2024          86 :          conv = dbcsr_frobenius_norm(Rmat)
    2025             : 
    2026          86 :          IF (unit_nr > 0) THEN
    2027          43 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "PROOT sqrt iter ", i, occ_matrix, &
    2028          43 :                conv, t2 - t1, &
    2029          86 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    2030          43 :             CALL m_flush(unit_nr)
    2031             :          END IF
    2032             : 
    2033          86 :          IF (abnormal_value(conv)) &
    2034           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    2035             : 
    2036             :          ! conv < SQRT(threshold)
    2037          86 :          IF ((conv*conv) < threshold) THEN
    2038          16 :             IF (PRESENT(converged)) converged = .TRUE.
    2039             :             EXIT
    2040             :          END IF
    2041             : 
    2042             :       END DO
    2043             : 
    2044             :       ! scale to proper end results
    2045          16 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    2046             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix, 0.0_dp, matrix_sqrt, &
    2047          16 :                           filter_eps=threshold, flop=flop5)
    2048             : 
    2049             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    2050          16 :       IF (tsym) THEN
    2051           8 :          IF (unit_nr > 0) THEN
    2052           4 :             WRITE (unit_nr, '(A20)') "SYMMETRIZING RESULTS"
    2053             :          END IF
    2054           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    2055           8 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    2056           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    2057           8 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    2058             :       END IF
    2059             : 
    2060             :       ! this check is not really needed
    2061             :       IF (test) THEN
    2062             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    2063             :                              filter_eps=threshold)
    2064             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2065             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2066             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2067             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2068             :          IF (unit_nr > 0) THEN
    2069             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{1/2}-Eins error ", i, occ_matrix, &
    2070             :                frob_matrix/frob_matrix_base
    2071             :             WRITE (unit_nr, '()')
    2072             :             CALL m_flush(unit_nr)
    2073             :          END IF
    2074             : 
    2075             :          ! this check is not really needed
    2076             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp2, &
    2077             :                              filter_eps=threshold)
    2078             :          CALL dbcsr_multiply("N", "N", +1.0_dp, tmp2, matrix, 0.0_dp, tmp1, &
    2079             :                              filter_eps=threshold)
    2080             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2081             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2082             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2083             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2084             :          IF (unit_nr > 0) THEN
    2085             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{-1/2} S-Eins error ", i, occ_matrix, &
    2086             :                frob_matrix/frob_matrix_base
    2087             :             WRITE (unit_nr, '()')
    2088             :             CALL m_flush(unit_nr)
    2089             :          END IF
    2090             :       END IF
    2091             : 
    2092          16 :       CALL dbcsr_release(tmp1)
    2093          16 :       CALL dbcsr_release(tmp2)
    2094          16 :       CALL dbcsr_release(tmp3)
    2095          16 :       CALL dbcsr_release(Rmat)
    2096          16 :       CALL dbcsr_release(matrixS)
    2097             : 
    2098          16 :       CALL timestop(handle)
    2099          16 :    END SUBROUTINE matrix_sqrt_proot
    2100             : 
    2101             : ! **************************************************************************************************
    2102             : !> \brief ...
    2103             : !> \param matrix_exp ...
    2104             : !> \param matrix ...
    2105             : !> \param omega ...
    2106             : !> \param alpha ...
    2107             : !> \param threshold ...
    2108             : ! **************************************************************************************************
    2109        1146 :    SUBROUTINE matrix_exponential(matrix_exp, matrix, omega, alpha, threshold)
    2110             :       ! compute matrix_exp=omega*exp(alpha*matrix)
    2111             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_exp, matrix
    2112             :       REAL(KIND=dp), INTENT(IN)                          :: omega, alpha, threshold
    2113             : 
    2114             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_exponential'
    2115             :       REAL(dp), PARAMETER                                :: one = 1.0_dp, toll = 1.E-17_dp, &
    2116             :                                                             zero = 0.0_dp
    2117             : 
    2118             :       INTEGER                                            :: handle, i, k, unit_nr
    2119             :       REAL(dp)                                           :: factorial, norm_C, norm_D, norm_scalar
    2120             :       TYPE(cp_logger_type), POINTER                      :: logger
    2121             :       TYPE(dbcsr_type)                                   :: B, B_square, C, D, D_product
    2122             : 
    2123        1146 :       CALL timeset(routineN, handle)
    2124             : 
    2125        1146 :       logger => cp_get_default_logger()
    2126        1146 :       IF (logger%para_env%is_source()) THEN
    2127        1058 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2128             :       ELSE
    2129             :          unit_nr = -1
    2130             :       END IF
    2131             : 
    2132             :       ! Calculate the norm of the matrix alpha*matrix, and scale it until it is less than 1.0
    2133        1146 :       norm_scalar = ABS(alpha)*dbcsr_frobenius_norm(matrix)
    2134             : 
    2135             :       ! k=scaling parameter
    2136        1146 :       k = 1
    2137        1008 :       DO
    2138        2154 :          IF ((norm_scalar/2.0_dp**k) <= one) EXIT
    2139        1008 :          k = k + 1
    2140             :       END DO
    2141             : 
    2142             :       ! copy and scale the input matrix in matrix C and in matrix D
    2143        1146 :       CALL dbcsr_create(C, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2144        1146 :       CALL dbcsr_copy(C, matrix)
    2145        1146 :       CALL dbcsr_scale(C, alpha_scalar=alpha/2.0_dp**k)
    2146             : 
    2147        1146 :       CALL dbcsr_create(D, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2148        1146 :       CALL dbcsr_copy(D, C)
    2149             : 
    2150             :       !   write(*,*)
    2151             :       !   write(*,*)
    2152             :       !   CALL dbcsr_print(D, nodata=.FALSE., matlab_format=.TRUE., variable_name="D", unit_nr=6)
    2153             : 
    2154             :       ! set the B matrix as B=Identity+D
    2155        1146 :       CALL dbcsr_create(B, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2156        1146 :       CALL dbcsr_copy(B, D)
    2157        1146 :       CALL dbcsr_add_on_diag(B, alpha_scalar=one)
    2158             : 
    2159             :       !   CALL dbcsr_print(B, nodata=.FALSE., matlab_format=.TRUE., variable_name="B", unit_nr=6)
    2160             : 
    2161             :       ! Calculate the norm of C and moltiply by toll to be used as a threshold
    2162        1146 :       norm_C = toll*dbcsr_frobenius_norm(matrix)
    2163             : 
    2164             :       ! iteration for the truncated taylor series expansion
    2165        1146 :       CALL dbcsr_create(D_product, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2166        1146 :       i = 1
    2167             :       DO
    2168       12676 :          i = i + 1
    2169             :          ! compute D_product=D*C
    2170             :          CALL dbcsr_multiply("N", "N", one, D, C, &
    2171       12676 :                              zero, D_product, filter_eps=threshold)
    2172             : 
    2173             :          ! copy D_product in D
    2174       12676 :          CALL dbcsr_copy(D, D_product)
    2175             : 
    2176             :          ! calculate B=B+D_product/fat(i)
    2177       12676 :          factorial = ifac(i)
    2178       12676 :          CALL dbcsr_add(B, D_product, one, factorial)
    2179             : 
    2180             :          ! check for convergence using the norm of D (copy of the matrix D_product) and C
    2181       12676 :          norm_D = factorial*dbcsr_frobenius_norm(D)
    2182       12676 :          IF (norm_D < norm_C) EXIT
    2183             :       END DO
    2184             : 
    2185             :       ! start the k iteration for the squaring of the matrix
    2186        1146 :       CALL dbcsr_create(B_square, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2187        3300 :       DO i = 1, k
    2188             :          !compute B_square=B*B
    2189             :          CALL dbcsr_multiply("N", "N", one, B, B, &
    2190        2154 :                              zero, B_square, filter_eps=threshold)
    2191             :          ! copy Bsquare in B to iterate
    2192        3300 :          CALL dbcsr_copy(B, B_square)
    2193             :       END DO
    2194             : 
    2195             :       ! copy B_square in matrix_exp and
    2196        1146 :       CALL dbcsr_copy(matrix_exp, B_square)
    2197             : 
    2198             :       ! scale matrix_exp by omega, matrix_exp=omega*B_square
    2199        1146 :       CALL dbcsr_scale(matrix_exp, alpha_scalar=omega)
    2200             :       ! write(*,*) alpha,omega
    2201             : 
    2202        1146 :       CALL dbcsr_release(B)
    2203        1146 :       CALL dbcsr_release(C)
    2204        1146 :       CALL dbcsr_release(D)
    2205        1146 :       CALL dbcsr_release(D_product)
    2206        1146 :       CALL dbcsr_release(B_square)
    2207             : 
    2208        1146 :       CALL timestop(handle)
    2209             : 
    2210        1146 :    END SUBROUTINE matrix_exponential
    2211             : 
    2212             : ! **************************************************************************************************
    2213             : !> \brief McWeeny purification of a matrix in the orthonormal basis
    2214             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2215             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2216             : !> \param max_steps Max number of iterations
    2217             : !> \par History
    2218             : !>       2013.01 created [Florian Schiffmann]
    2219             : !>       2014.07 slightly refactored [Ole Schuett]
    2220             : !> \author Florian Schiffmann
    2221             : ! **************************************************************************************************
    2222         234 :    SUBROUTINE purify_mcweeny_orth(matrix_p, threshold, max_steps)
    2223             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2224             :       REAL(KIND=dp)                                      :: threshold
    2225             :       INTEGER                                            :: max_steps
    2226             : 
    2227             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_orth'
    2228             : 
    2229             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2230             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2231             :       TYPE(cp_logger_type), POINTER                      :: logger
    2232             :       TYPE(dbcsr_type)                                   :: matrix_pp, matrix_tmp
    2233             : 
    2234         234 :       CALL timeset(routineN, handle)
    2235         234 :       logger => cp_get_default_logger()
    2236         234 :       IF (logger%para_env%is_source()) THEN
    2237         117 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2238             :       ELSE
    2239         117 :          unit_nr = -1
    2240             :       END IF
    2241             : 
    2242         234 :       CALL dbcsr_create(matrix_pp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2243         234 :       CALL dbcsr_create(matrix_tmp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2244         234 :       CALL dbcsr_trace(matrix_p(1), trace)
    2245             : 
    2246         476 :       DO ispin = 1, SIZE(matrix_p)
    2247         476 :          DO i = 1, max_steps
    2248             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_p(ispin), &
    2249         242 :                                 0.0_dp, matrix_pp, filter_eps=threshold)
    2250             : 
    2251             :             ! test convergence
    2252         242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2253         242 :             CALL dbcsr_add(matrix_tmp, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2254         242 :             frob_norm = dbcsr_frobenius_norm(matrix_tmp) ! tmp = PP - P
    2255         242 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2256         242 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2257             : 
    2258             :             ! construct new P
    2259         242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2260             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_pp, matrix_p(ispin), &
    2261         242 :                                 3.0_dp, matrix_tmp, filter_eps=threshold)
    2262         242 :             CALL dbcsr_copy(matrix_p(ispin), matrix_tmp) ! tmp = 3PP - 2PPP
    2263             : 
    2264             :             ! frob_norm < SQRT(trace*threshold)
    2265         242 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2266             :          END DO
    2267             :       END DO
    2268             : 
    2269         234 :       CALL dbcsr_release(matrix_pp)
    2270         234 :       CALL dbcsr_release(matrix_tmp)
    2271         234 :       CALL timestop(handle)
    2272         234 :    END SUBROUTINE purify_mcweeny_orth
    2273             : 
    2274             : ! **************************************************************************************************
    2275             : !> \brief McWeeny purification of a matrix in the non-orthonormal basis
    2276             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2277             : !> \param matrix_s Overlap-Matrix
    2278             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2279             : !> \param max_steps Max number of iterations
    2280             : !> \par History
    2281             : !>       2013.01 created [Florian Schiffmann]
    2282             : !>       2014.07 slightly refactored [Ole Schuett]
    2283             : !> \author Florian Schiffmann
    2284             : ! **************************************************************************************************
    2285         184 :    SUBROUTINE purify_mcweeny_nonorth(matrix_p, matrix_s, threshold, max_steps)
    2286             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2287             :       TYPE(dbcsr_type)                                   :: matrix_s
    2288             :       REAL(KIND=dp)                                      :: threshold
    2289             :       INTEGER                                            :: max_steps
    2290             : 
    2291             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_nonorth'
    2292             : 
    2293             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2294             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2295             :       TYPE(cp_logger_type), POINTER                      :: logger
    2296             :       TYPE(dbcsr_type)                                   :: matrix_ps, matrix_psp, matrix_test
    2297             : 
    2298         184 :       CALL timeset(routineN, handle)
    2299             : 
    2300         184 :       logger => cp_get_default_logger()
    2301         184 :       IF (logger%para_env%is_source()) THEN
    2302          92 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2303             :       ELSE
    2304          92 :          unit_nr = -1
    2305             :       END IF
    2306             : 
    2307         184 :       CALL dbcsr_create(matrix_ps, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2308         184 :       CALL dbcsr_create(matrix_psp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2309         184 :       CALL dbcsr_create(matrix_test, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2310             : 
    2311         368 :       DO ispin = 1, SIZE(matrix_p)
    2312         380 :          DO i = 1, max_steps
    2313             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_s, &
    2314         196 :                                 0.0_dp, matrix_ps, filter_eps=threshold)
    2315             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ps, matrix_p(ispin), &
    2316         196 :                                 0.0_dp, matrix_psp, filter_eps=threshold)
    2317         196 :             IF (i == 1) CALL dbcsr_trace(matrix_ps, trace)
    2318             : 
    2319             :             ! test convergence
    2320         196 :             CALL dbcsr_copy(matrix_test, matrix_psp)
    2321         196 :             CALL dbcsr_add(matrix_test, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2322         196 :             frob_norm = dbcsr_frobenius_norm(matrix_test) ! test = PSP - P
    2323         196 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,2f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2324         196 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2325             : 
    2326             :             ! construct new P
    2327         196 :             CALL dbcsr_copy(matrix_p(ispin), matrix_psp)
    2328             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_ps, matrix_psp, &
    2329         196 :                                 3.0_dp, matrix_p(ispin), filter_eps=threshold)
    2330             : 
    2331             :             ! frob_norm < SQRT(trace*threshold)
    2332         196 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2333             :          END DO
    2334             :       END DO
    2335             : 
    2336         184 :       CALL dbcsr_release(matrix_ps)
    2337         184 :       CALL dbcsr_release(matrix_psp)
    2338         184 :       CALL dbcsr_release(matrix_test)
    2339         184 :       CALL timestop(handle)
    2340         184 :    END SUBROUTINE purify_mcweeny_nonorth
    2341             : 
    2342           0 : END MODULE iterate_matrix

Generated by: LCOV version 1.15