LCOV - code coverage report
Current view: top level - src - iterate_matrix.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:422ac0d) Lines: 824 874 94.3 %
Date: 2025-04-02 06:58:30 Functions: 17 19 89.5 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : ! **************************************************************************************************
       8             : !> \brief Routines useful for iterative matrix calculations
       9             : !> \par History
      10             : !>       2010.10 created [Joost VandeVondele]
      11             : !> \author Joost VandeVondele
      12             : ! **************************************************************************************************
      13             : MODULE iterate_matrix
      14             :    USE arnoldi_api,                     ONLY: arnoldi_env_type,&
      15             :                                               arnoldi_extremal
      16             :    USE bibliography,                    ONLY: Richters2018,&
      17             :                                               cite_reference
      18             :    USE cp_dbcsr_api,                    ONLY: &
      19             :         dbcsr_add, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, dbcsr_distribution_get, &
      20             :         dbcsr_distribution_type, dbcsr_filter, dbcsr_get_info, dbcsr_get_matrix_type, &
      21             :         dbcsr_get_occupation, dbcsr_multiply, dbcsr_p_type, dbcsr_release, dbcsr_scale, dbcsr_set, &
      22             :         dbcsr_transposed, dbcsr_type, dbcsr_type_no_symmetry
      23             :    USE cp_dbcsr_contrib,                ONLY: dbcsr_add_on_diag,&
      24             :                                               dbcsr_frobenius_norm,&
      25             :                                               dbcsr_gershgorin_norm,&
      26             :                                               dbcsr_get_diag,&
      27             :                                               dbcsr_maxabs,&
      28             :                                               dbcsr_set_diag,&
      29             :                                               dbcsr_trace
      30             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      31             :                                               cp_logger_get_default_unit_nr,&
      32             :                                               cp_logger_type
      33             :    USE input_constants,                 ONLY: ls_scf_submatrix_sign_direct,&
      34             :                                               ls_scf_submatrix_sign_direct_muadj,&
      35             :                                               ls_scf_submatrix_sign_direct_muadj_lowmem,&
      36             :                                               ls_scf_submatrix_sign_ns
      37             :    USE kinds,                           ONLY: dp,&
      38             :                                               int_8
      39             :    USE machine,                         ONLY: m_flush,&
      40             :                                               m_walltime
      41             :    USE mathconstants,                   ONLY: ifac
      42             :    USE mathlib,                         ONLY: abnormal_value
      43             :    USE message_passing,                 ONLY: mp_comm_type
      44             :    USE submatrix_dissection,            ONLY: submatrix_dissection_type
      45             : #include "./base/base_uses.f90"
      46             : 
      47             :    IMPLICIT NONE
      48             : 
      49             :    PRIVATE
      50             : 
      51             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'iterate_matrix'
      52             : 
      53             :    TYPE :: eigbuf
      54             :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE    :: eigvals
      55             :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: eigvecs
      56             :    END TYPE eigbuf
      57             : 
      58             :    INTERFACE purify_mcweeny
      59             :       MODULE PROCEDURE purify_mcweeny_orth, purify_mcweeny_nonorth
      60             :    END INTERFACE
      61             : 
      62             :    PUBLIC :: invert_Hotelling, matrix_sign_Newton_Schulz, matrix_sqrt_Newton_Schulz, &
      63             :              matrix_sqrt_proot, matrix_sign_proot, matrix_sign_submatrix, matrix_exponential, &
      64             :              matrix_sign_submatrix_mu_adjust, purify_mcweeny, invert_Taylor, determinant
      65             : 
      66             : CONTAINS
      67             : 
      68             : ! *****************************************************************************
      69             : !> \brief Computes the determinant of a symmetric positive definite matrix
      70             : !>        using the trace of the matrix logarithm via Mercator series:
      71             : !>         det(A) = det(S)det(I+X)det(S), where S=diag(sqrt(Aii),..,sqrt(Ann))
      72             : !>         det(I+X) = Exp(Trace(Ln(I+X)))
      73             : !>         Ln(I+X) = X - X^2/2 + X^3/3 - X^4/4 + ..
      74             : !>        The series converges only if the Frobenius norm of X is less than 1.
      75             : !>        If it is more than one we compute (recursevily) the determinant of
      76             : !>        the square root of (I+X).
      77             : !> \param matrix ...
      78             : !> \param det - determinant
      79             : !> \param threshold ...
      80             : !> \par History
      81             : !>       2015.04 created [Rustam Z Khaliullin]
      82             : !> \author Rustam Z. Khaliullin
      83             : ! **************************************************************************************************
      84         132 :    RECURSIVE SUBROUTINE determinant(matrix, det, threshold)
      85             : 
      86             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix
      87             :       REAL(KIND=dp), INTENT(INOUT)                       :: det
      88             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
      89             : 
      90             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'determinant'
      91             : 
      92             :       INTEGER                                            :: handle, i, max_iter_lanczos, nsize, &
      93             :                                                             order_lanczos, sign_iter, unit_nr
      94             :       INTEGER(KIND=int_8)                                :: flop1
      95             :       INTEGER, SAVE                                      :: recursion_depth = 0
      96             :       REAL(KIND=dp)                                      :: det0, eps_lanczos, frobnorm, maxnorm, &
      97             :                                                             occ_matrix, t1, t2, trace
      98         132 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: diagonal
      99             :       TYPE(cp_logger_type), POINTER                      :: logger
     100             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
     101             : 
     102         132 :       CALL timeset(routineN, handle)
     103             : 
     104             :       ! get a useful output_unit
     105         132 :       logger => cp_get_default_logger()
     106         132 :       IF (logger%para_env%is_source()) THEN
     107          66 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     108             :       ELSE
     109          66 :          unit_nr = -1
     110             :       END IF
     111             : 
     112             :       ! Note: tmp1 and tmp2 have the same matrix type as the
     113             :       ! initial matrix (tmp3 does not have symmetry constraints)
     114             :       ! this might lead to uninteded results with anti-symmetric
     115             :       ! matrices
     116             :       CALL dbcsr_create(tmp1, template=matrix, &
     117         132 :                         matrix_type=dbcsr_type_no_symmetry)
     118             :       CALL dbcsr_create(tmp2, template=matrix, &
     119         132 :                         matrix_type=dbcsr_type_no_symmetry)
     120             :       CALL dbcsr_create(tmp3, template=matrix, &
     121         132 :                         matrix_type=dbcsr_type_no_symmetry)
     122             : 
     123             :       ! compute the product of the diagonal elements
     124             :       BLOCK
     125             :          TYPE(mp_comm_type) :: group
     126         132 :          CALL dbcsr_get_info(matrix, nfullrows_total=nsize, group=group)
     127         396 :          ALLOCATE (diagonal(nsize))
     128         132 :          CALL dbcsr_get_diag(matrix, diagonal)
     129         132 :          CALL group%sum(diagonal)
     130        2308 :          det = PRODUCT(diagonal)
     131             :       END BLOCK
     132             : 
     133             :       ! create diagonal SQRTI matrix
     134        2176 :       diagonal(:) = 1.0_dp/(SQRT(diagonal(:)))
     135             :       !ROLL CALL dbcsr_copy(tmp1,matrix)
     136         132 :       CALL dbcsr_desymmetrize(matrix, tmp1)
     137         132 :       CALL dbcsr_set(tmp1, 0.0_dp)
     138         132 :       CALL dbcsr_set_diag(tmp1, diagonal)
     139         132 :       CALL dbcsr_filter(tmp1, threshold)
     140         132 :       DEALLOCATE (diagonal)
     141             : 
     142             :       ! normalize the main diagonal, off-diagonal elements are scaled to
     143             :       ! make the norm of the matrix less than 1
     144             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     145             :                           matrix, &
     146             :                           tmp1, &
     147             :                           0.0_dp, tmp3, &
     148         132 :                           filter_eps=threshold)
     149             :       CALL dbcsr_multiply("N", "N", 1.0_dp, &
     150             :                           tmp1, &
     151             :                           tmp3, &
     152             :                           0.0_dp, tmp2, &
     153         132 :                           filter_eps=threshold)
     154             : 
     155             :       ! subtract the main diagonal to create matrix X
     156         132 :       CALL dbcsr_add_on_diag(tmp2, -1.0_dp)
     157         132 :       frobnorm = dbcsr_frobenius_norm(tmp2)
     158         132 :       IF (unit_nr > 0) THEN
     159          66 :          IF (recursion_depth .EQ. 0) THEN
     160          41 :             WRITE (unit_nr, '()')
     161             :          ELSE
     162             :             WRITE (unit_nr, '(T6,A28,1X,I15)') &
     163          25 :                "Recursive iteration:", recursion_depth
     164             :          END IF
     165             :          WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     166          66 :             "Frobenius norm:", frobnorm
     167          66 :          CALL m_flush(unit_nr)
     168             :       END IF
     169             : 
     170         132 :       IF (frobnorm .GE. 1.0_dp) THEN
     171             : 
     172          50 :          CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
     173             :          ! these controls should be provided as input
     174          50 :          order_lanczos = 3
     175          50 :          eps_lanczos = 1.0E-4_dp
     176          50 :          max_iter_lanczos = 40
     177             :          CALL matrix_sqrt_Newton_Schulz( &
     178             :             tmp3, & ! output sqrt
     179             :             tmp1, & ! output sqrti
     180             :             tmp2, & ! input original
     181             :             threshold=threshold, &
     182             :             order=order_lanczos, &
     183             :             eps_lanczos=eps_lanczos, &
     184          50 :             max_iter_lanczos=max_iter_lanczos)
     185          50 :          recursion_depth = recursion_depth + 1
     186          50 :          CALL determinant(tmp3, det0, threshold)
     187          50 :          recursion_depth = recursion_depth - 1
     188          50 :          det = det*det0*det0
     189             : 
     190             :       ELSE
     191             : 
     192             :          ! create accumulator
     193          82 :          CALL dbcsr_copy(tmp1, tmp2)
     194             :          ! re-create to make use of symmetry
     195             :          !ROLL CALL dbcsr_create(tmp3,template=matrix)
     196             : 
     197          82 :          IF (unit_nr > 0) WRITE (unit_nr, *)
     198             : 
     199             :          ! initialize the sign of the term
     200          82 :          sign_iter = -1
     201        1078 :          DO i = 1, 100
     202             : 
     203        1078 :             t1 = m_walltime()
     204             : 
     205             :             ! multiply X^i by X
     206             :             ! note that the first iteration evaluates X^2
     207             :             ! because the trace of X^1 is zero by construction
     208             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, &
     209             :                                 0.0_dp, tmp3, &
     210             :                                 filter_eps=threshold, &
     211        1078 :                                 flop=flop1)
     212        1078 :             CALL dbcsr_copy(tmp1, tmp3)
     213             : 
     214             :             ! get trace
     215        1078 :             CALL dbcsr_trace(tmp1, trace)
     216        1078 :             trace = trace*sign_iter/(1.0_dp*(i + 1))
     217        1078 :             sign_iter = -sign_iter
     218             : 
     219             :             ! update the determinant
     220        1078 :             det = det*EXP(trace)
     221             : 
     222        1078 :             occ_matrix = dbcsr_get_occupation(tmp1)
     223        1078 :             maxnorm = dbcsr_maxabs(tmp1)
     224             : 
     225        1078 :             t2 = m_walltime()
     226             : 
     227        1078 :             IF (unit_nr > 0) THEN
     228             :                WRITE (unit_nr, '(T6,A,1X,I3,1X,F7.5,F16.10,F10.3,F11.3)') &
     229         539 :                   "Determinant iter", i, occ_matrix, &
     230         539 :                   det, t2 - t1, &
     231        1078 :                   flop1/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     232         539 :                CALL m_flush(unit_nr)
     233             :             END IF
     234             : 
     235             :             ! exit if the trace is close to zero
     236        2156 :             IF (maxnorm < threshold) EXIT
     237             : 
     238             :          END DO ! end iterations
     239             : 
     240          82 :          IF (unit_nr > 0) THEN
     241          41 :             WRITE (unit_nr, '()')
     242          41 :             CALL m_flush(unit_nr)
     243             :          END IF
     244             : 
     245             :       END IF ! decide to do sqrt or not
     246             : 
     247         132 :       IF (unit_nr > 0) THEN
     248          66 :          IF (recursion_depth .EQ. 0) THEN
     249             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     250          41 :                "Final determinant:", det
     251          41 :             WRITE (unit_nr, '()')
     252             :          ELSE
     253             :             WRITE (unit_nr, '(T6,A28,1X,F15.10)') &
     254          25 :                "Recursive determinant:", det
     255             :          END IF
     256          66 :          CALL m_flush(unit_nr)
     257             :       END IF
     258             : 
     259         132 :       CALL dbcsr_release(tmp1)
     260         132 :       CALL dbcsr_release(tmp2)
     261         132 :       CALL dbcsr_release(tmp3)
     262             : 
     263         132 :       CALL timestop(handle)
     264             : 
     265         132 :    END SUBROUTINE determinant
     266             : 
     267             : ! **************************************************************************************************
     268             : !> \brief invert a symmetric positive definite diagonally dominant matrix
     269             : !> \param matrix_inverse ...
     270             : !> \param matrix ...
     271             : !> \param threshold convergence threshold nased on the max abs
     272             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     273             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     274             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     275             : !> \param accelerator_order ...
     276             : !> \param max_iter_lanczos ...
     277             : !> \param eps_lanczos ...
     278             : !> \param silent ...
     279             : !> \par History
     280             : !>       2010.10 created [Joost VandeVondele]
     281             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     282             : !> \author Joost VandeVondele
     283             : ! **************************************************************************************************
     284          26 :    SUBROUTINE invert_Taylor(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     285             :                             norm_convergence, filter_eps, accelerator_order, &
     286             :                             max_iter_lanczos, eps_lanczos, silent)
     287             : 
     288             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     289             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     290             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     291             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     292             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     293             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     294             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     295             : 
     296             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Taylor'
     297             : 
     298             :       INTEGER                                            :: accelerator_type, handle, i, &
     299             :                                                             my_max_iter_lanczos, nrows, unit_nr
     300             :       INTEGER(KIND=int_8)                                :: flop2
     301             :       LOGICAL                                            :: converged, use_inv_guess
     302             :       REAL(KIND=dp)                                      :: coeff, convergence, maxnorm_matrix, &
     303             :                                                             my_eps_lanczos, occ_matrix, t1, t2
     304          26 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: p_diagonal
     305             :       TYPE(cp_logger_type), POINTER                      :: logger
     306             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2, tmp3_sym
     307             : 
     308          26 :       CALL timeset(routineN, handle)
     309             : 
     310          26 :       logger => cp_get_default_logger()
     311          26 :       IF (logger%para_env%is_source()) THEN
     312          13 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     313             :       ELSE
     314          13 :          unit_nr = -1
     315             :       END IF
     316          26 :       IF (PRESENT(silent)) THEN
     317          26 :          IF (silent) unit_nr = -1
     318             :       END IF
     319             : 
     320          26 :       convergence = threshold
     321          26 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     322             : 
     323          26 :       accelerator_type = 0
     324          26 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     325           0 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     326             : 
     327          26 :       use_inv_guess = .FALSE.
     328          26 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     329             : 
     330          26 :       my_max_iter_lanczos = 64
     331          26 :       my_eps_lanczos = 1.0E-3_dp
     332          26 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     333          26 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     334             : 
     335          26 :       CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     336          26 :       CALL dbcsr_create(tmp2, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     337          26 :       CALL dbcsr_create(tmp3_sym, template=matrix_inverse)
     338             : 
     339          26 :       CALL dbcsr_get_info(matrix, nfullrows_total=nrows)
     340          78 :       ALLOCATE (p_diagonal(nrows))
     341             : 
     342             :       ! generate the initial guess
     343          26 :       IF (.NOT. use_inv_guess) THEN
     344             : 
     345          26 :          SELECT CASE (accelerator_type)
     346             :          CASE (0)
     347             :             ! use tmp1 to hold off-diagonal elements
     348          26 :             CALL dbcsr_desymmetrize(matrix, tmp1)
     349         858 :             p_diagonal(:) = 0.0_dp
     350          26 :             CALL dbcsr_set_diag(tmp1, p_diagonal)
     351             :             !CALL dbcsr_print(tmp1)
     352             :             ! invert the main diagonal
     353          26 :             CALL dbcsr_get_diag(matrix, p_diagonal)
     354         858 :             DO i = 1, nrows
     355         858 :                IF (p_diagonal(i) .NE. 0.0_dp) THEN
     356         416 :                   p_diagonal(i) = 1.0_dp/p_diagonal(i)
     357             :                END IF
     358             :             END DO
     359          26 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     360          26 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     361          26 :             CALL dbcsr_set_diag(matrix_inverse, p_diagonal)
     362             :          CASE DEFAULT
     363          26 :             CPABORT("Illegal accelerator order")
     364             :          END SELECT
     365             : 
     366             :       ELSE
     367             : 
     368           0 :          CPABORT("Guess is NYI")
     369             : 
     370             :       END IF
     371             : 
     372             :       CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, &
     373          26 :                           0.0_dp, tmp2, filter_eps=filter_eps)
     374             : 
     375          26 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     376             : 
     377             :       ! scale the approximate inverse to be within the convergence radius
     378          26 :       t1 = m_walltime()
     379             : 
     380             :       ! done with the initial guess, start iterations
     381          26 :       converged = .FALSE.
     382          26 :       CALL dbcsr_desymmetrize(matrix_inverse, tmp1)
     383          26 :       coeff = 1.0_dp
     384         284 :       DO i = 1, 100
     385             : 
     386             :          ! coeff = +/- 1
     387         284 :          coeff = -1.0_dp*coeff
     388             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, &
     389             :                              tmp3_sym, &
     390         284 :                              flop=flop2, filter_eps=filter_eps)
     391             :          !flop=flop2)
     392         284 :          CALL dbcsr_add(matrix_inverse, tmp3_sym, 1.0_dp, coeff)
     393         284 :          CALL dbcsr_release(tmp1)
     394         284 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     395         284 :          CALL dbcsr_desymmetrize(tmp3_sym, tmp1)
     396             : 
     397             :          ! for the convergence check
     398         284 :          maxnorm_matrix = dbcsr_maxabs(tmp3_sym)
     399             : 
     400         284 :          t2 = m_walltime()
     401         284 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     402             : 
     403         284 :          IF (unit_nr > 0) THEN
     404         142 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Taylor iter", i, occ_matrix, &
     405         142 :                maxnorm_matrix, t2 - t1, &
     406         284 :                flop2/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     407         142 :             CALL m_flush(unit_nr)
     408             :          END IF
     409             : 
     410         284 :          IF (maxnorm_matrix < convergence) THEN
     411             :             converged = .TRUE.
     412             :             EXIT
     413             :          END IF
     414             : 
     415         258 :          t1 = m_walltime()
     416             : 
     417             :       END DO
     418             : 
     419             :       !last convergence check
     420             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_inverse, 0.0_dp, tmp1, &
     421          26 :                           filter_eps=filter_eps)
     422          26 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     423             :       !frob_matrix =  dbcsr_frobenius_norm(tmp1)
     424          26 :       maxnorm_matrix = dbcsr_maxabs(tmp1)
     425          26 :       IF (unit_nr > 0) THEN
     426          13 :          WRITE (unit_nr, '(T6,A,E12.5)') "Final Taylor error", maxnorm_matrix
     427          13 :          WRITE (unit_nr, '()')
     428          13 :          CALL m_flush(unit_nr)
     429             :       END IF
     430          26 :       IF (maxnorm_matrix > convergence) THEN
     431           0 :          converged = .FALSE.
     432           0 :          IF (unit_nr > 0) THEN
     433           0 :             WRITE (unit_nr, *) 'Final convergence check failed'
     434             :          END IF
     435             :       END IF
     436             : 
     437          26 :       IF (.NOT. converged) THEN
     438           0 :          CPABORT("Taylor inversion did not converge")
     439             :       END IF
     440             : 
     441          26 :       CALL dbcsr_release(tmp1)
     442          26 :       CALL dbcsr_release(tmp2)
     443          26 :       CALL dbcsr_release(tmp3_sym)
     444             : 
     445          26 :       DEALLOCATE (p_diagonal)
     446             : 
     447          26 :       CALL timestop(handle)
     448             : 
     449          52 :    END SUBROUTINE invert_Taylor
     450             : 
     451             : ! **************************************************************************************************
     452             : !> \brief invert a symmetric positive definite matrix by Hotelling's method
     453             : !>        explicit symmetrization makes this code not suitable for other matrix types
     454             : !>        Currently a bit messy with the options, to to be cleaned soon
     455             : !> \param matrix_inverse ...
     456             : !> \param matrix ...
     457             : !> \param threshold convergence threshold nased on the max abs
     458             : !> \param use_inv_as_guess logical whether input can be used as guess for inverse
     459             : !> \param norm_convergence convergence threshold for the 2-norm, useful for approximate solutions
     460             : !> \param filter_eps filter_eps for matrix multiplications, if not passed nothing is filteres
     461             : !> \param accelerator_order ...
     462             : !> \param max_iter_lanczos ...
     463             : !> \param eps_lanczos ...
     464             : !> \param silent ...
     465             : !> \par History
     466             : !>       2010.10 created [Joost VandeVondele]
     467             : !>       2011.10 guess option added [Rustam Z Khaliullin]
     468             : !> \author Joost VandeVondele
     469             : ! **************************************************************************************************
     470        2032 :    SUBROUTINE invert_Hotelling(matrix_inverse, matrix, threshold, use_inv_as_guess, &
     471             :                                norm_convergence, filter_eps, accelerator_order, &
     472             :                                max_iter_lanczos, eps_lanczos, silent)
     473             : 
     474             :       TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: matrix_inverse, matrix
     475             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     476             :       LOGICAL, INTENT(IN), OPTIONAL                      :: use_inv_as_guess
     477             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: norm_convergence, filter_eps
     478             :       INTEGER, INTENT(IN), OPTIONAL                      :: accelerator_order, max_iter_lanczos
     479             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: eps_lanczos
     480             :       LOGICAL, INTENT(IN), OPTIONAL                      :: silent
     481             : 
     482             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'invert_Hotelling'
     483             : 
     484             :       INTEGER                                            :: accelerator_type, handle, i, &
     485             :                                                             my_max_iter_lanczos, unit_nr
     486             :       INTEGER(KIND=int_8)                                :: flop1, flop2
     487             :       LOGICAL                                            :: arnoldi_converged, converged, &
     488             :                                                             use_inv_guess
     489             :       REAL(KIND=dp) :: convergence, frob_matrix, gershgorin_norm, max_ev, maxnorm_matrix, min_ev, &
     490             :          my_eps_lanczos, my_filter_eps, occ_matrix, scalingf, t1, t2
     491             :       TYPE(cp_logger_type), POINTER                      :: logger
     492             :       TYPE(dbcsr_type), TARGET                           :: tmp1, tmp2
     493             : 
     494             :       !TYPE(arnoldi_env_type)                            :: arnoldi_env
     495             :       !TYPE(dbcsr_p_type), DIMENSION(1)                   :: mymat
     496             : 
     497        2032 :       CALL timeset(routineN, handle)
     498             : 
     499        2032 :       logger => cp_get_default_logger()
     500        2032 :       IF (logger%para_env%is_source()) THEN
     501        1016 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     502             :       ELSE
     503        1016 :          unit_nr = -1
     504             :       END IF
     505        2032 :       IF (PRESENT(silent)) THEN
     506        2014 :          IF (silent) unit_nr = -1
     507             :       END IF
     508             : 
     509        2032 :       convergence = threshold
     510        2032 :       IF (PRESENT(norm_convergence)) convergence = norm_convergence
     511             : 
     512        2032 :       accelerator_type = 1
     513        2032 :       IF (PRESENT(accelerator_order)) accelerator_type = accelerator_order
     514        1436 :       IF (accelerator_type .GT. 1) accelerator_type = 1
     515             : 
     516        2032 :       use_inv_guess = .FALSE.
     517        2032 :       IF (PRESENT(use_inv_as_guess)) use_inv_guess = use_inv_as_guess
     518             : 
     519        2032 :       my_max_iter_lanczos = 64
     520        2032 :       my_eps_lanczos = 1.0E-3_dp
     521        2032 :       IF (PRESENT(max_iter_lanczos)) my_max_iter_lanczos = max_iter_lanczos
     522        2032 :       IF (PRESENT(eps_lanczos)) my_eps_lanczos = eps_lanczos
     523             : 
     524        2032 :       my_filter_eps = threshold
     525        2032 :       IF (PRESENT(filter_eps)) my_filter_eps = filter_eps
     526             : 
     527             :       ! generate the initial guess
     528        2032 :       IF (.NOT. use_inv_guess) THEN
     529             : 
     530           0 :          SELECT CASE (accelerator_type)
     531             :          CASE (0)
     532           0 :             gershgorin_norm = dbcsr_gershgorin_norm(matrix)
     533           0 :             frob_matrix = dbcsr_frobenius_norm(matrix)
     534           0 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     535           0 :             CALL dbcsr_add_on_diag(matrix_inverse, 1/MIN(gershgorin_norm, frob_matrix))
     536             :          CASE (1)
     537             :             ! initialize matrix to unity and use arnoldi (below) to scale it into the convergence range
     538        1558 :             CALL dbcsr_set(matrix_inverse, 0.0_dp)
     539        1558 :             CALL dbcsr_add_on_diag(matrix_inverse, 1.0_dp)
     540             :          CASE DEFAULT
     541        1558 :             CPABORT("Illegal accelerator order")
     542             :          END SELECT
     543             : 
     544             :          ! everything commutes, therefore our all products will be symmetric
     545        1558 :          CALL dbcsr_create(tmp1, template=matrix_inverse)
     546             : 
     547             :       ELSE
     548             : 
     549             :          ! It is unlikely that our guess will commute with the matrix, therefore the first product will
     550             :          ! be non symmetric
     551         474 :          CALL dbcsr_create(tmp1, template=matrix_inverse, matrix_type=dbcsr_type_no_symmetry)
     552             : 
     553             :       END IF
     554             : 
     555        2032 :       CALL dbcsr_create(tmp2, template=matrix_inverse)
     556             : 
     557        2032 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     558             : 
     559             :       ! scale the approximate inverse to be within the convergence radius
     560        2032 :       t1 = m_walltime()
     561             : 
     562             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     563        2032 :                           0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     564             : 
     565        2032 :       IF (accelerator_type == 1) THEN
     566             : 
     567             :          ! scale the matrix to get into the convergence range
     568             :          CALL arnoldi_extremal(tmp1, max_eV, min_eV, threshold=my_eps_lanczos, &
     569        2032 :                                max_iter=my_max_iter_lanczos, converged=arnoldi_converged)
     570             :          !mymat(1)%matrix => tmp1
     571             :          !CALL setup_arnoldi_env(arnoldi_env, mymat, max_iter=30, threshold=1.0E-3_dp, selection_crit=1, &
     572             :          !                        nval_request=2, nrestarts=2, generalized_ev=.FALSE., iram=.TRUE.)
     573             :          !CALL arnoldi_ev(mymat, arnoldi_env)
     574             :          !max_eV = REAL(get_selected_ritz_val(arnoldi_env, 2), dp)
     575             :          !min_eV = REAL(get_selected_ritz_val(arnoldi_env, 1), dp)
     576             :          !CALL deallocate_arnoldi_env(arnoldi_env)
     577             : 
     578        2032 :          IF (unit_nr > 0) THEN
     579         768 :             WRITE (unit_nr, *)
     580         768 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", my_eps_lanczos
     581         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_eV, min_eV
     582         768 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_eV/MAX(min_eV, EPSILON(min_eV))
     583             :          END IF
     584             : 
     585             :          ! 2.0 would be the correct scaling however, we should make sure here, that we are in the convergence radius
     586        2032 :          scalingf = 1.9_dp/(max_eV + min_eV)
     587        2032 :          CALL dbcsr_scale(tmp1, scalingf)
     588        2032 :          CALL dbcsr_scale(matrix_inverse, scalingf)
     589        2032 :          min_ev = min_ev*scalingf
     590             : 
     591             :       END IF
     592             : 
     593             :       ! done with the initial guess, start iterations
     594        2032 :       converged = .FALSE.
     595        8998 :       DO i = 1, 100
     596             : 
     597             :          ! tmp1 = S^-1 S
     598             : 
     599             :          ! for the convergence check
     600        8998 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     601        8998 :          maxnorm_matrix = dbcsr_maxabs(tmp1)
     602        8998 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     603             : 
     604             :          ! tmp2 = S^-1 S S^-1
     605             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_inverse, 0.0_dp, tmp2, &
     606        8998 :                              flop=flop2, filter_eps=my_filter_eps)
     607             :          ! S^-1_{n+1} = 2 S^-1 - S^-1 S S^-1
     608        8998 :          CALL dbcsr_add(matrix_inverse, tmp2, 2.0_dp, -1.0_dp)
     609             : 
     610        8998 :          CALL dbcsr_filter(matrix_inverse, my_filter_eps)
     611        8998 :          t2 = m_walltime()
     612        8998 :          occ_matrix = dbcsr_get_occupation(matrix_inverse)
     613             : 
     614             :          ! use the scalar form of the algorithm to trace the EV
     615        8998 :          IF (accelerator_type == 1) THEN
     616        8998 :             min_ev = min_ev*(2.0_dp - min_ev)
     617        8998 :             IF (PRESENT(norm_convergence)) maxnorm_matrix = ABS(min_eV - 1.0_dp)
     618             :          END IF
     619             : 
     620        8998 :          IF (unit_nr > 0) THEN
     621        3718 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "Hotelling iter", i, occ_matrix, &
     622        3718 :                maxnorm_matrix, t2 - t1, &
     623        7436 :                (flop1 + flop2)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     624        3718 :             CALL m_flush(unit_nr)
     625             :          END IF
     626             : 
     627        8998 :          IF (maxnorm_matrix < convergence) THEN
     628             :             converged = .TRUE.
     629             :             EXIT
     630             :          END IF
     631             : 
     632             :          ! scale the matrix for improved convergence
     633        6966 :          IF (accelerator_type == 1) THEN
     634        6966 :             min_ev = min_ev*2.0_dp/(min_ev + 1.0_dp)
     635        6966 :             CALL dbcsr_scale(matrix_inverse, 2.0_dp/(min_ev + 1.0_dp))
     636             :          END IF
     637             : 
     638        6966 :          t1 = m_walltime()
     639             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_inverse, matrix, &
     640        6966 :                              0.0_dp, tmp1, flop=flop1, filter_eps=my_filter_eps)
     641             : 
     642             :       END DO
     643             : 
     644        2032 :       IF (.NOT. converged) THEN
     645           0 :          CPABORT("Hotelling inversion did not converge")
     646             :       END IF
     647             : 
     648             :       ! try to symmetrize the output matrix
     649        2032 :       IF (dbcsr_get_matrix_type(matrix_inverse) == dbcsr_type_no_symmetry) THEN
     650         100 :          CALL dbcsr_transposed(tmp2, matrix_inverse)
     651        2132 :          CALL dbcsr_add(matrix_inverse, tmp2, 0.5_dp, 0.5_dp)
     652             :       END IF
     653             : 
     654        2032 :       IF (unit_nr > 0) THEN
     655             : !           WRITE(unit_nr,'(T6,A,1X,I3,1X,F10.8,E12.3)') "Final Hotelling ",i,occ_matrix,&
     656             : !              !frob_matrix/frob_matrix_base
     657             : !              maxnorm_matrix
     658         768 :          WRITE (unit_nr, '()')
     659         768 :          CALL m_flush(unit_nr)
     660             :       END IF
     661             : 
     662        2032 :       CALL dbcsr_release(tmp1)
     663        2032 :       CALL dbcsr_release(tmp2)
     664             : 
     665        2032 :       CALL timestop(handle)
     666             : 
     667        2032 :    END SUBROUTINE invert_Hotelling
     668             : 
     669             : ! **************************************************************************************************
     670             : !> \brief compute the sign a matrix using Newton-Schulz iterations
     671             : !> \param matrix_sign ...
     672             : !> \param matrix ...
     673             : !> \param threshold ...
     674             : !> \param sign_order ...
     675             : !> \param iounit ...
     676             : !> \par History
     677             : !>       2010.10 created [Joost VandeVondele]
     678             : !>       2019.05 extended to order byxond 2 [Robert Schade]
     679             : !> \author Joost VandeVondele, Robert Schade
     680             : ! **************************************************************************************************
     681        1058 :    SUBROUTINE matrix_sign_Newton_Schulz(matrix_sign, matrix, threshold, sign_order, iounit)
     682             : 
     683             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
     684             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
     685             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order, iounit
     686             : 
     687             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_Newton_Schulz'
     688             : 
     689             :       INTEGER                                            :: count, handle, i, order, unit_nr
     690             :       INTEGER(KIND=int_8)                                :: flops
     691             :       REAL(KIND=dp)                                      :: a0, a1, a2, a3, a4, a5, floptot, &
     692             :                                                             frob_matrix, frob_matrix_base, &
     693             :                                                             gersh_matrix, occ_matrix, prefactor, &
     694             :                                                             t1, t2
     695             :       TYPE(cp_logger_type), POINTER                      :: logger
     696             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3, tmp4
     697             : 
     698        1058 :       CALL timeset(routineN, handle)
     699             : 
     700        1058 :       IF (PRESENT(iounit)) THEN
     701        1058 :          unit_nr = iounit
     702             :       ELSE
     703           0 :          logger => cp_get_default_logger()
     704           0 :          IF (logger%para_env%is_source()) THEN
     705           0 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
     706             :          ELSE
     707           0 :             unit_nr = -1
     708             :          END IF
     709             :       END IF
     710             : 
     711        1058 :       IF (PRESENT(sign_order)) THEN
     712        1058 :          order = sign_order
     713             :       ELSE
     714             :          order = 2
     715             :       END IF
     716             : 
     717        1058 :       CALL dbcsr_create(tmp1, template=matrix_sign)
     718             : 
     719        1058 :       CALL dbcsr_create(tmp2, template=matrix_sign)
     720        1058 :       IF (ABS(order) .GE. 4) THEN
     721           8 :          CALL dbcsr_create(tmp3, template=matrix_sign)
     722             :       END IF
     723           8 :       IF (ABS(order) .GT. 4) THEN
     724           6 :          CALL dbcsr_create(tmp4, template=matrix_sign)
     725             :       END IF
     726             : 
     727        1058 :       CALL dbcsr_copy(matrix_sign, matrix)
     728        1058 :       CALL dbcsr_filter(matrix_sign, threshold)
     729             : 
     730             :       ! scale the matrix to get into the convergence range
     731        1058 :       frob_matrix = dbcsr_frobenius_norm(matrix_sign)
     732        1058 :       gersh_matrix = dbcsr_gershgorin_norm(matrix_sign)
     733        1058 :       CALL dbcsr_scale(matrix_sign, 1/MIN(frob_matrix, gersh_matrix))
     734             : 
     735        1058 :       IF (unit_nr > 0) WRITE (unit_nr, *)
     736             : 
     737        1058 :       count = 0
     738       13202 :       DO i = 1, 100
     739       13202 :          floptot = 0_dp
     740       13202 :          t1 = m_walltime()
     741             :          ! tmp1 = X * X
     742             :          CALL dbcsr_multiply("N", "N", -1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     743       13202 :                              filter_eps=threshold, flop=flops)
     744       13202 :          floptot = floptot + flops
     745             : 
     746             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
     747       13202 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     748       13202 :          CALL dbcsr_add_on_diag(tmp1, +1.0_dp)
     749       13202 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
     750             : 
     751             :          ! f(y) approx 1/sqrt(1-y)
     752             :          ! f(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     753             :          ! f2(y)=1+y/2=1/2*(2+y)
     754             :          ! f3(y)=1+y/2+3/8*y^2=3/8*(8/3+4/3*y+y^2)
     755             :          ! f4(y)=1+y/2+3/8*y^2+5/16*y^3=5/16*(16/5+8/5*y+6/5*y^2+y^3)
     756             :          ! f5(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4=35/128*(128/35+128/70*y+48/35*y^2+8/7*y^3+y^4)
     757             :          !      z(y)=(y+a_0)*y+a_1
     758             :          ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     759             :          !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     760             :          !    a_0=1/14
     761             :          !    a_1=23819/13720
     762             :          !    a_2=1269/980-2a_1=-3734/1715
     763             :          !    a_3=832591127/188238400
     764             :          ! f6(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5
     765             :          !      =63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     766             :          ! f7(y)=1+y/2+3/8*y^2+5/16*y^3+35/128*y^4+63/256*y^5+231/1024*y^6
     767             :          !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     768             :          ! z(y)=(y+a_0)*y+a_1
     769             :          ! w(y)=(y+a_2)*z(y)+a_3
     770             :          ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     771             :          ! a_0= 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422
     772             :          ! a_1= 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568
     773             :          ! a_2=-1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968
     774             :          ! a_3= 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453
     775             :          ! a_4=-3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667
     776             :          ! a_5= 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578
     777             :          !
     778             :          ! y=1-X*X
     779             : 
     780             :          ! tmp1 = I-x*x
     781             :          IF (order .EQ. 2) THEN
     782       13156 :             prefactor = 0.5_dp
     783             : 
     784             :             ! update the above to 3*I-X*X
     785       13156 :             CALL dbcsr_add_on_diag(tmp1, +2.0_dp)
     786       13156 :             occ_matrix = dbcsr_get_occupation(matrix_sign)
     787             :          ELSE IF (order .EQ. 3) THEN
     788             :             ! with one multiplication
     789             :             ! tmp1=y
     790          12 :             CALL dbcsr_copy(tmp2, tmp1)
     791          12 :             CALL dbcsr_scale(tmp1, 4.0_dp/3.0_dp)
     792          12 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp/3.0_dp)
     793             : 
     794             :             ! tmp2=y^2
     795             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp2, 1.0_dp, tmp1, &
     796          12 :                                 filter_eps=threshold, flop=flops)
     797          12 :             floptot = floptot + flops
     798          12 :             prefactor = 3.0_dp/8.0_dp
     799             : 
     800             :          ELSE IF (order .EQ. 4) THEN
     801             :             ! with two multiplications
     802             :             ! tmp1=y
     803          10 :             CALL dbcsr_copy(tmp3, tmp1)
     804          10 :             CALL dbcsr_scale(tmp1, 8.0_dp/5.0_dp)
     805          10 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp/5.0_dp)
     806             : 
     807             :             !
     808             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     809          10 :                                 filter_eps=threshold, flop=flops)
     810          10 :             floptot = floptot + flops
     811             : 
     812          10 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp/5.0_dp)
     813             : 
     814             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     815          10 :                                 filter_eps=threshold, flop=flops)
     816          10 :             floptot = floptot + flops
     817             : 
     818          10 :             prefactor = 5.0_dp/16.0_dp
     819             :          ELSE IF (order .EQ. -5) THEN
     820             :             ! with three multiplications
     821             :             ! tmp1=y
     822           0 :             CALL dbcsr_copy(tmp3, tmp1)
     823           0 :             CALL dbcsr_scale(tmp1, 128.0_dp/70.0_dp)
     824           0 :             CALL dbcsr_add_on_diag(tmp1, 128.0_dp/35.0_dp)
     825             : 
     826             :             !
     827             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     828           0 :                                 filter_eps=threshold, flop=flops)
     829           0 :             floptot = floptot + flops
     830             : 
     831           0 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 48.0_dp/35.0_dp)
     832             : 
     833             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     834           0 :                                 filter_eps=threshold, flop=flops)
     835           0 :             floptot = floptot + flops
     836             : 
     837           0 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 8.0_dp/7.0_dp)
     838             : 
     839             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 1.0_dp, tmp1, &
     840           0 :                                 filter_eps=threshold, flop=flops)
     841           0 :             floptot = floptot + flops
     842             : 
     843           0 :             prefactor = 35.0_dp/128.0_dp
     844             :          ELSE IF (order .EQ. 5) THEN
     845             :             ! with two multiplications
     846             :             !      z(y)=(y+a_0)*y+a_1
     847             :             ! f5(y)=35/128*((z(y)+y+a_2)*z(y)+a_3)
     848             :             !      =35/128*((a_1^2+a_1a_2+a_3)+(2*a_0a_1+a_1+a_0a_2)y+(a_0^2+a_0+2a_1+a_2)y^2+(2a_0+1)y^3+y^4)
     849             :             !    a_0=1/14
     850             :             !    a_1=23819/13720
     851             :             !    a_2=1269/980-2a_1=-3734/1715
     852             :             !    a_3=832591127/188238400
     853           8 :             a0 = 1.0_dp/14.0_dp
     854           8 :             a1 = 23819.0_dp/13720.0_dp
     855           8 :             a2 = -3734_dp/1715.0_dp
     856           8 :             a3 = 832591127_dp/188238400.0_dp
     857             : 
     858             :             ! tmp1=y
     859             :             ! tmp3=z
     860           8 :             CALL dbcsr_copy(tmp3, tmp1)
     861           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     862             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     863           8 :                                 filter_eps=threshold, flop=flops)
     864           8 :             floptot = floptot + flops
     865           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     866             : 
     867           8 :             CALL dbcsr_add_on_diag(tmp1, a2)
     868           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
     869             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp2, 0.0_dp, tmp3, &
     870           8 :                                 filter_eps=threshold, flop=flops)
     871           8 :             floptot = floptot + flops
     872           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     873           8 :             CALL dbcsr_copy(tmp1, tmp3)
     874             : 
     875           8 :             prefactor = 35.0_dp/128.0_dp
     876             :          ELSE IF (order .EQ. 6) THEN
     877             :             ! with four multiplications
     878             :             ! f6(y)=63/256*(256/63 + (128 y)/63 + (32 y^2)/21 + (80 y^3)/63 + (10 y^4)/9 + y^5)
     879             :             ! tmp1=y
     880           8 :             CALL dbcsr_copy(tmp3, tmp1)
     881           8 :             CALL dbcsr_scale(tmp1, 128.0_dp/63.0_dp)
     882           8 :             CALL dbcsr_add_on_diag(tmp1, 256.0_dp/63.0_dp)
     883             : 
     884             :             !
     885             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp3, 0.0_dp, tmp2, &
     886           8 :                                 filter_eps=threshold, flop=flops)
     887           8 :             floptot = floptot + flops
     888             : 
     889           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 32.0_dp/21.0_dp)
     890             : 
     891             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp4, &
     892           8 :                                 filter_eps=threshold, flop=flops)
     893           8 :             floptot = floptot + flops
     894             : 
     895           8 :             CALL dbcsr_add(tmp1, tmp4, 1.0_dp, 80.0_dp/63.0_dp)
     896             : 
     897             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp3, 0.0_dp, tmp2, &
     898           8 :                                 filter_eps=threshold, flop=flops)
     899           8 :             floptot = floptot + flops
     900             : 
     901           8 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 10.0_dp/9.0_dp)
     902             : 
     903             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 1.0_dp, tmp1, &
     904           8 :                                 filter_eps=threshold, flop=flops)
     905           8 :             floptot = floptot + flops
     906             : 
     907           8 :             prefactor = 63.0_dp/256.0_dp
     908             :          ELSE IF (order .EQ. 7) THEN
     909             :             ! with three multiplications
     910             : 
     911           8 :             a0 = 1.3686502058092053653287666647611728507211996691324048468010382350359929055186612505791532871573242422_dp
     912           8 :             a1 = 1.7089671854477436685850554669524985556296280184497503489303331821456795715195510972774979091893741568_dp
     913           8 :             a2 = -1.3231956603546599107833121193066273961757451236778593922555836895814474509732067051246078326118696968_dp
     914           8 :             a3 = 3.9876642330847931291749479958277754186675336169578593000744380254770411483327581042259415937710270453_dp
     915           8 :             a4 = -3.7273299006476825027065704937541279833880400042556351139273912137942678919776364526511485025132991667_dp
     916           8 :             a5 = 4.9369932474103023792021351907971943220607580694533770325967170245194362399287150565595441897740173578_dp
     917             :             !      =231/1024*(1024/231+512/231*y+128/77*y^2+320/231*y^3+40/33*y^4+12/11*y^5+y^6)
     918             :             ! z(y)=(y+a_0)*y+a_1
     919             :             ! w(y)=(y+a_2)*z(y)+a_3
     920             :             ! f7(y)=(w(y)+z(y)+a_4)*w(y)+a_5
     921             : 
     922             :             ! tmp1=y
     923             :             ! tmp3=z
     924           8 :             CALL dbcsr_copy(tmp3, tmp1)
     925           8 :             CALL dbcsr_add_on_diag(tmp3, a0)
     926             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, tmp1, 0.0_dp, tmp2, &
     927           8 :                                 filter_eps=threshold, flop=flops)
     928           8 :             floptot = floptot + flops
     929           8 :             CALL dbcsr_add_on_diag(tmp2, a1)
     930             : 
     931             :             ! tmp4=w
     932           8 :             CALL dbcsr_copy(tmp4, tmp1)
     933           8 :             CALL dbcsr_add_on_diag(tmp4, a2)
     934             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp4, tmp2, 0.0_dp, tmp3, &
     935           8 :                                 filter_eps=threshold, flop=flops)
     936           8 :             floptot = floptot + flops
     937           8 :             CALL dbcsr_add_on_diag(tmp3, a3)
     938             : 
     939           8 :             CALL dbcsr_add(tmp2, tmp3, 1.0_dp, 1.0_dp)
     940           8 :             CALL dbcsr_add_on_diag(tmp2, a4)
     941             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
     942           8 :                                 filter_eps=threshold, flop=flops)
     943           8 :             floptot = floptot + flops
     944           8 :             CALL dbcsr_add_on_diag(tmp1, a5)
     945             : 
     946           8 :             prefactor = 231.0_dp/1024.0_dp
     947             :          ELSE
     948           0 :             CPABORT("requested order is not implemented.")
     949             :          END IF
     950             : 
     951             :          ! tmp2 = X * prefactor *
     952             :          CALL dbcsr_multiply("N", "N", prefactor, matrix_sign, tmp1, 0.0_dp, tmp2, &
     953       13202 :                              filter_eps=threshold, flop=flops)
     954       13202 :          floptot = floptot + flops
     955             : 
     956             :          ! done iterating
     957             :          ! CALL dbcsr_filter(tmp2,threshold)
     958       13202 :          CALL dbcsr_copy(matrix_sign, tmp2)
     959       13202 :          t2 = m_walltime()
     960             : 
     961       13202 :          occ_matrix = dbcsr_get_occupation(matrix_sign)
     962             : 
     963       13202 :          IF (unit_nr > 0) THEN
     964           0 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sign iter ", i, occ_matrix, &
     965           0 :                frob_matrix/frob_matrix_base, t2 - t1, &
     966           0 :                floptot/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
     967           0 :             CALL m_flush(unit_nr)
     968             :          END IF
     969             : 
     970             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
     971       13202 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) EXIT
     972             : 
     973             :       END DO
     974             : 
     975             :       ! this check is not really needed
     976             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
     977        1058 :                           filter_eps=threshold)
     978        1058 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
     979        1058 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
     980        1058 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
     981        1058 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
     982        1058 :       IF (unit_nr > 0) THEN
     983           0 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sign iter", i, occ_matrix, &
     984           0 :             frob_matrix/frob_matrix_base
     985           0 :          WRITE (unit_nr, '()')
     986           0 :          CALL m_flush(unit_nr)
     987             :       END IF
     988             : 
     989        1058 :       CALL dbcsr_release(tmp1)
     990        1058 :       CALL dbcsr_release(tmp2)
     991        1058 :       IF (ABS(order) .GE. 4) THEN
     992           8 :          CALL dbcsr_release(tmp3)
     993             :       END IF
     994           8 :       IF (ABS(order) .GT. 4) THEN
     995           6 :          CALL dbcsr_release(tmp4)
     996             :       END IF
     997             : 
     998        1058 :       CALL timestop(handle)
     999             : 
    1000        1058 :    END SUBROUTINE matrix_sign_Newton_Schulz
    1001             : 
    1002             :    ! **************************************************************************************************
    1003             : !> \brief compute the sign a matrix using the general algorithm for the p-th root of Richters et al.
    1004             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1005             : !> \param matrix_sign ...
    1006             : !> \param matrix ...
    1007             : !> \param threshold ...
    1008             : !> \param sign_order ...
    1009             : !> \par History
    1010             : !>       2019.03 created [Robert Schade]
    1011             : !> \author Robert Schade
    1012             : ! **************************************************************************************************
    1013          16 :    SUBROUTINE matrix_sign_proot(matrix_sign, matrix, threshold, sign_order)
    1014             : 
    1015             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1016             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1017             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1018             : 
    1019             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sign_proot'
    1020             : 
    1021             :       INTEGER                                            :: handle, order, unit_nr
    1022             :       INTEGER(KIND=int_8)                                :: flop0, flop1, flop2
    1023             :       LOGICAL                                            :: converged, symmetrize
    1024             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, occ_matrix
    1025             :       TYPE(cp_logger_type), POINTER                      :: logger
    1026             :       TYPE(dbcsr_type)                                   :: matrix2, matrix_sqrt, matrix_sqrt_inv, &
    1027             :                                                             tmp1, tmp2
    1028             : 
    1029           8 :       CALL cite_reference(Richters2018)
    1030             : 
    1031           8 :       CALL timeset(routineN, handle)
    1032             : 
    1033           8 :       logger => cp_get_default_logger()
    1034           8 :       IF (logger%para_env%is_source()) THEN
    1035           4 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1036             :       ELSE
    1037           4 :          unit_nr = -1
    1038             :       END IF
    1039             : 
    1040           8 :       IF (PRESENT(sign_order)) THEN
    1041           8 :          order = sign_order
    1042             :       ELSE
    1043           0 :          order = 2
    1044             :       END IF
    1045             : 
    1046           8 :       CALL dbcsr_create(tmp1, template=matrix_sign)
    1047             : 
    1048           8 :       CALL dbcsr_create(tmp2, template=matrix_sign)
    1049             : 
    1050           8 :       CALL dbcsr_create(matrix2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1051             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix, 0.0_dp, matrix2, &
    1052           8 :                           filter_eps=threshold, flop=flop0)
    1053             :       !CALL dbcsr_filter(matrix2, threshold)
    1054             : 
    1055             :       !CALL dbcsr_copy(matrix_sign, matrix)
    1056             :       !CALL dbcsr_filter(matrix_sign, threshold)
    1057             : 
    1058           8 :       IF (unit_nr > 0) WRITE (unit_nr, *)
    1059             : 
    1060           8 :       CALL dbcsr_create(matrix_sqrt, template=matrix2)
    1061           8 :       CALL dbcsr_create(matrix_sqrt_inv, template=matrix2)
    1062           8 :       IF (unit_nr > 0) WRITE (unit_nr, *) "Threshold=", threshold
    1063             : 
    1064           8 :       symmetrize = .FALSE.
    1065             :       CALL matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix2, threshold, order, &
    1066           8 :                              0.01_dp, 100, symmetrize, converged)
    1067             : 
    1068             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix, matrix_sqrt_inv, 0.0_dp, matrix_sign, &
    1069           8 :                           filter_eps=threshold, flop=flop1)
    1070             : 
    1071             :       ! this check is not really needed
    1072             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sign, matrix_sign, 0.0_dp, tmp1, &
    1073           8 :                           filter_eps=threshold, flop=flop2)
    1074           8 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1075           8 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1076           8 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1077           8 :       occ_matrix = dbcsr_get_occupation(matrix_sign)
    1078           8 :       IF (unit_nr > 0) THEN
    1079           4 :          WRITE (unit_nr, '(T6,A,F10.8,E12.3)') "Final proot sign iter", occ_matrix, &
    1080           8 :             frob_matrix/frob_matrix_base
    1081           4 :          WRITE (unit_nr, '()')
    1082           4 :          CALL m_flush(unit_nr)
    1083             :       END IF
    1084             : 
    1085           8 :       CALL dbcsr_release(tmp1)
    1086           8 :       CALL dbcsr_release(tmp2)
    1087           8 :       CALL dbcsr_release(matrix2)
    1088           8 :       CALL dbcsr_release(matrix_sqrt)
    1089           8 :       CALL dbcsr_release(matrix_sqrt_inv)
    1090             : 
    1091           8 :       CALL timestop(handle)
    1092             : 
    1093           8 :    END SUBROUTINE matrix_sign_proot
    1094             : 
    1095             : ! **************************************************************************************************
    1096             : !> \brief compute the sign of a dense matrix using Newton-Schulz iterations
    1097             : !> \param matrix_sign ...
    1098             : !> \param matrix ...
    1099             : !> \param matrix_id ...
    1100             : !> \param threshold ...
    1101             : !> \param sign_order ...
    1102             : !> \author Michael Lass, Robert Schade
    1103             : ! **************************************************************************************************
    1104           2 :    SUBROUTINE dense_matrix_sign_Newton_Schulz(matrix_sign, matrix, matrix_id, threshold, sign_order)
    1105             : 
    1106             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: matrix_sign
    1107             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: matrix
    1108             :       INTEGER, INTENT(IN), OPTIONAL                      :: matrix_id
    1109             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1110             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1111             : 
    1112             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dense_matrix_sign_Newton_Schulz'
    1113             : 
    1114             :       INTEGER                                            :: handle, i, j, sz, unit_nr
    1115             :       LOGICAL                                            :: converged
    1116             :       REAL(KIND=dp)                                      :: frob_matrix, frob_matrix_base, &
    1117             :                                                             gersh_matrix, prefactor, scaling_factor
    1118           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp1, tmp2
    1119             :       REAL(KIND=dp), DIMENSION(1)                        :: work
    1120             :       REAL(KIND=dp), EXTERNAL                            :: dlange
    1121             :       TYPE(cp_logger_type), POINTER                      :: logger
    1122             : 
    1123           2 :       CALL timeset(routineN, handle)
    1124             : 
    1125             :       ! print output on all ranks
    1126           2 :       logger => cp_get_default_logger()
    1127           2 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1128             : 
    1129             :       ! scale the matrix to get into the convergence range
    1130           2 :       sz = SIZE(matrix, 1)
    1131           2 :       frob_matrix = dlange('F', sz, sz, matrix, sz, work) !dbcsr_frobenius_norm(matrix_sign)
    1132           2 :       gersh_matrix = dlange('1', sz, sz, matrix, sz, work) !dbcsr_gershgorin_norm(matrix_sign)
    1133           2 :       scaling_factor = 1/MIN(frob_matrix, gersh_matrix)
    1134          86 :       matrix_sign = matrix*scaling_factor
    1135           8 :       ALLOCATE (tmp1(sz, sz))
    1136           6 :       ALLOCATE (tmp2(sz, sz))
    1137             : 
    1138           2 :       converged = .FALSE.
    1139          14 :       DO i = 1, 100
    1140          14 :          CALL dgemm('N', 'N', sz, sz, sz, -1.0_dp, matrix_sign, sz, matrix_sign, sz, 0.0_dp, tmp1, sz)
    1141             : 
    1142             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1143          14 :          frob_matrix_base = dlange('F', sz, sz, tmp1, sz, work)
    1144          98 :          DO j = 1, sz
    1145          98 :             tmp1(j, j) = tmp1(j, j) + 1.0_dp
    1146             :          END DO
    1147          14 :          frob_matrix = dlange('F', sz, sz, tmp1, sz, work)
    1148             : 
    1149          14 :          IF (sign_order .EQ. 2) THEN
    1150           8 :             prefactor = 0.5_dp
    1151             :             ! update the above to 3*I-X*X
    1152          56 :             DO j = 1, sz
    1153          56 :                tmp1(j, j) = tmp1(j, j) + 2.0_dp
    1154             :             END DO
    1155           6 :          ELSE IF (sign_order .EQ. 3) THEN
    1156         258 :             tmp2(:, :) = tmp1
    1157         258 :             tmp1 = tmp1*4.0_dp/3.0_dp
    1158          42 :             DO j = 1, sz
    1159          42 :                tmp1(j, j) = tmp1(j, j) + 8.0_dp/3.0_dp
    1160             :             END DO
    1161           6 :             CALL dgemm('N', 'N', sz, sz, sz, 1.0_dp, tmp2, sz, tmp2, sz, 1.0_dp, tmp1, sz)
    1162           6 :             prefactor = 3.0_dp/8.0_dp
    1163             :          ELSE
    1164           0 :             CPABORT("requested order is not implemented.")
    1165             :          END IF
    1166             : 
    1167          14 :          CALL dgemm('N', 'N', sz, sz, sz, prefactor, matrix_sign, sz, tmp1, sz, 0.0_dp, tmp2, sz)
    1168         602 :          matrix_sign = tmp2
    1169             : 
    1170             :          ! frob_matrix/frob_matrix_base < SQRT(threshold)
    1171          14 :          IF (frob_matrix*frob_matrix < (threshold*frob_matrix_base*frob_matrix_base)) THEN
    1172             :             WRITE (unit_nr, '(T6,A,1X,I6,1X,A,1X,I3,E12.3)') &
    1173           2 :                "Submatrix", matrix_id, "final NS sign iter", i, frob_matrix/frob_matrix_base
    1174           2 :             CALL m_flush(unit_nr)
    1175             :             converged = .TRUE.
    1176             :             EXIT
    1177             :          END IF
    1178             :       END DO
    1179             : 
    1180             :       IF (.NOT. converged) &
    1181           0 :          CPABORT("dense_matrix_sign_Newton_Schulz did not converge within 100 iterations")
    1182             : 
    1183           2 :       DEALLOCATE (tmp1)
    1184           2 :       DEALLOCATE (tmp2)
    1185             : 
    1186           2 :       CALL timestop(handle)
    1187             : 
    1188           2 :    END SUBROUTINE dense_matrix_sign_Newton_Schulz
    1189             : 
    1190             : ! **************************************************************************************************
    1191             : !> \brief Perform eigendecomposition of a dense matrix
    1192             : !> \param sm ...
    1193             : !> \param N ...
    1194             : !> \param eigvals ...
    1195             : !> \param eigvecs ...
    1196             : !> \par History
    1197             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1198             : !> \author Michael Lass, Robert Schade
    1199             : ! **************************************************************************************************
    1200           4 :    SUBROUTINE eigdecomp(sm, N, eigvals, eigvecs)
    1201             :       INTEGER, INTENT(IN)                                :: N
    1202             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1203             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1204             :          INTENT(OUT)                                     :: eigvals
    1205             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1206             :          INTENT(OUT)                                     :: eigvecs
    1207             : 
    1208             :       INTEGER                                            :: info, liwork, lwork
    1209           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: iwork
    1210           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: work
    1211           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: tmp
    1212             : 
    1213          24 :       ALLOCATE (eigvecs(N, N), tmp(N, N))
    1214          12 :       ALLOCATE (eigvals(N))
    1215             : 
    1216             :       ! symmetrize sm
    1217         172 :       eigvecs(:, :) = 0.5*(sm + TRANSPOSE(sm))
    1218             : 
    1219             :       ! probe optimal sizes for WORK and IWORK
    1220           4 :       LWORK = -1
    1221           4 :       LIWORK = -1
    1222           4 :       ALLOCATE (WORK(1))
    1223           4 :       ALLOCATE (IWORK(1))
    1224           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1225           4 :       LWORK = INT(WORK(1))
    1226           4 :       LIWORK = INT(IWORK(1))
    1227           4 :       DEALLOCATE (IWORK)
    1228           4 :       DEALLOCATE (WORK)
    1229             : 
    1230             :       ! calculate eigenvalues and eigenvectors
    1231          12 :       ALLOCATE (WORK(LWORK))
    1232          12 :       ALLOCATE (IWORK(LIWORK))
    1233           4 :       CALL dsyevd('V', 'U', N, eigvecs, N, eigvals, WORK, LWORK, IWORK, LIWORK, INFO)
    1234           4 :       DEALLOCATE (IWORK)
    1235           4 :       DEALLOCATE (WORK)
    1236           4 :       IF (INFO .NE. 0) CPABORT("dsyevd did not succeed")
    1237             : 
    1238           4 :       DEALLOCATE (tmp)
    1239           4 :    END SUBROUTINE eigdecomp
    1240             : 
    1241             : ! **************************************************************************************************
    1242             : !> \brief Calculate the sign matrix from eigenvalues and eigenvectors of a matrix
    1243             : !> \param sm_sign ...
    1244             : !> \param eigvals ...
    1245             : !> \param eigvecs ...
    1246             : !> \param N ...
    1247             : !> \param mu_correction ...
    1248             : !> \par History
    1249             : !>       2020.05 Extracted from dense_matrix_sign_direct [Michael Lass]
    1250             : !> \author Michael Lass, Robert Schade
    1251             : ! **************************************************************************************************
    1252           4 :    SUBROUTINE sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, mu_correction)
    1253             :       INTEGER                                            :: N
    1254             :       REAL(KIND=dp), INTENT(IN)                          :: eigvecs(N, N), eigvals(N)
    1255             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1256             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1257             : 
    1258             :       INTEGER                                            :: i
    1259           4 :       REAL(KIND=dp)                                      :: modified_eigval, tmp(N, N)
    1260             : 
    1261         172 :       sm_sign = 0
    1262          28 :       DO i = 1, N
    1263          24 :          modified_eigval = eigvals(i) - mu_correction
    1264          28 :          IF (modified_eigval > 0) THEN
    1265           6 :             sm_sign(i, i) = 1.0
    1266          18 :          ELSE IF (modified_eigval < 0) THEN
    1267          18 :             sm_sign(i, i) = -1.0
    1268             :          ELSE
    1269           0 :             sm_sign(i, i) = 0.0
    1270             :          END IF
    1271             :       END DO
    1272             : 
    1273             :       ! Create matrix with eigenvalues in {-1,0,1} and eigenvectors of sm:
    1274             :       ! sm_sign = eigvecs * sm_sign * eigvecs.T
    1275           4 :       CALL dgemm('N', 'N', N, N, N, 1.0_dp, eigvecs, N, sm_sign, N, 0.0_dp, tmp, N)
    1276           4 :       CALL dgemm('N', 'T', N, N, N, 1.0_dp, tmp, N, eigvecs, N, 0.0_dp, sm_sign, N)
    1277           4 :    END SUBROUTINE sign_from_eigdecomp
    1278             : 
    1279             : ! **************************************************************************************************
    1280             : !> \brief Compute partial trace of a matrix from its eigenvalues and eigenvectors
    1281             : !> \param eigvals ...
    1282             : !> \param eigvecs ...
    1283             : !> \param firstcol ...
    1284             : !> \param lastcol ...
    1285             : !> \param mu_correction ...
    1286             : !> \return ...
    1287             : !> \par History
    1288             : !>       2020.05 Created [Michael Lass]
    1289             : !> \author Michael Lass
    1290             : ! **************************************************************************************************
    1291          36 :    FUNCTION trace_from_eigdecomp(eigvals, eigvecs, firstcol, lastcol, mu_correction) RESULT(trace)
    1292             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1293             :          INTENT(IN)                                      :: eigvals
    1294             :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1295             :          INTENT(IN)                                      :: eigvecs
    1296             :       INTEGER, INTENT(IN)                                :: firstcol, lastcol
    1297             :       REAL(KIND=dp), INTENT(IN)                          :: mu_correction
    1298             :       REAL(KIND=dp)                                      :: trace
    1299             : 
    1300             :       INTEGER                                            :: i, j, sm_size
    1301             :       REAL(KIND=dp)                                      :: modified_eigval, tmpsum
    1302          36 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: mapped_eigvals
    1303             : 
    1304          36 :       sm_size = SIZE(eigvals)
    1305         108 :       ALLOCATE (mapped_eigvals(sm_size))
    1306             : 
    1307         252 :       DO i = 1, sm_size
    1308         216 :          modified_eigval = eigvals(i) - mu_correction
    1309         252 :          IF (modified_eigval > 0) THEN
    1310          26 :             mapped_eigvals(i) = 1.0
    1311         190 :          ELSE IF (modified_eigval < 0) THEN
    1312         190 :             mapped_eigvals(i) = -1.0
    1313             :          ELSE
    1314           0 :             mapped_eigvals(i) = 0.0
    1315             :          END IF
    1316             :       END DO
    1317             : 
    1318          36 :       trace = 0.0_dp
    1319         252 :       DO i = firstcol, lastcol
    1320             :          tmpsum = 0.0_dp
    1321        1512 :          DO j = 1, sm_size
    1322        1512 :             tmpsum = tmpsum + (eigvecs(i, j)*mapped_eigvals(j)*eigvecs(i, j))
    1323             :          END DO
    1324         252 :          trace = trace - 0.5_dp*tmpsum + 0.5_dp
    1325             :       END DO
    1326          36 :    END FUNCTION trace_from_eigdecomp
    1327             : 
    1328             : ! **************************************************************************************************
    1329             : !> \brief Calculate the sign matrix by direct calculation of all eigenvalues and eigenvectors
    1330             : !> \param sm_sign ...
    1331             : !> \param sm ...
    1332             : !> \param N ...
    1333             : !> \par History
    1334             : !>       2020.02 Created [Michael Lass, Robert Schade]
    1335             : !>       2020.05 Extracted eigdecomp and sign_from_eigdecomp [Michael Lass]
    1336             : !> \author Michael Lass, Robert Schade
    1337             : ! **************************************************************************************************
    1338           2 :    SUBROUTINE dense_matrix_sign_direct(sm_sign, sm, N)
    1339             :       INTEGER, INTENT(IN)                                :: N
    1340             :       REAL(KIND=dp), INTENT(IN)                          :: sm(N, N)
    1341             :       REAL(KIND=dp), INTENT(INOUT)                       :: sm_sign(N, N)
    1342             : 
    1343           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: eigvals
    1344           2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: eigvecs
    1345             : 
    1346             :       CALL eigdecomp(sm, N, eigvals, eigvecs)
    1347           2 :       CALL sign_from_eigdecomp(sm_sign, eigvals, eigvecs, N, 0.0_dp)
    1348             : 
    1349           2 :       DEALLOCATE (eigvals, eigvecs)
    1350           2 :    END SUBROUTINE dense_matrix_sign_direct
    1351             : 
    1352             : ! **************************************************************************************************
    1353             : !> \brief Submatrix method
    1354             : !> \param matrix_sign ...
    1355             : !> \param matrix ...
    1356             : !> \param threshold ...
    1357             : !> \param sign_order ...
    1358             : !> \param submatrix_sign_method ...
    1359             : !> \par History
    1360             : !>       2019.03 created [Robert Schade]
    1361             : !>       2019.06 impl. submatrix method [Michael Lass]
    1362             : !> \author Robert Schade, Michael Lass
    1363             : ! **************************************************************************************************
    1364           6 :    SUBROUTINE matrix_sign_submatrix(matrix_sign, matrix, threshold, sign_order, submatrix_sign_method)
    1365             : 
    1366             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1367             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1368             :       INTEGER, INTENT(IN), OPTIONAL                      :: sign_order
    1369             :       INTEGER, INTENT(IN)                                :: submatrix_sign_method
    1370             : 
    1371             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix'
    1372             : 
    1373             :       INTEGER                                            :: handle, i, myrank, nblkcols, order, &
    1374             :                                                             sm_size, unit_nr
    1375           6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1376           6 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign
    1377             :       TYPE(cp_logger_type), POINTER                      :: logger
    1378             :       TYPE(dbcsr_distribution_type)                      :: dist
    1379           6 :       TYPE(submatrix_dissection_type)                    :: dissection
    1380             : 
    1381           6 :       CALL timeset(routineN, handle)
    1382             : 
    1383             :       ! print output on all ranks
    1384           6 :       logger => cp_get_default_logger()
    1385           6 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1386             : 
    1387           6 :       IF (PRESENT(sign_order)) THEN
    1388           6 :          order = sign_order
    1389             :       ELSE
    1390           0 :          order = 2
    1391             :       END IF
    1392             : 
    1393           6 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist)
    1394           6 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1395             : 
    1396           6 :       CALL dissection%init(matrix)
    1397           6 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1398             : 
    1399             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1400             :       !$OMP          PRIVATE(sm, sm_sign, sm_size) &
    1401           6 :       !$OMP          SHARED(dissection, myrank, my_sms, order, submatrix_sign_method, threshold, unit_nr)
    1402             :       !$OMP DO SCHEDULE(GUIDED)
    1403             :       DO i = 1, SIZE(my_sms)
    1404             :          WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "processing submatrix", my_sms(i)
    1405             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1406             :          sm_size = SIZE(sm, 1)
    1407             :          ALLOCATE (sm_sign(sm_size, sm_size))
    1408             :          SELECT CASE (submatrix_sign_method)
    1409             :          CASE (ls_scf_submatrix_sign_ns)
    1410             :             CALL dense_matrix_sign_Newton_Schulz(sm_sign, sm, my_sms(i), threshold, order)
    1411             :          CASE (ls_scf_submatrix_sign_direct, ls_scf_submatrix_sign_direct_muadj, ls_scf_submatrix_sign_direct_muadj_lowmem)
    1412             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1413             :          CASE DEFAULT
    1414             :             CPABORT("Unkown submatrix sign method.")
    1415             :          END SELECT
    1416             :          CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1417             :          DEALLOCATE (sm, sm_sign)
    1418             :       END DO
    1419             :       !$OMP END DO
    1420             :       !$OMP END PARALLEL
    1421             : 
    1422           6 :       CALL dissection%communicate_results(matrix_sign)
    1423           6 :       CALL dissection%final
    1424             : 
    1425           6 :       CALL timestop(handle)
    1426             : 
    1427          12 :    END SUBROUTINE matrix_sign_submatrix
    1428             : 
    1429             : ! **************************************************************************************************
    1430             : !> \brief Submatrix method with internal adjustment of chemical potential
    1431             : !> \param matrix_sign ...
    1432             : !> \param matrix ...
    1433             : !> \param mu ...
    1434             : !> \param nelectron ...
    1435             : !> \param threshold ...
    1436             : !> \param variant ...
    1437             : !> \par History
    1438             : !>       2020.05 Created [Michael Lass]
    1439             : !> \author Robert Schade, Michael Lass
    1440             : ! **************************************************************************************************
    1441           4 :    SUBROUTINE matrix_sign_submatrix_mu_adjust(matrix_sign, matrix, mu, nelectron, threshold, variant)
    1442             : 
    1443             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sign, matrix
    1444             :       REAL(KIND=dp), INTENT(INOUT)                       :: mu
    1445             :       INTEGER, INTENT(IN)                                :: nelectron
    1446             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1447             :       INTEGER, INTENT(IN)                                :: variant
    1448             : 
    1449             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sign_submatrix_mu_adjust'
    1450             :       REAL(KIND=dp), PARAMETER                           :: initial_increment = 0.01_dp
    1451             : 
    1452             :       INTEGER                                            :: handle, i, j, myrank, nblkcols, &
    1453             :                                                             sm_firstcol, sm_lastcol, sm_size, &
    1454             :                                                             unit_nr
    1455           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: my_sms
    1456             :       LOGICAL                                            :: has_mu_high, has_mu_low
    1457             :       REAL(KIND=dp)                                      :: increment, mu_high, mu_low, new_mu, trace
    1458           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: sm, sm_sign, tmp
    1459             :       TYPE(cp_logger_type), POINTER                      :: logger
    1460             :       TYPE(dbcsr_distribution_type)                      :: dist
    1461           4 :       TYPE(eigbuf), ALLOCATABLE, DIMENSION(:)            :: eigbufs
    1462             :       TYPE(mp_comm_type)                                 :: group
    1463           4 :       TYPE(submatrix_dissection_type)                    :: dissection
    1464             : 
    1465           4 :       CALL timeset(routineN, handle)
    1466             : 
    1467             :       ! print output on all ranks
    1468           4 :       logger => cp_get_default_logger()
    1469           4 :       unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1470             : 
    1471           4 :       CALL dbcsr_get_info(matrix=matrix, nblkcols_total=nblkcols, distribution=dist, group=group)
    1472           4 :       CALL dbcsr_distribution_get(dist=dist, mynode=myrank)
    1473             : 
    1474           4 :       CALL dissection%init(matrix)
    1475           4 :       CALL dissection%get_sm_ids_for_rank(myrank, my_sms)
    1476             : 
    1477          12 :       ALLOCATE (eigbufs(SIZE(my_sms)))
    1478             : 
    1479             :       trace = 0.0_dp
    1480             : 
    1481             :       !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1482             :       !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j, tmp) &
    1483             :       !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, threshold, variant) &
    1484           4 :       !$OMP          REDUCTION(+:trace)
    1485             :       !$OMP DO SCHEDULE(GUIDED)
    1486             :       DO i = 1, SIZE(my_sms)
    1487             :          CALL dissection%generate_submatrix(my_sms(i), sm)
    1488             :          sm_size = SIZE(sm, 1)
    1489             :          WRITE (unit_nr, *) "Rank", myrank, "processing submatrix", my_sms(i), "size", sm_size
    1490             : 
    1491             :          CALL dissection%get_relevant_sm_columns(my_sms(i), sm_firstcol, sm_lastcol)
    1492             : 
    1493             :          IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1494             :             ! Store all eigenvectors in buffer. We will use it to compute sm_sign at the end.
    1495             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=eigbufs(i)%eigvecs)
    1496             :          ELSE
    1497             :             ! Only store eigenvectors that are required for mu adjustment.
    1498             :             ! Calculate sm_sign right away in the hope that mu is already correct.
    1499             :             CALL eigdecomp(sm, sm_size, eigvals=eigbufs(i)%eigvals, eigvecs=tmp)
    1500             :             ALLOCATE (eigbufs(i)%eigvecs(sm_firstcol:sm_lastcol, 1:sm_size))
    1501             :             eigbufs(i)%eigvecs(:, :) = tmp(sm_firstcol:sm_lastcol, 1:sm_size)
    1502             : 
    1503             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1504             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, tmp, sm_size, 0.0_dp)
    1505             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1506             :             DEALLOCATE (sm_sign, tmp)
    1507             :          END IF
    1508             : 
    1509             :          DEALLOCATE (sm)
    1510             :          trace = trace + trace_from_eigdecomp(eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_firstcol, sm_lastcol, 0.0_dp)
    1511             :       END DO
    1512             :       !$OMP END DO
    1513             :       !$OMP END PARALLEL
    1514             : 
    1515           4 :       has_mu_low = .FALSE.
    1516           4 :       has_mu_high = .FALSE.
    1517           4 :       increment = initial_increment
    1518           4 :       new_mu = mu
    1519          72 :       DO i = 1, 30
    1520          72 :          CALL group%sum(trace)
    1521          72 :          IF (unit_nr > 0) WRITE (unit_nr, '(T2,A,1X,F13.9,1X,F15.9)') &
    1522          72 :             "Density matrix:  mu, trace error: ", new_mu, trace - nelectron
    1523          72 :          IF (ABS(trace - nelectron) < 0.5_dp) EXIT
    1524          68 :          IF (trace < nelectron) THEN
    1525           8 :             mu_low = new_mu
    1526           8 :             new_mu = new_mu + increment
    1527           8 :             has_mu_low = .TRUE.
    1528           8 :             increment = increment*2
    1529             :          ELSE
    1530          60 :             mu_high = new_mu
    1531          60 :             new_mu = new_mu - increment
    1532          60 :             has_mu_high = .TRUE.
    1533          60 :             increment = increment*2
    1534             :          END IF
    1535             : 
    1536          68 :          IF (has_mu_low .AND. has_mu_high) THEN
    1537          20 :             new_mu = (mu_low + mu_high)/2
    1538          20 :             IF (ABS(mu_high - mu_low) < threshold) EXIT
    1539             :          END IF
    1540             : 
    1541             :          trace = 0
    1542             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1543             :          !$OMP          PRIVATE(i, sm_sign, tmp, sm_size, sm_firstcol, sm_lastcol) &
    1544             :          !$OMP          SHARED(dissection, my_sms, unit_nr, eigbufs, mu, new_mu, nelectron) &
    1545          72 :          !$OMP          REDUCTION(+:trace)
    1546             :          !$OMP DO SCHEDULE(GUIDED)
    1547             :          DO j = 1, SIZE(my_sms)
    1548             :             sm_size = SIZE(eigbufs(j)%eigvals)
    1549             :             CALL dissection%get_relevant_sm_columns(my_sms(j), sm_firstcol, sm_lastcol)
    1550             :             trace = trace + trace_from_eigdecomp(eigbufs(j)%eigvals, eigbufs(j)%eigvecs, sm_firstcol, sm_lastcol, new_mu - mu)
    1551             :          END DO
    1552             :          !$OMP END DO
    1553             :          !$OMP END PARALLEL
    1554             :       END DO
    1555             : 
    1556             :       ! Finalize sign matrix from eigendecompositions if we kept all eigenvectors
    1557           4 :       IF (variant .EQ. ls_scf_submatrix_sign_direct_muadj) THEN
    1558             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1559             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1560           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1561             :          !$OMP DO SCHEDULE(GUIDED)
    1562             :          DO i = 1, SIZE(my_sms)
    1563             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "finalizing submatrix", my_sms(i)
    1564             :             sm_size = SIZE(eigbufs(i)%eigvals)
    1565             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1566             :             CALL sign_from_eigdecomp(sm_sign, eigbufs(i)%eigvals, eigbufs(i)%eigvecs, sm_size, new_mu - mu)
    1567             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1568             :             DEALLOCATE (sm_sign)
    1569             :          END DO
    1570             :          !$OMP END DO
    1571             :          !$OMP END PARALLEL
    1572             :       END IF
    1573             : 
    1574           6 :       DEALLOCATE (eigbufs)
    1575             : 
    1576             :       ! If we only stored parts of the eigenvectors and mu has changed, we need to recompute sm_sign
    1577           4 :       IF ((variant .EQ. ls_scf_submatrix_sign_direct_muadj_lowmem) .AND. (mu .NE. new_mu)) THEN
    1578             :          !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
    1579             :          !$OMP          PRIVATE(sm, sm_sign, sm_size, sm_firstcol, sm_lastcol, j) &
    1580           2 :          !$OMP          SHARED(dissection, myrank, my_sms, unit_nr, eigbufs, mu, new_mu)
    1581             :          !$OMP DO SCHEDULE(GUIDED)
    1582             :          DO i = 1, SIZE(my_sms)
    1583             :             WRITE (unit_nr, '(T3,A,1X,I4,1X,A,1X,I6)') "Rank", myrank, "reprocessing submatrix", my_sms(i)
    1584             :             CALL dissection%generate_submatrix(my_sms(i), sm)
    1585             :             sm_size = SIZE(sm, 1)
    1586             :             DO j = 1, sm_size
    1587             :                sm(j, j) = sm(j, j) + mu - new_mu
    1588             :             END DO
    1589             :             ALLOCATE (sm_sign(sm_size, sm_size))
    1590             :             CALL dense_matrix_sign_direct(sm_sign, sm, sm_size)
    1591             :             CALL dissection%copy_resultcol(my_sms(i), sm_sign)
    1592             :             DEALLOCATE (sm, sm_sign)
    1593             :          END DO
    1594             :          !$OMP END DO
    1595             :          !$OMP END PARALLEL
    1596             :       END IF
    1597             : 
    1598           4 :       mu = new_mu
    1599             : 
    1600           4 :       CALL dissection%communicate_results(matrix_sign)
    1601           4 :       CALL dissection%final
    1602             : 
    1603           4 :       CALL timestop(handle)
    1604             : 
    1605          12 :    END SUBROUTINE matrix_sign_submatrix_mu_adjust
    1606             : 
    1607             : ! **************************************************************************************************
    1608             : !> \brief compute the sqrt of a matrix via the sign function and the corresponding Newton-Schulz iterations
    1609             : !>        the order of the algorithm should be 2..5, 3 or 5 is recommended
    1610             : !> \param matrix_sqrt ...
    1611             : !> \param matrix_sqrt_inv ...
    1612             : !> \param matrix ...
    1613             : !> \param threshold ...
    1614             : !> \param order ...
    1615             : !> \param eps_lanczos ...
    1616             : !> \param max_iter_lanczos ...
    1617             : !> \param symmetrize ...
    1618             : !> \param converged ...
    1619             : !> \param iounit ...
    1620             : !> \par History
    1621             : !>       2010.10 created [Joost VandeVondele]
    1622             : !> \author Joost VandeVondele
    1623             : ! **************************************************************************************************
    1624       14414 :    SUBROUTINE matrix_sqrt_Newton_Schulz(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1625             :                                         eps_lanczos, max_iter_lanczos, symmetrize, converged, iounit)
    1626             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1627             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1628             :       INTEGER, INTENT(IN)                                :: order
    1629             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1630             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1631             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1632             :       INTEGER, INTENT(IN), OPTIONAL                      :: iounit
    1633             : 
    1634             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_sqrt_Newton_Schulz'
    1635             : 
    1636             :       INTEGER                                            :: handle, i, unit_nr
    1637             :       INTEGER(KIND=int_8)                                :: flop1, flop2, flop3, flop4, flop5
    1638             :       LOGICAL                                            :: arnoldi_converged, tsym
    1639             :       REAL(KIND=dp)                                      :: a, b, c, conv, d, frob_matrix, &
    1640             :                                                             frob_matrix_base, gershgorin_norm, &
    1641             :                                                             max_ev, min_ev, oa, ob, oc, &
    1642             :                                                             occ_matrix, od, scaling, t1, t2
    1643             :       TYPE(cp_logger_type), POINTER                      :: logger
    1644             :       TYPE(dbcsr_type)                                   :: tmp1, tmp2, tmp3
    1645             : 
    1646       14414 :       CALL timeset(routineN, handle)
    1647             : 
    1648       14414 :       IF (PRESENT(iounit)) THEN
    1649       13062 :          unit_nr = iounit
    1650             :       ELSE
    1651        1352 :          logger => cp_get_default_logger()
    1652        1352 :          IF (logger%para_env%is_source()) THEN
    1653         676 :             unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1654             :          ELSE
    1655         676 :             unit_nr = -1
    1656             :          END IF
    1657             :       END IF
    1658             : 
    1659       14414 :       IF (PRESENT(converged)) converged = .FALSE.
    1660       14414 :       IF (PRESENT(symmetrize)) THEN
    1661           0 :          tsym = symmetrize
    1662             :       ELSE
    1663             :          tsym = .TRUE.
    1664             :       END IF
    1665             : 
    1666             :       ! for stability symmetry can not be assumed
    1667       14414 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1668       14414 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1669       14414 :       IF (order .GE. 4) THEN
    1670          20 :          CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1671             :       END IF
    1672             : 
    1673       14414 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1674       14414 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1675       14414 :       CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1676       14414 :       CALL dbcsr_copy(matrix_sqrt, matrix)
    1677             : 
    1678             :       ! scale the matrix to get into the convergence range
    1679       14414 :       IF (order == 0) THEN
    1680             : 
    1681           0 :          gershgorin_norm = dbcsr_gershgorin_norm(matrix_sqrt)
    1682           0 :          frob_matrix = dbcsr_frobenius_norm(matrix_sqrt)
    1683           0 :          scaling = 1.0_dp/MIN(frob_matrix, gershgorin_norm)
    1684             : 
    1685             :       ELSE
    1686             : 
    1687             :          ! scale the matrix to get into the convergence range
    1688             :          CALL arnoldi_extremal(matrix_sqrt, max_ev, min_ev, threshold=eps_lanczos, &
    1689       14414 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1690       14414 :          IF (unit_nr > 0) THEN
    1691         676 :             WRITE (unit_nr, *)
    1692         676 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1693         676 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1694         676 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1695             :          END IF
    1696             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1697             :          ! and adjust the scaling to be on the safe side
    1698       14414 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1699             : 
    1700             :       END IF
    1701             : 
    1702       14414 :       CALL dbcsr_scale(matrix_sqrt, scaling)
    1703       14414 :       CALL dbcsr_filter(matrix_sqrt, threshold)
    1704       14414 :       IF (unit_nr > 0) THEN
    1705         676 :          WRITE (unit_nr, *)
    1706         676 :          WRITE (unit_nr, *) "Order=", order
    1707             :       END IF
    1708             : 
    1709       67934 :       DO i = 1, 100
    1710             : 
    1711       67934 :          t1 = m_walltime()
    1712             : 
    1713             :          ! tmp1 = Zk * Yk - I
    1714             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1715       67934 :                              filter_eps=threshold, flop=flop1)
    1716       67934 :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1717       67934 :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1718             : 
    1719             :          ! check convergence (frob norm of what should be the identity matrix minus identity matrix)
    1720       67934 :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    1721             : 
    1722       67934 :          flop4 = 0; flop5 = 0
    1723          36 :          SELECT CASE (order)
    1724             :          CASE (0, 2)
    1725             :             ! update the above to 0.5*(3*I-Zk*Yk)
    1726          36 :             CALL dbcsr_add_on_diag(tmp1, -2.0_dp)
    1727          36 :             CALL dbcsr_scale(tmp1, -0.5_dp)
    1728             :          CASE (3)
    1729             :             ! tmp2 = tmp1 ** 2
    1730             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1731       67838 :                                 filter_eps=threshold, flop=flop4)
    1732             :             ! tmp1 = 1/16 * (16*I-8*tmp1+6*tmp1**2-5*tmp1**3)
    1733       67838 :             CALL dbcsr_add(tmp1, tmp2, -4.0_dp, 3.0_dp)
    1734       67838 :             CALL dbcsr_add_on_diag(tmp1, 8.0_dp)
    1735       67838 :             CALL dbcsr_scale(tmp1, 0.125_dp)
    1736             :          CASE (4) ! as expensive as case(5), so little need to use it
    1737             :             ! tmp2 = tmp1 ** 2
    1738             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1739          32 :                                 filter_eps=threshold, flop=flop4)
    1740             :             ! tmp3 = tmp2 * tmp1
    1741             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp1, 0.0_dp, tmp3, &
    1742          32 :                                 filter_eps=threshold, flop=flop5)
    1743          32 :             CALL dbcsr_scale(tmp1, -8.0_dp)
    1744          32 :             CALL dbcsr_add_on_diag(tmp1, 16.0_dp)
    1745          32 :             CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 6.0_dp)
    1746          32 :             CALL dbcsr_add(tmp1, tmp3, 1.0_dp, -5.0_dp)
    1747          32 :             CALL dbcsr_scale(tmp1, 1/16.0_dp)
    1748             :          CASE (5)
    1749             :             ! Knuth's reformulation to evaluate the polynomial of 4th degree in 2 multiplications
    1750             :             ! p = y4+A*y3+B*y2+C*y+D
    1751             :             ! z := y * (y+a); P := (z+y+b) * (z+c) + d.
    1752             :             ! a=(A-1)/2 ; b=B*(a+1)-C-a*(a+1)*(a+1)
    1753             :             ! c=B-b-a*(a+1)
    1754             :             ! d=D-bc
    1755          28 :             oa = -40.0_dp/35.0_dp
    1756          28 :             ob = 48.0_dp/35.0_dp
    1757          28 :             oc = -64.0_dp/35.0_dp
    1758          28 :             od = 128.0_dp/35.0_dp
    1759          28 :             a = (oa - 1)/2
    1760          28 :             b = ob*(a + 1) - oc - a*(a + 1)**2
    1761          28 :             c = ob - b - a*(a + 1)
    1762          28 :             d = od - b*c
    1763             :             ! tmp2 = tmp1 ** 2 + a * tmp1
    1764             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, tmp1, 0.0_dp, tmp2, &
    1765          28 :                                 filter_eps=threshold, flop=flop4)
    1766          28 :             CALL dbcsr_add(tmp2, tmp1, 1.0_dp, a)
    1767             :             ! tmp3 = tmp2 + tmp1 + b
    1768          28 :             CALL dbcsr_copy(tmp3, tmp2)
    1769          28 :             CALL dbcsr_add(tmp3, tmp1, 1.0_dp, 1.0_dp)
    1770          28 :             CALL dbcsr_add_on_diag(tmp3, b)
    1771             :             ! tmp2 = tmp2 + c
    1772          28 :             CALL dbcsr_add_on_diag(tmp2, c)
    1773             :             ! tmp1 = tmp2 * tmp3
    1774             :             CALL dbcsr_multiply("N", "N", 1.0_dp, tmp2, tmp3, 0.0_dp, tmp1, &
    1775          28 :                                 filter_eps=threshold, flop=flop5)
    1776             :             ! tmp1 = tmp1 + d
    1777          28 :             CALL dbcsr_add_on_diag(tmp1, d)
    1778             :             ! final scale
    1779          28 :             CALL dbcsr_scale(tmp1, 35.0_dp/128.0_dp)
    1780             :          CASE DEFAULT
    1781       67934 :             CPABORT("Illegal order value")
    1782             :          END SELECT
    1783             : 
    1784             :          ! tmp2 = Yk * tmp1 = Y(k+1)
    1785             :          CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt, tmp1, 0.0_dp, tmp2, &
    1786       67934 :                              filter_eps=threshold, flop=flop2)
    1787             :          ! CALL dbcsr_filter(tmp2,threshold)
    1788       67934 :          CALL dbcsr_copy(matrix_sqrt, tmp2)
    1789             : 
    1790             :          ! tmp2 = tmp1 * Zk = Z(k+1)
    1791             :          CALL dbcsr_multiply("N", "N", 1.0_dp, tmp1, matrix_sqrt_inv, 0.0_dp, tmp2, &
    1792       67934 :                              filter_eps=threshold, flop=flop3)
    1793             :          ! CALL dbcsr_filter(tmp2,threshold)
    1794       67934 :          CALL dbcsr_copy(matrix_sqrt_inv, tmp2)
    1795             : 
    1796       67934 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1797             : 
    1798             :          ! done iterating
    1799       67934 :          t2 = m_walltime()
    1800             : 
    1801       67934 :          conv = frob_matrix/frob_matrix_base
    1802             : 
    1803       67934 :          IF (unit_nr > 0) THEN
    1804        3818 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "NS sqrt iter ", i, occ_matrix, &
    1805        3818 :                conv, t2 - t1, &
    1806        7636 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    1807        3818 :             CALL m_flush(unit_nr)
    1808             :          END IF
    1809             : 
    1810       67934 :          IF (abnormal_value(conv)) &
    1811           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    1812             : 
    1813             :          ! conv < SQRT(threshold)
    1814       67934 :          IF ((conv*conv) < threshold) THEN
    1815       14414 :             IF (PRESENT(converged)) converged = .TRUE.
    1816             :             EXIT
    1817             :          END IF
    1818             : 
    1819             :       END DO
    1820             : 
    1821             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    1822       14414 :       IF (tsym) THEN
    1823       14414 :          IF (unit_nr > 0) THEN
    1824         676 :             WRITE (unit_nr, '(T6,A20)') "Symmetrizing Results"
    1825             :          END IF
    1826       14414 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    1827       14414 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    1828       14414 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    1829       14414 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    1830             :       END IF
    1831             : 
    1832             :       ! this check is not really needed
    1833             :       CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    1834       14414 :                           filter_eps=threshold)
    1835       14414 :       frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    1836       14414 :       CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    1837       14414 :       frob_matrix = dbcsr_frobenius_norm(tmp1)
    1838       14414 :       occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    1839       14414 :       IF (unit_nr > 0) THEN
    1840         676 :          WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final NS sqrt iter ", i, occ_matrix, &
    1841        1352 :             frob_matrix/frob_matrix_base
    1842         676 :          WRITE (unit_nr, '()')
    1843         676 :          CALL m_flush(unit_nr)
    1844             :       END IF
    1845             : 
    1846             :       ! scale to proper end results
    1847       14414 :       CALL dbcsr_scale(matrix_sqrt, 1/SQRT(scaling))
    1848       14414 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    1849             : 
    1850       14414 :       CALL dbcsr_release(tmp1)
    1851       14414 :       CALL dbcsr_release(tmp2)
    1852       14414 :       IF (order .GE. 4) THEN
    1853          20 :          CALL dbcsr_release(tmp3)
    1854             :       END IF
    1855             : 
    1856       14414 :       CALL timestop(handle)
    1857             : 
    1858       14414 :    END SUBROUTINE matrix_sqrt_Newton_Schulz
    1859             : 
    1860             : ! **************************************************************************************************
    1861             : !> \brief compute the sqrt of a matrix via the general algorithm for the p-th root of Richters et al.
    1862             : !>                   Commun. Comput. Phys., 25 (2019), pp. 564-585.
    1863             : !> \param matrix_sqrt ...
    1864             : !> \param matrix_sqrt_inv ...
    1865             : !> \param matrix ...
    1866             : !> \param threshold ...
    1867             : !> \param order ...
    1868             : !> \param eps_lanczos ...
    1869             : !> \param max_iter_lanczos ...
    1870             : !> \param symmetrize ...
    1871             : !> \param converged ...
    1872             : !> \par History
    1873             : !>       2019.04 created [Robert Schade]
    1874             : !> \author Robert Schade
    1875             : ! **************************************************************************************************
    1876          48 :    SUBROUTINE matrix_sqrt_proot(matrix_sqrt, matrix_sqrt_inv, matrix, threshold, order, &
    1877             :                                 eps_lanczos, max_iter_lanczos, symmetrize, converged)
    1878             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_sqrt, matrix_sqrt_inv, matrix
    1879             :       REAL(KIND=dp), INTENT(IN)                          :: threshold
    1880             :       INTEGER, INTENT(IN)                                :: order
    1881             :       REAL(KIND=dp), INTENT(IN)                          :: eps_lanczos
    1882             :       INTEGER, INTENT(IN)                                :: max_iter_lanczos
    1883             :       LOGICAL, OPTIONAL                                  :: symmetrize, converged
    1884             : 
    1885             :       CHARACTER(LEN=*), PARAMETER                        :: routineN = 'matrix_sqrt_proot'
    1886             : 
    1887             :       INTEGER                                            :: choose, handle, i, ii, j, unit_nr
    1888             :       INTEGER(KIND=int_8)                                :: f, flop1, flop2, flop3, flop4, flop5
    1889             :       LOGICAL                                            :: arnoldi_converged, test, tsym
    1890             :       REAL(KIND=dp)                                      :: conv, frob_matrix, frob_matrix_base, &
    1891             :                                                             max_ev, min_ev, occ_matrix, scaling, &
    1892             :                                                             t1, t2
    1893             :       TYPE(cp_logger_type), POINTER                      :: logger
    1894             :       TYPE(dbcsr_type)                                   :: BK2A, matrixS, Rmat, tmp1, tmp2, tmp3
    1895             : 
    1896          16 :       CALL cite_reference(Richters2018)
    1897             : 
    1898          16 :       test = .FALSE.
    1899             : 
    1900          16 :       CALL timeset(routineN, handle)
    1901             : 
    1902          16 :       logger => cp_get_default_logger()
    1903          16 :       IF (logger%para_env%is_source()) THEN
    1904           8 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    1905             :       ELSE
    1906           8 :          unit_nr = -1
    1907             :       END IF
    1908             : 
    1909          16 :       IF (PRESENT(converged)) converged = .FALSE.
    1910          16 :       IF (PRESENT(symmetrize)) THEN
    1911          16 :          tsym = symmetrize
    1912             :       ELSE
    1913             :          tsym = .TRUE.
    1914             :       END IF
    1915             : 
    1916             :       ! for stability symmetry can not be assumed
    1917          16 :       CALL dbcsr_create(tmp1, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1918          16 :       CALL dbcsr_create(tmp2, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1919          16 :       CALL dbcsr_create(tmp3, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1920          16 :       CALL dbcsr_create(Rmat, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1921          16 :       CALL dbcsr_create(matrixS, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1922             : 
    1923          16 :       CALL dbcsr_copy(matrixS, matrix)
    1924             :       IF (1 .EQ. 1) THEN
    1925             :          ! scale the matrix to get into the convergence range
    1926             :          CALL arnoldi_extremal(matrixS, max_ev, min_ev, threshold=eps_lanczos, &
    1927          16 :                                max_iter=max_iter_lanczos, converged=arnoldi_converged)
    1928          16 :          IF (unit_nr > 0) THEN
    1929           8 :             WRITE (unit_nr, *)
    1930           8 :             WRITE (unit_nr, '(T6,A,1X,L1,A,E12.3)') "Lanczos converged: ", arnoldi_converged, " threshold:", eps_lanczos
    1931           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3,E12.3)') "Est. extremal eigenvalues:", max_ev, min_ev
    1932           8 :             WRITE (unit_nr, '(T6,A,1X,E12.3)') "Est. condition number :", max_ev/MAX(min_ev, EPSILON(min_ev))
    1933             :          END IF
    1934             :          ! conservatively assume we get a relatively large error (100*threshold_lanczos) in the estimates
    1935             :          ! and adjust the scaling to be on the safe side
    1936          16 :          scaling = 2.0_dp/(max_ev + min_ev + 100*eps_lanczos)
    1937          16 :          CALL dbcsr_scale(matrixS, scaling)
    1938          16 :          CALL dbcsr_filter(matrixS, threshold)
    1939             :       ELSE
    1940             :          scaling = 1.0_dp
    1941             :       END IF
    1942             : 
    1943          16 :       CALL dbcsr_set(matrix_sqrt_inv, 0.0_dp)
    1944          16 :       CALL dbcsr_add_on_diag(matrix_sqrt_inv, 1.0_dp)
    1945             :       !CALL dbcsr_filter(matrix_sqrt_inv, threshold)
    1946             : 
    1947          16 :       IF (unit_nr > 0) THEN
    1948           8 :          WRITE (unit_nr, *)
    1949           8 :          WRITE (unit_nr, *) "Order=", order
    1950             :       END IF
    1951             : 
    1952          86 :       DO i = 1, 100
    1953             : 
    1954          86 :          t1 = m_walltime()
    1955             :          IF (1 .EQ. 1) THEN
    1956             :             !build R=1-A B_K^2
    1957             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp1, &
    1958          86 :                                 filter_eps=threshold, flop=flop1)
    1959             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrixS, tmp1, 0.0_dp, Rmat, &
    1960          86 :                                 filter_eps=threshold, flop=flop2)
    1961          86 :             CALL dbcsr_scale(Rmat, -1.0_dp)
    1962          86 :             CALL dbcsr_add_on_diag(Rmat, 1.0_dp)
    1963             : 
    1964          86 :             flop4 = 0; flop5 = 0
    1965          86 :             CALL dbcsr_set(tmp1, 0.0_dp)
    1966          86 :             CALL dbcsr_add_on_diag(tmp1, 2.0_dp)
    1967             : 
    1968          86 :             flop3 = 0
    1969             : 
    1970         274 :             DO j = 2, order
    1971         188 :                IF (j .EQ. 2) THEN
    1972          86 :                   CALL dbcsr_copy(tmp2, Rmat)
    1973             :                ELSE
    1974             :                   f = 0
    1975         102 :                   CALL dbcsr_copy(tmp3, tmp2)
    1976             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, Rmat, 0.0_dp, tmp2, &
    1977         102 :                                       filter_eps=threshold, flop=f)
    1978         102 :                   flop3 = flop3 + f
    1979             :                END IF
    1980         274 :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, 1.0_dp)
    1981             :             END DO
    1982             :          ELSE
    1983             :             CALL dbcsr_create(BK2A, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    1984             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrixS, 0.0_dp, tmp3, &
    1985             :                                 filter_eps=threshold, flop=flop1)
    1986             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, tmp3, 0.0_dp, BK2A, &
    1987             :                                 filter_eps=threshold, flop=flop2)
    1988             :             CALL dbcsr_copy(Rmat, BK2A)
    1989             :             CALL dbcsr_add_on_diag(Rmat, -1.0_dp)
    1990             : 
    1991             :             CALL dbcsr_set(tmp1, 0.0_dp)
    1992             :             CALL dbcsr_add_on_diag(tmp1, 1.0_dp)
    1993             : 
    1994             :             CALL dbcsr_set(tmp2, 0.0_dp)
    1995             :             CALL dbcsr_add_on_diag(tmp2, 1.0_dp)
    1996             : 
    1997             :             flop3 = 0
    1998             :             DO j = 1, order
    1999             :                !choose=factorial(order)/(factorial(j)*factorial(order-j))
    2000             :                choose = PRODUCT((/(ii, ii=1, order)/))/(PRODUCT((/(ii, ii=1, j)/))*PRODUCT((/(ii, ii=1, order - j)/)))
    2001             :                CALL dbcsr_add(tmp1, tmp2, 1.0_dp, -1.0_dp*(-1)**j*choose)
    2002             :                IF (j .LT. order) THEN
    2003             :                   f = 0
    2004             :                   CALL dbcsr_copy(tmp3, tmp2)
    2005             :                   CALL dbcsr_multiply("N", "N", 1.0_dp, tmp3, BK2A, 0.0_dp, tmp2, &
    2006             :                                       filter_eps=threshold, flop=f)
    2007             :                   flop3 = flop3 + f
    2008             :                END IF
    2009             :             END DO
    2010             :             CALL dbcsr_release(BK2A)
    2011             :          END IF
    2012             : 
    2013          86 :          CALL dbcsr_copy(tmp3, matrix_sqrt_inv)
    2014             :          CALL dbcsr_multiply("N", "N", 0.5_dp, tmp3, tmp1, 0.0_dp, matrix_sqrt_inv, &
    2015          86 :                              filter_eps=threshold, flop=flop4)
    2016             : 
    2017          86 :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2018             : 
    2019             :          ! done iterating
    2020          86 :          t2 = m_walltime()
    2021             : 
    2022          86 :          conv = dbcsr_frobenius_norm(Rmat)
    2023             : 
    2024          86 :          IF (unit_nr > 0) THEN
    2025          43 :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3,F12.3,F13.3)') "PROOT sqrt iter ", i, occ_matrix, &
    2026          43 :                conv, t2 - t1, &
    2027          86 :                (flop1 + flop2 + flop3 + flop4 + flop5)/(1.0E6_dp*MAX(0.001_dp, t2 - t1))
    2028          43 :             CALL m_flush(unit_nr)
    2029             :          END IF
    2030             : 
    2031          86 :          IF (abnormal_value(conv)) &
    2032           0 :             CPABORT("conv is an abnormal value (NaN/Inf).")
    2033             : 
    2034             :          ! conv < SQRT(threshold)
    2035          86 :          IF ((conv*conv) < threshold) THEN
    2036          16 :             IF (PRESENT(converged)) converged = .TRUE.
    2037             :             EXIT
    2038             :          END IF
    2039             : 
    2040             :       END DO
    2041             : 
    2042             :       ! scale to proper end results
    2043          16 :       CALL dbcsr_scale(matrix_sqrt_inv, SQRT(scaling))
    2044             :       CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_sqrt_inv, matrix, 0.0_dp, matrix_sqrt, &
    2045          16 :                           filter_eps=threshold, flop=flop5)
    2046             : 
    2047             :       ! symmetrize the matrices as this is not guaranteed by the algorithm
    2048          16 :       IF (tsym) THEN
    2049           8 :          IF (unit_nr > 0) THEN
    2050           4 :             WRITE (unit_nr, '(A20)') "SYMMETRIZING RESULTS"
    2051             :          END IF
    2052           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt_inv)
    2053           8 :          CALL dbcsr_add(matrix_sqrt_inv, tmp1, 0.5_dp, 0.5_dp)
    2054           8 :          CALL dbcsr_transposed(tmp1, matrix_sqrt)
    2055           8 :          CALL dbcsr_add(matrix_sqrt, tmp1, 0.5_dp, 0.5_dp)
    2056             :       END IF
    2057             : 
    2058             :       ! this check is not really needed
    2059             :       IF (test) THEN
    2060             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt, 0.0_dp, tmp1, &
    2061             :                              filter_eps=threshold)
    2062             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2063             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2064             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2065             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2066             :          IF (unit_nr > 0) THEN
    2067             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{1/2}-Eins error ", i, occ_matrix, &
    2068             :                frob_matrix/frob_matrix_base
    2069             :             WRITE (unit_nr, '()')
    2070             :             CALL m_flush(unit_nr)
    2071             :          END IF
    2072             : 
    2073             :          ! this check is not really needed
    2074             :          CALL dbcsr_multiply("N", "N", +1.0_dp, matrix_sqrt_inv, matrix_sqrt_inv, 0.0_dp, tmp2, &
    2075             :                              filter_eps=threshold)
    2076             :          CALL dbcsr_multiply("N", "N", +1.0_dp, tmp2, matrix, 0.0_dp, tmp1, &
    2077             :                              filter_eps=threshold)
    2078             :          frob_matrix_base = dbcsr_frobenius_norm(tmp1)
    2079             :          CALL dbcsr_add_on_diag(tmp1, -1.0_dp)
    2080             :          frob_matrix = dbcsr_frobenius_norm(tmp1)
    2081             :          occ_matrix = dbcsr_get_occupation(matrix_sqrt_inv)
    2082             :          IF (unit_nr > 0) THEN
    2083             :             WRITE (unit_nr, '(T6,A,1X,I3,1X,F10.8,E12.3)') "Final PROOT S^{-1/2} S^{-1/2} S-Eins error ", i, occ_matrix, &
    2084             :                frob_matrix/frob_matrix_base
    2085             :             WRITE (unit_nr, '()')
    2086             :             CALL m_flush(unit_nr)
    2087             :          END IF
    2088             :       END IF
    2089             : 
    2090          16 :       CALL dbcsr_release(tmp1)
    2091          16 :       CALL dbcsr_release(tmp2)
    2092          16 :       CALL dbcsr_release(tmp3)
    2093          16 :       CALL dbcsr_release(Rmat)
    2094          16 :       CALL dbcsr_release(matrixS)
    2095             : 
    2096          16 :       CALL timestop(handle)
    2097          16 :    END SUBROUTINE matrix_sqrt_proot
    2098             : 
    2099             : ! **************************************************************************************************
    2100             : !> \brief ...
    2101             : !> \param matrix_exp ...
    2102             : !> \param matrix ...
    2103             : !> \param omega ...
    2104             : !> \param alpha ...
    2105             : !> \param threshold ...
    2106             : ! **************************************************************************************************
    2107        1146 :    SUBROUTINE matrix_exponential(matrix_exp, matrix, omega, alpha, threshold)
    2108             :       ! compute matrix_exp=omega*exp(alpha*matrix)
    2109             :       TYPE(dbcsr_type), INTENT(INOUT)                    :: matrix_exp, matrix
    2110             :       REAL(KIND=dp), INTENT(IN)                          :: omega, alpha, threshold
    2111             : 
    2112             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'matrix_exponential'
    2113             :       REAL(dp), PARAMETER                                :: one = 1.0_dp, toll = 1.E-17_dp, &
    2114             :                                                             zero = 0.0_dp
    2115             : 
    2116             :       INTEGER                                            :: handle, i, k, unit_nr
    2117             :       REAL(dp)                                           :: factorial, norm_C, norm_D, norm_scalar
    2118             :       TYPE(cp_logger_type), POINTER                      :: logger
    2119             :       TYPE(dbcsr_type)                                   :: B, B_square, C, D, D_product
    2120             : 
    2121        1146 :       CALL timeset(routineN, handle)
    2122             : 
    2123        1146 :       logger => cp_get_default_logger()
    2124        1146 :       IF (logger%para_env%is_source()) THEN
    2125        1058 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2126             :       ELSE
    2127             :          unit_nr = -1
    2128             :       END IF
    2129             : 
    2130             :       ! Calculate the norm of the matrix alpha*matrix, and scale it until it is less than 1.0
    2131        1146 :       norm_scalar = ABS(alpha)*dbcsr_frobenius_norm(matrix)
    2132             : 
    2133             :       ! k=scaling parameter
    2134        1146 :       k = 1
    2135        1008 :       DO
    2136        2154 :          IF ((norm_scalar/2.0_dp**k) <= one) EXIT
    2137        1008 :          k = k + 1
    2138             :       END DO
    2139             : 
    2140             :       ! copy and scale the input matrix in matrix C and in matrix D
    2141        1146 :       CALL dbcsr_create(C, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2142        1146 :       CALL dbcsr_copy(C, matrix)
    2143        1146 :       CALL dbcsr_scale(C, alpha_scalar=alpha/2.0_dp**k)
    2144             : 
    2145        1146 :       CALL dbcsr_create(D, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2146        1146 :       CALL dbcsr_copy(D, C)
    2147             : 
    2148             :       !   write(*,*)
    2149             :       !   write(*,*)
    2150             :       !   CALL dbcsr_print(D, variable_name="D")
    2151             : 
    2152             :       ! set the B matrix as B=Identity+D
    2153        1146 :       CALL dbcsr_create(B, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2154        1146 :       CALL dbcsr_copy(B, D)
    2155        1146 :       CALL dbcsr_add_on_diag(B, alpha=one)
    2156             : 
    2157             :       !   CALL dbcsr_print(B, variable_name="B")
    2158             : 
    2159             :       ! Calculate the norm of C and moltiply by toll to be used as a threshold
    2160        1146 :       norm_C = toll*dbcsr_frobenius_norm(matrix)
    2161             : 
    2162             :       ! iteration for the truncated taylor series expansion
    2163        1146 :       CALL dbcsr_create(D_product, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2164        1146 :       i = 1
    2165             :       DO
    2166       12676 :          i = i + 1
    2167             :          ! compute D_product=D*C
    2168             :          CALL dbcsr_multiply("N", "N", one, D, C, &
    2169       12676 :                              zero, D_product, filter_eps=threshold)
    2170             : 
    2171             :          ! copy D_product in D
    2172       12676 :          CALL dbcsr_copy(D, D_product)
    2173             : 
    2174             :          ! calculate B=B+D_product/fat(i)
    2175       12676 :          factorial = ifac(i)
    2176       12676 :          CALL dbcsr_add(B, D_product, one, factorial)
    2177             : 
    2178             :          ! check for convergence using the norm of D (copy of the matrix D_product) and C
    2179       12676 :          norm_D = factorial*dbcsr_frobenius_norm(D)
    2180       12676 :          IF (norm_D < norm_C) EXIT
    2181             :       END DO
    2182             : 
    2183             :       ! start the k iteration for the squaring of the matrix
    2184        1146 :       CALL dbcsr_create(B_square, template=matrix, matrix_type=dbcsr_type_no_symmetry)
    2185        3300 :       DO i = 1, k
    2186             :          !compute B_square=B*B
    2187             :          CALL dbcsr_multiply("N", "N", one, B, B, &
    2188        2154 :                              zero, B_square, filter_eps=threshold)
    2189             :          ! copy Bsquare in B to iterate
    2190        3300 :          CALL dbcsr_copy(B, B_square)
    2191             :       END DO
    2192             : 
    2193             :       ! copy B_square in matrix_exp and
    2194        1146 :       CALL dbcsr_copy(matrix_exp, B_square)
    2195             : 
    2196             :       ! scale matrix_exp by omega, matrix_exp=omega*B_square
    2197        1146 :       CALL dbcsr_scale(matrix_exp, alpha_scalar=omega)
    2198             :       ! write(*,*) alpha,omega
    2199             : 
    2200        1146 :       CALL dbcsr_release(B)
    2201        1146 :       CALL dbcsr_release(C)
    2202        1146 :       CALL dbcsr_release(D)
    2203        1146 :       CALL dbcsr_release(D_product)
    2204        1146 :       CALL dbcsr_release(B_square)
    2205             : 
    2206        1146 :       CALL timestop(handle)
    2207             : 
    2208        1146 :    END SUBROUTINE matrix_exponential
    2209             : 
    2210             : ! **************************************************************************************************
    2211             : !> \brief McWeeny purification of a matrix in the orthonormal basis
    2212             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2213             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2214             : !> \param max_steps Max number of iterations
    2215             : !> \par History
    2216             : !>       2013.01 created [Florian Schiffmann]
    2217             : !>       2014.07 slightly refactored [Ole Schuett]
    2218             : !> \author Florian Schiffmann
    2219             : ! **************************************************************************************************
    2220         234 :    SUBROUTINE purify_mcweeny_orth(matrix_p, threshold, max_steps)
    2221             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2222             :       REAL(KIND=dp)                                      :: threshold
    2223             :       INTEGER                                            :: max_steps
    2224             : 
    2225             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_orth'
    2226             : 
    2227             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2228             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2229             :       TYPE(cp_logger_type), POINTER                      :: logger
    2230             :       TYPE(dbcsr_type)                                   :: matrix_pp, matrix_tmp
    2231             : 
    2232         234 :       CALL timeset(routineN, handle)
    2233         234 :       logger => cp_get_default_logger()
    2234         234 :       IF (logger%para_env%is_source()) THEN
    2235         117 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2236             :       ELSE
    2237         117 :          unit_nr = -1
    2238             :       END IF
    2239             : 
    2240         234 :       CALL dbcsr_create(matrix_pp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2241         234 :       CALL dbcsr_create(matrix_tmp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2242         234 :       CALL dbcsr_trace(matrix_p(1), trace)
    2243             : 
    2244         476 :       DO ispin = 1, SIZE(matrix_p)
    2245         476 :          DO i = 1, max_steps
    2246             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_p(ispin), &
    2247         242 :                                 0.0_dp, matrix_pp, filter_eps=threshold)
    2248             : 
    2249             :             ! test convergence
    2250         242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2251         242 :             CALL dbcsr_add(matrix_tmp, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2252         242 :             frob_norm = dbcsr_frobenius_norm(matrix_tmp) ! tmp = PP - P
    2253         242 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2254         242 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2255             : 
    2256             :             ! construct new P
    2257         242 :             CALL dbcsr_copy(matrix_tmp, matrix_pp)
    2258             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_pp, matrix_p(ispin), &
    2259         242 :                                 3.0_dp, matrix_tmp, filter_eps=threshold)
    2260         242 :             CALL dbcsr_copy(matrix_p(ispin), matrix_tmp) ! tmp = 3PP - 2PPP
    2261             : 
    2262             :             ! frob_norm < SQRT(trace*threshold)
    2263         242 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2264             :          END DO
    2265             :       END DO
    2266             : 
    2267         234 :       CALL dbcsr_release(matrix_pp)
    2268         234 :       CALL dbcsr_release(matrix_tmp)
    2269         234 :       CALL timestop(handle)
    2270         234 :    END SUBROUTINE purify_mcweeny_orth
    2271             : 
    2272             : ! **************************************************************************************************
    2273             : !> \brief McWeeny purification of a matrix in the non-orthonormal basis
    2274             : !> \param matrix_p Matrix to purify (needs to be almost idempotent already)
    2275             : !> \param matrix_s Overlap-Matrix
    2276             : !> \param threshold Threshold used as filter_eps and convergence criteria
    2277             : !> \param max_steps Max number of iterations
    2278             : !> \par History
    2279             : !>       2013.01 created [Florian Schiffmann]
    2280             : !>       2014.07 slightly refactored [Ole Schuett]
    2281             : !> \author Florian Schiffmann
    2282             : ! **************************************************************************************************
    2283         196 :    SUBROUTINE purify_mcweeny_nonorth(matrix_p, matrix_s, threshold, max_steps)
    2284             :       TYPE(dbcsr_type), DIMENSION(:)                     :: matrix_p
    2285             :       TYPE(dbcsr_type)                                   :: matrix_s
    2286             :       REAL(KIND=dp)                                      :: threshold
    2287             :       INTEGER                                            :: max_steps
    2288             : 
    2289             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'purify_mcweeny_nonorth'
    2290             : 
    2291             :       INTEGER                                            :: handle, i, ispin, unit_nr
    2292             :       REAL(KIND=dp)                                      :: frob_norm, trace
    2293             :       TYPE(cp_logger_type), POINTER                      :: logger
    2294             :       TYPE(dbcsr_type)                                   :: matrix_ps, matrix_psp, matrix_test
    2295             : 
    2296         196 :       CALL timeset(routineN, handle)
    2297             : 
    2298         196 :       logger => cp_get_default_logger()
    2299         196 :       IF (logger%para_env%is_source()) THEN
    2300          98 :          unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
    2301             :       ELSE
    2302          98 :          unit_nr = -1
    2303             :       END IF
    2304             : 
    2305         196 :       CALL dbcsr_create(matrix_ps, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2306         196 :       CALL dbcsr_create(matrix_psp, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2307         196 :       CALL dbcsr_create(matrix_test, template=matrix_p(1), matrix_type=dbcsr_type_no_symmetry)
    2308             : 
    2309         392 :       DO ispin = 1, SIZE(matrix_p)
    2310         404 :          DO i = 1, max_steps
    2311             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_p(ispin), matrix_s, &
    2312         208 :                                 0.0_dp, matrix_ps, filter_eps=threshold)
    2313             :             CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_ps, matrix_p(ispin), &
    2314         208 :                                 0.0_dp, matrix_psp, filter_eps=threshold)
    2315         208 :             IF (i == 1) CALL dbcsr_trace(matrix_ps, trace)
    2316             : 
    2317             :             ! test convergence
    2318         208 :             CALL dbcsr_copy(matrix_test, matrix_psp)
    2319         208 :             CALL dbcsr_add(matrix_test, matrix_p(ispin), 1.0_dp, -1.0_dp)
    2320         208 :             frob_norm = dbcsr_frobenius_norm(matrix_test) ! test = PSP - P
    2321         208 :             IF (unit_nr > 0) WRITE (unit_nr, '(t3,a,2f16.8)') "McWeeny: Deviation of idempotency", frob_norm
    2322         208 :             IF (unit_nr > 0) CALL m_flush(unit_nr)
    2323             : 
    2324             :             ! construct new P
    2325         208 :             CALL dbcsr_copy(matrix_p(ispin), matrix_psp)
    2326             :             CALL dbcsr_multiply("N", "N", -2.0_dp, matrix_ps, matrix_psp, &
    2327         208 :                                 3.0_dp, matrix_p(ispin), filter_eps=threshold)
    2328             : 
    2329             :             ! frob_norm < SQRT(trace*threshold)
    2330         208 :             IF (frob_norm*frob_norm < trace*threshold) EXIT
    2331             :          END DO
    2332             :       END DO
    2333             : 
    2334         196 :       CALL dbcsr_release(matrix_ps)
    2335         196 :       CALL dbcsr_release(matrix_psp)
    2336         196 :       CALL dbcsr_release(matrix_test)
    2337         196 :       CALL timestop(handle)
    2338         196 :    END SUBROUTINE purify_mcweeny_nonorth
    2339             : 
    2340           0 : END MODULE iterate_matrix

Generated by: LCOV version 1.15