LCOV - code coverage report
Current view: top level - src/fm - cp_fm_basic_linalg.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b8e0b09) Lines: 591 825 71.6 %
Date: 2024-08-31 06:31:37 Functions: 29 42 69.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief basic linear algebra operations for full matrices
      10             : !> \par History
      11             : !>      08.2002 split out of qs_blacs [fawzi]
      12             : !> \author Fawzi Mohamed
      13             : ! **************************************************************************************************
      14             : MODULE cp_fm_basic_linalg
      15             :    USE cp_blacs_env, ONLY: cp_blacs_env_type
      16             :    USE cp_fm_struct, ONLY: cp_fm_struct_equivalent
      17             :    USE cp_fm_types, ONLY: &
      18             :       cp_fm_create, cp_fm_get_diag, cp_fm_get_info, cp_fm_get_submatrix, cp_fm_p_type, &
      19             :       cp_fm_release, cp_fm_set_all, cp_fm_set_element, cp_fm_set_submatrix, cp_fm_to_fm, &
      20             :       cp_fm_type
      21             :    USE cp_log_handling, ONLY: cp_logger_get_default_unit_nr, &
      22             :                               cp_to_string
      23             :    USE kahan_sum, ONLY: accurate_dot_product, &
      24             :                         accurate_sum
      25             :    USE kinds, ONLY: dp, &
      26             :                     int_8, &
      27             :                     sp
      28             :    USE machine, ONLY: m_memory
      29             :    USE mathlib, ONLY: get_pseudo_inverse_svd, &
      30             :                       invert_matrix
      31             :    USE message_passing, ONLY: mp_comm_type
      32             : #include "../base/base_uses.f90"
      33             : 
      34             :    IMPLICIT NONE
      35             :    PRIVATE
      36             : 
      37             :    LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .TRUE.
      38             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_fm_basic_linalg'
      39             : 
      40             :    PUBLIC :: cp_fm_scale, & ! scale a matrix
      41             :              cp_fm_scale_and_add, & ! scale and add two matrices
      42             :              cp_fm_geadd, & ! general addition
      43             :              cp_fm_column_scale, & ! scale columns of a matrix
      44             :              cp_fm_row_scale, & ! scale rows of a matrix
      45             :              cp_fm_trace, & ! trace of the transpose(A)*B
      46             :              cp_fm_contracted_trace, & ! sum_{i,...,k} Tr [A(i,...,k)^T * B(i,...,k)]
      47             :              cp_fm_norm, & ! different norms of A
      48             :              cp_fm_schur_product, & ! schur product
      49             :              cp_fm_transpose, & ! transpose a matrix
      50             :              cp_fm_upper_to_full, & ! symmetrise an upper symmetric matrix
      51             :              cp_fm_syrk, & ! rank k update
      52             :              cp_fm_triangular_multiply, & ! triangular matrix multiply / solve
      53             :              cp_fm_symm, & ! multiply a symmetric with a non-symmetric matrix
      54             :              cp_fm_gemm, & ! multiply two matrices
      55             :              cp_complex_fm_gemm, & ! multiply two complex matrices, represented by non_complex fm matrices
      56             :              cp_fm_invert, & ! computes the inverse and determinant
      57             :              cp_fm_frobenius_norm, & ! frobenius norm
      58             :              cp_fm_triangular_invert, & ! compute the reciprocal of a triangular matrix
      59             :              cp_fm_qr_factorization, & ! compute the QR factorization of a rectangular matrix
      60             :              cp_fm_solve, & ! solves the equation  A*B=C A and C are input
      61             :              cp_fm_pdgeqpf, & ! compute a QR factorization with column pivoting of a M-by-N distributed matrix
      62             :              cp_fm_pdorgqr, & ! generates an M-by-N as first N columns of a product of K elementary reflectors
      63             :              cp_fm_potrf, & ! Cholesky decomposition
      64             :              cp_fm_potri, & ! Invert triangular matrix
      65             :              cp_fm_rot_rows, & ! rotates two rows
      66             :              cp_fm_rot_cols, & ! rotates two columns
      67             :              cp_fm_cholesky_restore, & ! apply Cholesky decomposition
      68             :              cp_fm_Gram_Schmidt_orthonorm, & ! Gram-Schmidt orthonormalization of columns of a full matrix, &
      69             :              cp_fm_det ! determinant of a real matrix with correct sign
      70             : 
      71             :    REAL(KIND=dp), EXTERNAL :: dlange, pdlange, pdlatra
      72             :    REAL(KIND=sp), EXTERNAL :: slange, pslange, pslatra
      73             : 
      74             :    INTERFACE cp_fm_trace
      75             :       MODULE PROCEDURE cp_fm_trace_a0b0t0
      76             :       MODULE PROCEDURE cp_fm_trace_a1b0t1_a
      77             :       MODULE PROCEDURE cp_fm_trace_a1b0t1_p
      78             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_aa
      79             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_ap
      80             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_pa
      81             :       MODULE PROCEDURE cp_fm_trace_a1b1t1_pp
      82             :    END INTERFACE cp_fm_trace
      83             : 
      84             :    INTERFACE cp_fm_contracted_trace
      85             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_aa
      86             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_ap
      87             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_pa
      88             :       MODULE PROCEDURE cp_fm_contracted_trace_a2b2t2_pp
      89             :    END INTERFACE cp_fm_contracted_trace
      90             : CONTAINS
      91             : 
      92             : ! **************************************************************************************************
      93             : !> \brief Computes the determinant (with a correct sign even in parallel environment!) of a real square matrix
      94             : !> \author A. Sinyavskiy (andrey.sinyavskiy@chem.uzh.ch)
      95             : ! **************************************************************************************************
      96           0 :    SUBROUTINE cp_fm_det(matrix_a, det_a)
      97             : 
      98             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
      99             :       REAL(KIND=dp), INTENT(OUT)               :: det_a
     100             :       REAL(KIND=dp)                            :: determinant
     101             :       TYPE(cp_fm_type)                         :: matrix_lu
     102             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     103             :       INTEGER                                  :: n, i, info, P
     104           0 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
     105             :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
     106             : 
     107             : #if defined(__parallel)
     108             :       INTEGER                                  :: myprow, nprow, npcol, nrow_local, nrow_block, irow_local
     109             :       INTEGER, DIMENSION(9)                    :: desca
     110             : #endif
     111             : 
     112             :       CALL cp_fm_create(matrix=matrix_lu, &
     113             :                         matrix_struct=matrix_a%matrix_struct, &
     114           0 :                         name="A_lu"//TRIM(ADJUSTL(cp_to_string(1)))//"MATRIX")
     115           0 :       CALL cp_fm_to_fm(matrix_a, matrix_lu)
     116             : 
     117           0 :       a => matrix_lu%local_data
     118           0 :       n = matrix_lu%matrix_struct%nrow_global
     119           0 :       ALLOCATE (ipivot(n))
     120           0 :       ipivot(:) = 0
     121           0 :       P = 0
     122           0 :       ALLOCATE (diag(n))
     123           0 :       diag(:) = 0.0_dp
     124             : #if defined(__parallel)
     125             :       ! Use LU decomposition
     126           0 :       desca(:) = matrix_lu%matrix_struct%descriptor(:)
     127           0 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
     128           0 :       CALL cp_fm_get_diag(matrix_lu, diag)
     129           0 :       determinant = PRODUCT(diag)
     130           0 :       myprow = matrix_lu%matrix_struct%context%mepos(1)
     131           0 :       nprow = matrix_lu%matrix_struct%context%num_pe(1)
     132           0 :       npcol = matrix_lu%matrix_struct%context%num_pe(2)
     133           0 :       nrow_local = matrix_lu%matrix_struct%nrow_locals(myprow)
     134           0 :       nrow_block = matrix_lu%matrix_struct%nrow_block
     135           0 :       DO irow_local = 1, nrow_local
     136           0 :          i = matrix_lu%matrix_struct%row_indices(irow_local)
     137           0 :          IF (ipivot(irow_local) /= i) P = P + 1
     138             :       END DO
     139           0 :       CALL matrix_lu%matrix_struct%para_env%sum(P)
     140             :       ! very important fix
     141           0 :       P = P/npcol
     142             : #else
     143             :       CALL dgetrf(n, n, a, n, ipivot, info)
     144             :       CALL cp_fm_get_diag(matrix_lu, diag)
     145             :       determinant = PRODUCT(diag)
     146             :       DO i = 1, n
     147             :          IF (ipivot(i) /= i) P = P + 1
     148             :       END DO
     149             : #endif
     150           0 :       DEALLOCATE (ipivot)
     151           0 :       DEALLOCATE (diag)
     152           0 :       CALL cp_fm_release(matrix_lu)
     153           0 :       det_a = determinant*(-2*MOD(P, 2) + 1.0_dp)
     154           0 :    END SUBROUTINE cp_fm_det
     155             : 
     156             : ! **************************************************************************************************
     157             : !> \brief calc A <- alpha*A + beta*B
     158             : !>      optimized for alpha == 1.0 (just add beta*B) and beta == 0.0 (just
     159             : !>      scale A)
     160             : !> \param alpha ...
     161             : !> \param matrix_a ...
     162             : !> \param beta ...
     163             : !> \param matrix_b ...
     164             : ! **************************************************************************************************
     165     1054898 :    SUBROUTINE cp_fm_scale_and_add(alpha, matrix_a, beta, matrix_b)
     166             : 
     167             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
     168             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
     169             :       REAL(KIND=dp), INTENT(in), OPTIONAL                :: beta
     170             :       TYPE(cp_fm_type), INTENT(IN), OPTIONAL             :: matrix_b
     171             : 
     172             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_scale_and_add'
     173             : 
     174             :       INTEGER                                            :: handle, size_a, size_b
     175             :       REAL(KIND=dp)                                      :: my_beta
     176     1054898 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b
     177     1054898 :       REAL(KIND=sp), DIMENSION(:, :), POINTER            :: a_sp, b_sp
     178             : 
     179     1054898 :       CALL timeset(routineN, handle)
     180             : 
     181     1054898 :       IF (PRESENT(matrix_b)) THEN
     182     1054898 :          my_beta = 1.0_dp
     183             :       ELSE
     184           0 :          my_beta = 0.0_dp
     185             :       END IF
     186     1054898 :       IF (PRESENT(beta)) my_beta = beta
     187     1054898 :       NULLIFY (a, b)
     188             : 
     189     1054898 :       IF (PRESENT(beta)) THEN
     190     1054898 :          CPASSERT(PRESENT(matrix_b))
     191     1054898 :          IF (ASSOCIATED(matrix_a%local_data, matrix_b%local_data)) THEN
     192           0 :             CPWARN("Bad use of routine. Call cp_fm_scale instead")
     193           0 :             CALL cp_fm_scale(alpha + beta, matrix_a)
     194           0 :             CALL timestop(handle)
     195           0 :             RETURN
     196             :          END IF
     197             :       END IF
     198             : 
     199     1054898 :       a => matrix_a%local_data
     200     1054898 :       a_sp => matrix_a%local_data_sp
     201             : 
     202     1054898 :       IF (matrix_a%use_sp) THEN
     203           0 :          size_a = SIZE(a_sp, 1)*SIZE(a_sp, 2)
     204             :       ELSE
     205     1054898 :          size_a = SIZE(a, 1)*SIZE(a, 2)
     206             :       END IF
     207             : 
     208     1054898 :       IF (alpha /= 1.0_dp) THEN
     209       78616 :          IF (matrix_a%use_sp) THEN
     210           0 :             CALL sscal(size_a, REAL(alpha, sp), a_sp, 1)
     211             :          ELSE
     212       78616 :             CALL dscal(size_a, alpha, a, 1)
     213             :          END IF
     214             :       END IF
     215     1054898 :       IF (my_beta .NE. 0.0_dp) THEN
     216     1045466 :          IF (matrix_a%matrix_struct%context /= matrix_b%matrix_struct%context) &
     217           0 :             CPABORT("matrixes must be in the same blacs context")
     218             : 
     219     1045466 :          IF (cp_fm_struct_equivalent(matrix_a%matrix_struct, &
     220             :                                      matrix_b%matrix_struct)) THEN
     221             : 
     222     1045466 :             b => matrix_b%local_data
     223     1045466 :             b_sp => matrix_b%local_data_sp
     224     1045466 :             IF (matrix_b%use_sp) THEN
     225           0 :                size_b = SIZE(b_sp, 1)*SIZE(b_sp, 2)
     226             :             ELSE
     227     1045466 :                size_b = SIZE(b, 1)*SIZE(b, 2)
     228             :             END IF
     229     1045466 :             IF (size_a /= size_b) &
     230           0 :                CPABORT("Matrixes must have same locale sizes")
     231             : 
     232     1045466 :             IF (matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     233           0 :                CALL saxpy(size_a, REAL(my_beta, sp), b_sp, 1, a_sp, 1)
     234     1045466 :             ELSEIF (matrix_a%use_sp .AND. .NOT. matrix_b%use_sp) THEN
     235           0 :                CALL saxpy(size_a, REAL(my_beta, sp), REAL(b, sp), 1, a_sp, 1)
     236     1045466 :             ELSEIF (.NOT. matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     237           0 :                CALL daxpy(size_a, my_beta, REAL(b_sp, dp), 1, a, 1)
     238             :             ELSE
     239     1045466 :                CALL daxpy(size_a, my_beta, b, 1, a, 1)
     240             :             END IF
     241             : 
     242             :          ELSE
     243             : #ifdef __parallel
     244           0 :             CPABORT("to do (pdscal,pdcopy,pdaxpy)")
     245             : #else
     246             :             CPABORT("")
     247             : #endif
     248             :          END IF
     249             : 
     250             :       END IF
     251             : 
     252     1054898 :       CALL timestop(handle)
     253             : 
     254     1054898 :    END SUBROUTINE cp_fm_scale_and_add
     255             : 
     256             : ! **************************************************************************************************
     257             : !> \brief interface to BLACS geadd:
     258             : !>                matrix_b = beta*matrix_b + alpha*opt(matrix_a)
     259             : !>        where opt(matrix_a) can be either:
     260             : !>              'N':  matrix_a
     261             : !>              'T':  matrix_a^T
     262             : !>              'C':  matrix_a^H (Hermitian conjugate)
     263             : !>        note that this is a level three routine, use cp_fm_scale_and_add if that
     264             : !>        is sufficient for your needs
     265             : !> \param alpha  : complex scalar
     266             : !> \param trans  : 'N' normal, 'T' transposed
     267             : !> \param matrix_a : input matrix_a
     268             : !> \param beta   : complex scalar
     269             : !> \param matrix_b : input matrix_b, upon out put the updated matrix_b
     270             : !> \author  Lianheng Tong
     271             : ! **************************************************************************************************
     272          96 :    SUBROUTINE cp_fm_geadd(alpha, trans, matrix_a, beta, matrix_b)
     273             :       REAL(KIND=dp), INTENT(IN) :: alpha, beta
     274             :       CHARACTER, INTENT(IN) :: trans
     275             :       TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
     276             : 
     277             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_geadd'
     278             : 
     279             :       INTEGER :: nrow_global, ncol_global, handle
     280             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa, bb
     281             : #if defined(__parallel)
     282             :       INTEGER, DIMENSION(9) :: desca, descb
     283             : #else
     284             :       INTEGER :: ii, jj
     285             : #endif
     286             : 
     287          96 :       CALL timeset(routineN, handle)
     288             : 
     289          96 :       nrow_global = matrix_a%matrix_struct%nrow_global
     290          96 :       ncol_global = matrix_a%matrix_struct%ncol_global
     291          96 :       CPASSERT(nrow_global .EQ. matrix_b%matrix_struct%nrow_global)
     292          96 :       CPASSERT(ncol_global .EQ. matrix_b%matrix_struct%ncol_global)
     293             : 
     294          96 :       aa => matrix_a%local_data
     295          96 :       bb => matrix_b%local_data
     296             : 
     297             : #if defined(__parallel)
     298         960 :       desca = matrix_a%matrix_struct%descriptor
     299         960 :       descb = matrix_b%matrix_struct%descriptor
     300             :       CALL pdgeadd(trans, &
     301             :                    nrow_global, &
     302             :                    ncol_global, &
     303             :                    alpha, &
     304             :                    aa, &
     305             :                    1, 1, &
     306             :                    desca, &
     307             :                    beta, &
     308             :                    bb, &
     309             :                    1, 1, &
     310          96 :                    descb)
     311             : #else
     312             :       ! dgeadd is not a standard BLAS function, although is implemented
     313             :       ! in some libraries like OpenBLAS, so not going to use it here
     314             :       SELECT CASE (trans)
     315             :       CASE ('T')
     316             :          DO jj = 1, ncol_global
     317             :             DO ii = 1, nrow_global
     318             :                bb(ii, jj) = beta*bb(ii, jj) + alpha*aa(jj, ii)
     319             :             END DO
     320             :          END DO
     321             :       CASE DEFAULT
     322             :          DO jj = 1, ncol_global
     323             :             DO ii = 1, nrow_global
     324             :                bb(ii, jj) = beta*bb(ii, jj) + alpha*aa(ii, jj)
     325             :             END DO
     326             :          END DO
     327             :       END SELECT
     328             : #endif
     329             : 
     330          96 :       CALL timestop(handle)
     331             : 
     332          96 :    END SUBROUTINE cp_fm_geadd
     333             : 
     334             : ! **************************************************************************************************
     335             : !> \brief Computes the LU-decomposition of the matrix, and the determinant of the matrix
     336             : !>      IMPORTANT : the sign of the determinant is not defined correctly yet ....
     337             : !> \param matrix_a ...
     338             : !> \param almost_determinant ...
     339             : !> \param correct_sign ...
     340             : !> \par History
     341             : !>      added correct_sign 02.07 (fschiff)
     342             : !> \author Joost VandeVondele
     343             : !> \note
     344             : !>      - matrix_a is overwritten
     345             : !>      - the sign of the determinant might be wrong
     346             : !>      - SERIOUS WARNING (KNOWN BUG) : the sign of the determinant depends on ipivot
     347             : !>      - one should be able to find out if ipivot is an even or an odd permutation...
     348             : !>        if you need the correct sign, just add correct_sign==.TRUE. (fschiff)
     349             : !>      - Use cp_fm_get_diag instead of n times cp_fm_get_element (A. Bussy)
     350             : ! **************************************************************************************************
     351           0 :    SUBROUTINE cp_fm_lu_decompose(matrix_a, almost_determinant, correct_sign)
     352             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a
     353             :       REAL(KIND=dp), INTENT(OUT)               :: almost_determinant
     354             :       LOGICAL, INTENT(IN), OPTIONAL            :: correct_sign
     355             : 
     356             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_lu_decompose'
     357             : 
     358             :       INTEGER                                  :: handle, i, info, n
     359           0 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
     360             :       REAL(KIND=dp)                            :: determinant
     361             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     362             : #if defined(__parallel)
     363             :       INTEGER, DIMENSION(9)                    :: desca
     364           0 :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
     365             : #else
     366             :       INTEGER                                  :: lda
     367             : #endif
     368             : 
     369           0 :       CALL timeset(routineN, handle)
     370             : 
     371           0 :       a => matrix_a%local_data
     372           0 :       n = matrix_a%matrix_struct%nrow_global
     373           0 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
     374             : 
     375             : #if defined(__parallel)
     376             :       MARK_USED(correct_sign)
     377           0 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     378           0 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
     379             : 
     380           0 :       ALLOCATE (diag(n))
     381           0 :       diag(:) = 0.0_dp
     382           0 :       CALL cp_fm_get_diag(matrix_a, diag)
     383           0 :       determinant = 1.0_dp
     384           0 :       DO i = 1, n
     385           0 :          determinant = determinant*diag(i)
     386             :       END DO
     387           0 :       DEALLOCATE (diag)
     388             : #else
     389             :       lda = SIZE(a, 1)
     390             :       CALL dgetrf(n, n, a, lda, ipivot, info)
     391             :       determinant = 1.0_dp
     392             :       IF (correct_sign) THEN
     393             :          DO i = 1, n
     394             :             IF (ipivot(i) .NE. i) THEN
     395             :                determinant = -determinant*a(i, i)
     396             :             ELSE
     397             :                determinant = determinant*a(i, i)
     398             :             END IF
     399             :          END DO
     400             :       ELSE
     401             :          DO i = 1, n
     402             :             determinant = determinant*a(i, i)
     403             :          END DO
     404             :       END IF
     405             : #endif
     406             :       ! info is allowed to be zero
     407             :       ! this does just signal a zero diagonal element
     408           0 :       DEALLOCATE (ipivot)
     409           0 :       almost_determinant = determinant ! notice that the sign is random
     410           0 :       CALL timestop(handle)
     411           0 :    END SUBROUTINE
     412             : 
     413             : ! **************************************************************************************************
     414             : !> \brief computes matrix_c = beta * matrix_c + alpha * ( matrix_a  ** transa ) * ( matrix_b ** transb )
     415             : !> \param transa : 'N' -> normal   'T' -> transpose
     416             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
     417             : !> \param transb ...
     418             : !> \param m ...
     419             : !> \param n ...
     420             : !> \param k ...
     421             : !> \param alpha ...
     422             : !> \param matrix_a : m x k matrix ( ! for transa = 'N')
     423             : !> \param matrix_b : k x n matrix ( ! for transb = 'N')
     424             : !> \param beta ...
     425             : !> \param matrix_c : m x n matrix
     426             : !> \param a_first_col ...
     427             : !> \param a_first_row ...
     428             : !> \param b_first_col : the k x n matrix starts at col b_first_col of matrix_b (avoid usage)
     429             : !> \param b_first_row ...
     430             : !> \param c_first_col ...
     431             : !> \param c_first_row ...
     432             : !> \author Matthias Krack
     433             : !> \note
     434             : !>      matrix_c should have no overlap with matrix_a, matrix_b
     435             : ! **************************************************************************************************
     436         514 :    SUBROUTINE cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
     437             :                          matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
     438             :                          c_first_col, c_first_row)
     439             : 
     440             :       CHARACTER(LEN=1), INTENT(IN)             :: transa, transb
     441             :       INTEGER, INTENT(IN)                      :: m, n, k
     442             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     443             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a, matrix_b
     444             :       REAL(KIND=dp), INTENT(IN)                :: beta
     445             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     446             :       INTEGER, INTENT(IN), OPTIONAL            :: a_first_col, a_first_row, &
     447             :                                                   b_first_col, b_first_row, &
     448             :                                                   c_first_col, c_first_row
     449             : 
     450             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_gemm'
     451             : 
     452             :       INTEGER                                  :: handle, i_a, i_b, i_c, j_a, &
     453             :                                                   j_b, j_c
     454             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, b, c
     455         514 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp, b_sp, c_sp
     456             : #if defined(__parallel)
     457             :       INTEGER, DIMENSION(9)                    :: desca, descb, descc
     458             : #else
     459             :       INTEGER                                  :: lda, ldb, ldc
     460             : #endif
     461             : 
     462         514 :       CALL timeset(routineN, handle)
     463             : 
     464             :       !sample peak memory
     465         514 :       CALL m_memory()
     466             : 
     467         514 :       a => matrix_a%local_data
     468         514 :       b => matrix_b%local_data
     469         514 :       c => matrix_c%local_data
     470             : 
     471         514 :       a_sp => matrix_a%local_data_sp
     472         514 :       b_sp => matrix_b%local_data_sp
     473         514 :       c_sp => matrix_c%local_data_sp
     474             : 
     475         514 :       IF (PRESENT(a_first_row)) THEN
     476           0 :          i_a = a_first_row
     477             :       ELSE
     478         514 :          i_a = 1
     479             :       END IF
     480         514 :       IF (PRESENT(a_first_col)) THEN
     481           0 :          j_a = a_first_col
     482             :       ELSE
     483         514 :          j_a = 1
     484             :       END IF
     485         514 :       IF (PRESENT(b_first_row)) THEN
     486           0 :          i_b = b_first_row
     487             :       ELSE
     488         514 :          i_b = 1
     489             :       END IF
     490         514 :       IF (PRESENT(b_first_col)) THEN
     491           0 :          j_b = b_first_col
     492             :       ELSE
     493         514 :          j_b = 1
     494             :       END IF
     495         514 :       IF (PRESENT(c_first_row)) THEN
     496           0 :          i_c = c_first_row
     497             :       ELSE
     498         514 :          i_c = 1
     499             :       END IF
     500         514 :       IF (PRESENT(c_first_col)) THEN
     501           0 :          j_c = c_first_col
     502             :       ELSE
     503         514 :          j_c = 1
     504             :       END IF
     505             : 
     506             : #if defined(__parallel)
     507             : 
     508        5140 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     509        5140 :       descb(:) = matrix_b%matrix_struct%descriptor(:)
     510        5140 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     511             : 
     512         514 :       IF (matrix_a%use_sp .AND. matrix_b%use_sp .AND. matrix_c%use_sp) THEN
     513             : 
     514             :          CALL psgemm(transa, transb, m, n, k, REAL(alpha, sp), a_sp(1, 1), i_a, j_a, desca, b_sp(1, 1), i_b, j_b, &
     515           0 :                      descb, REAL(beta, sp), c_sp(1, 1), i_c, j_c, descc)
     516             : 
     517         514 :       ELSEIF ((.NOT. matrix_a%use_sp) .AND. (.NOT. matrix_b%use_sp) .AND. (.NOT. matrix_c%use_sp)) THEN
     518             : 
     519             :          CALL pdgemm(transa, transb, m, n, k, alpha, a, i_a, j_a, desca, b, i_b, j_b, &
     520         514 :                      descb, beta, c, i_c, j_c, descc)
     521             : 
     522             :       ELSE
     523           0 :          CPABORT("Mixed precision gemm NYI")
     524             :       END IF
     525             : #else
     526             : 
     527             :       IF (matrix_a%use_sp .AND. matrix_b%use_sp .AND. matrix_c%use_sp) THEN
     528             : 
     529             :          lda = SIZE(a_sp, 1)
     530             :          ldb = SIZE(b_sp, 1)
     531             :          ldc = SIZE(c_sp, 1)
     532             : 
     533             :          CALL sgemm(transa, transb, m, n, k, REAL(alpha, sp), a_sp(i_a, j_a), lda, b_sp(i_b, j_b), ldb, &
     534             :                     REAL(beta, sp), c_sp(i_c, j_c), ldc)
     535             : 
     536             :       ELSEIF ((.NOT. matrix_a%use_sp) .AND. (.NOT. matrix_b%use_sp) .AND. (.NOT. matrix_c%use_sp)) THEN
     537             : 
     538             :          lda = SIZE(a, 1)
     539             :          ldb = SIZE(b, 1)
     540             :          ldc = SIZE(c, 1)
     541             : 
     542             :          CALL dgemm(transa, transb, m, n, k, alpha, a(i_a, j_a), lda, b(i_b, j_b), ldb, beta, c(i_c, j_c), ldc)
     543             : 
     544             :       ELSE
     545             :          CPABORT("Mixed precision gemm NYI")
     546             :       END IF
     547             : 
     548             : #endif
     549         514 :       CALL timestop(handle)
     550             : 
     551         514 :    END SUBROUTINE cp_fm_gemm
     552             : 
     553             : ! **************************************************************************************************
     554             : !> \brief computes matrix_c = beta * matrix_c + alpha *  matrix_a  *  matrix_b
     555             : !>      computes matrix_c = beta * matrix_c + alpha *  matrix_b  *  matrix_a
     556             : !>      where matrix_a is symmetric
     557             : !> \param side : 'L' -> matrix_a is on the left 'R' -> matrix_a is on the right
     558             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
     559             : !> \param uplo ...
     560             : !> \param m ...
     561             : !> \param n ...
     562             : !> \param alpha ...
     563             : !> \param matrix_a : m x m matrix
     564             : !> \param matrix_b : m x n matrix
     565             : !> \param beta ...
     566             : !> \param matrix_c : m x n matrix
     567             : !> \author Matthias Krack
     568             : !> \note
     569             : !>      matrix_c should have no overlap with matrix_a, matrix_b
     570             : !>      all matrices in QS are upper triangular, so uplo should be 'U' always
     571             : !>      matrix_a is always an m x m matrix
     572             : !>      it is typically slower to do cp_fm_symm than cp_fm_gemm (especially in parallel easily 50 percent !)
     573             : ! **************************************************************************************************
     574      142356 :    SUBROUTINE cp_fm_symm(side, uplo, m, n, alpha, matrix_a, matrix_b, beta, matrix_c)
     575             : 
     576             :       CHARACTER(LEN=1), INTENT(IN)             :: side, uplo
     577             :       INTEGER, INTENT(IN)                      :: m, n
     578             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     579             :       TYPE(cp_fm_type), INTENT(IN)                :: matrix_a, matrix_b
     580             :       REAL(KIND=dp), INTENT(IN)                :: beta
     581             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     582             : 
     583             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_symm'
     584             : 
     585             :       INTEGER                                  :: handle
     586      142356 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, b, c
     587             : #if defined(__parallel)
     588             :       INTEGER, DIMENSION(9)                    :: desca, descb, descc
     589             : #else
     590             :       INTEGER                                  :: lda, ldb, ldc
     591             : #endif
     592             : 
     593      142356 :       CALL timeset(routineN, handle)
     594             : 
     595      142356 :       a => matrix_a%local_data
     596      142356 :       b => matrix_b%local_data
     597      142356 :       c => matrix_c%local_data
     598             : 
     599             : #if defined(__parallel)
     600             : 
     601     1423560 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     602     1423560 :       descb(:) = matrix_b%matrix_struct%descriptor(:)
     603     1423560 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     604             : 
     605      142356 :       CALL pdsymm(side, uplo, m, n, alpha, a(1, 1), 1, 1, desca, b(1, 1), 1, 1, descb, beta, c(1, 1), 1, 1, descc)
     606             : 
     607             : #else
     608             : 
     609             :       lda = matrix_a%matrix_struct%local_leading_dimension
     610             :       ldb = matrix_b%matrix_struct%local_leading_dimension
     611             :       ldc = matrix_c%matrix_struct%local_leading_dimension
     612             : 
     613             :       CALL dsymm(side, uplo, m, n, alpha, a(1, 1), lda, b(1, 1), ldb, beta, c(1, 1), ldc)
     614             : 
     615             : #endif
     616      142356 :       CALL timestop(handle)
     617             : 
     618      142356 :    END SUBROUTINE cp_fm_symm
     619             : 
     620             : ! **************************************************************************************************
     621             : !> \brief computes the Frobenius norm of matrix_a
     622             : !> \brief computes the Frobenius norm of matrix_a
     623             : !> \param matrix_a : m x n matrix
     624             : !> \return ...
     625             : !> \author VW
     626             : ! **************************************************************************************************
     627        8030 :    FUNCTION cp_fm_frobenius_norm(matrix_a) RESULT(norm)
     628             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
     629             :       REAL(KIND=dp)                            :: norm
     630             : 
     631             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_frobenius_norm'
     632             : 
     633             :       INTEGER                                  :: handle, size_a
     634        8030 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
     635             :       REAL(KIND=dp), EXTERNAL                  :: DDOT
     636             : #if defined(__parallel)
     637             :       TYPE(mp_comm_type)                       :: group
     638             : #endif
     639             : 
     640        8030 :       CALL timeset(routineN, handle)
     641             : 
     642             :       norm = 0.0_dp
     643        8030 :       a => matrix_a%local_data
     644        8030 :       size_a = SIZE(a, 1)*SIZE(a, 2)
     645        8030 :       norm = DDOT(size_a, a(1, 1), 1, a(1, 1), 1)
     646             : #if defined(__parallel)
     647        8030 :       group = matrix_a%matrix_struct%para_env
     648        8030 :       CALL group%sum(norm)
     649             : #endif
     650        8030 :       norm = SQRT(norm)
     651             : 
     652        8030 :       CALL timestop(handle)
     653             : 
     654        8030 :    END FUNCTION cp_fm_frobenius_norm
     655             : 
     656             : ! **************************************************************************************************
     657             : !> \brief performs a rank-k update of a symmetric matrix_c
     658             : !>         matrix_c = beta * matrix_c + alpha * matrix_a * transpose ( matrix_a )
     659             : !> \param uplo : 'U'   ('L')
     660             : !> \param trans : 'N'  ('T')
     661             : !> \param k : number of cols to use in matrix_a
     662             : !>      ia,ja ::  1,1 (could be used for selecting subblock of a)
     663             : !> \param alpha ...
     664             : !> \param matrix_a ...
     665             : !> \param ia ...
     666             : !> \param ja ...
     667             : !> \param beta ...
     668             : !> \param matrix_c ...
     669             : !> \author Matthias Krack
     670             : !> \note
     671             : !>      In QS uplo should 'U' (upper part updated)
     672             : ! **************************************************************************************************
     673        6296 :    SUBROUTINE cp_fm_syrk(uplo, trans, k, alpha, matrix_a, ia, ja, beta, matrix_c)
     674             :       CHARACTER(LEN=1), INTENT(IN)             :: uplo, trans
     675             :       INTEGER, INTENT(IN)                      :: k
     676             :       REAL(KIND=dp), INTENT(IN)                :: alpha
     677             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix_a
     678             :       INTEGER, INTENT(IN)                      :: ia, ja
     679             :       REAL(KIND=dp), INTENT(IN)                :: beta
     680             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_c
     681             : 
     682             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_syrk'
     683             : 
     684             :       INTEGER                                  :: handle, n
     685        6296 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, c
     686             : #if defined(__parallel)
     687             :       INTEGER, DIMENSION(9)                    :: desca, descc
     688             : #else
     689             :       INTEGER                                  :: lda, ldc
     690             : #endif
     691             : 
     692        6296 :       CALL timeset(routineN, handle)
     693             : 
     694        6296 :       n = matrix_c%matrix_struct%nrow_global
     695             : 
     696        6296 :       a => matrix_a%local_data
     697        6296 :       c => matrix_c%local_data
     698             : 
     699             : #if defined(__parallel)
     700             : 
     701       62960 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
     702       62960 :       descc(:) = matrix_c%matrix_struct%descriptor(:)
     703             : 
     704        6296 :       CALL pdsyrk(uplo, trans, n, k, alpha, a(1, 1), ia, ja, desca, beta, c(1, 1), 1, 1, descc)
     705             : 
     706             : #else
     707             : 
     708             :       lda = SIZE(a, 1)
     709             :       ldc = SIZE(c, 1)
     710             : 
     711             :       CALL dsyrk(uplo, trans, n, k, alpha, a(ia, ja), lda, beta, c(1, 1), ldc)
     712             : 
     713             : #endif
     714        6296 :       CALL timestop(handle)
     715             : 
     716        6296 :    END SUBROUTINE cp_fm_syrk
     717             : 
     718             : ! **************************************************************************************************
     719             : !> \brief computes the schur product of two matrices
     720             : !>       c_ij = a_ij * b_ij
     721             : !> \param matrix_a ...
     722             : !> \param matrix_b ...
     723             : !> \param matrix_c ...
     724             : !> \author Joost VandeVondele
     725             : ! **************************************************************************************************
     726        9190 :    SUBROUTINE cp_fm_schur_product(matrix_a, matrix_b, matrix_c)
     727             : 
     728             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b, matrix_c
     729             : 
     730             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_schur_product'
     731             : 
     732             :       INTEGER                                            :: handle, icol_local, irow_local, mypcol, &
     733             :                                                             myprow, ncol_local, npcol, nprow, &
     734             :                                                             nrow_local
     735        9190 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b, c
     736             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     737             : 
     738        9190 :       CALL timeset(routineN, handle)
     739             : 
     740        9190 :       context => matrix_a%matrix_struct%context
     741        9190 :       myprow = context%mepos(1)
     742        9190 :       mypcol = context%mepos(2)
     743        9190 :       nprow = context%num_pe(1)
     744        9190 :       npcol = context%num_pe(2)
     745             : 
     746        9190 :       a => matrix_a%local_data
     747        9190 :       b => matrix_b%local_data
     748        9190 :       c => matrix_c%local_data
     749             : 
     750        9190 :       nrow_local = matrix_a%matrix_struct%nrow_locals(myprow)
     751        9190 :       ncol_local = matrix_a%matrix_struct%ncol_locals(mypcol)
     752             : 
     753       99952 :       DO icol_local = 1, ncol_local
     754     6860227 :          DO irow_local = 1, nrow_local
     755     6851037 :             c(irow_local, icol_local) = a(irow_local, icol_local)*b(irow_local, icol_local)
     756             :          END DO
     757             :       END DO
     758             : 
     759        9190 :       CALL timestop(handle)
     760             : 
     761        9190 :    END SUBROUTINE cp_fm_schur_product
     762             : 
     763             : ! **************************************************************************************************
     764             : !> \brief returns the trace of matrix_a^T matrix_b, i.e
     765             : !>      sum_{i,j}(matrix_a(i,j)*matrix_b(i,j))
     766             : !> \param matrix_a a matrix
     767             : !> \param matrix_b another matrix
     768             : !> \param trace ...
     769             : !> \par History
     770             : !>      11.06.2001 Creation (Matthias Krack)
     771             : !>      12.2002 added doc [fawzi]
     772             : !> \author Matthias Krack
     773             : !> \note
     774             : !>      note the transposition of matrix_a!
     775             : ! **************************************************************************************************
     776      684132 :    SUBROUTINE cp_fm_trace_a0b0t0(matrix_a, matrix_b, trace)
     777             : 
     778             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
     779             :       REAL(KIND=dp), INTENT(OUT)                         :: trace
     780             : 
     781             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a0b0t0'
     782             : 
     783             :       INTEGER                                            :: handle, mypcol, myprow, ncol_local, &
     784             :                                                             npcol, nprow, nrow_local
     785      684132 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a, b
     786      684132 :       REAL(KIND=sp), DIMENSION(:, :), POINTER            :: a_sp, b_sp
     787             :       TYPE(cp_blacs_env_type), POINTER                   :: context
     788             :       TYPE(mp_comm_type)                                 :: group
     789             : 
     790      684132 :       CALL timeset(routineN, handle)
     791             : 
     792      684132 :       context => matrix_a%matrix_struct%context
     793      684132 :       myprow = context%mepos(1)
     794      684132 :       mypcol = context%mepos(2)
     795      684132 :       nprow = context%num_pe(1)
     796      684132 :       npcol = context%num_pe(2)
     797             : 
     798      684132 :       group = matrix_a%matrix_struct%para_env
     799             : 
     800      684132 :       a => matrix_a%local_data
     801      684132 :       b => matrix_b%local_data
     802             : 
     803      684132 :       a_sp => matrix_a%local_data_sp
     804      684132 :       b_sp => matrix_b%local_data_sp
     805             : 
     806      684132 :       nrow_local = MIN(matrix_a%matrix_struct%nrow_locals(myprow), matrix_b%matrix_struct%nrow_locals(myprow))
     807      684132 :       ncol_local = MIN(matrix_a%matrix_struct%ncol_locals(mypcol), matrix_b%matrix_struct%ncol_locals(mypcol))
     808             : 
     809             :       ! cries for an accurate_dot_product
     810      684132 :       IF (matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     811             :          trace = accurate_sum(REAL(a_sp(1:nrow_local, 1:ncol_local)* &
     812           0 :                                    b_sp(1:nrow_local, 1:ncol_local), dp))
     813      684132 :       ELSEIF (matrix_a%use_sp .AND. .NOT. matrix_b%use_sp) THEN
     814             :          trace = accurate_sum(REAL(a_sp(1:nrow_local, 1:ncol_local), dp)* &
     815           0 :                               b(1:nrow_local, 1:ncol_local))
     816      684132 :       ELSEIF (.NOT. matrix_a%use_sp .AND. matrix_b%use_sp) THEN
     817             :          trace = accurate_sum(a(1:nrow_local, 1:ncol_local)* &
     818           0 :                               REAL(b_sp(1:nrow_local, 1:ncol_local), dp))
     819             :       ELSE
     820             :          trace = accurate_dot_product(a(1:nrow_local, 1:ncol_local), &
     821      684132 :                                       b(1:nrow_local, 1:ncol_local))
     822             :       END IF
     823             : 
     824      684132 :       CALL group%sum(trace)
     825             : 
     826      684132 :       CALL timestop(handle)
     827             : 
     828      684132 :    END SUBROUTINE cp_fm_trace_a0b0t0
     829             : 
     830             :    #:mute
     831             :       #:set types = [("cp_fm_type", "a", ""), ("cp_fm_p_type", "p","%matrix")]
     832             :    #:endmute
     833             : 
     834             : ! **************************************************************************************************
     835             : !> \brief Compute trace(k) = Tr (matrix_a(k)^T matrix_b) for each pair of matrices A_k and B.
     836             : !> \param matrix_a list of A matrices
     837             : !> \param matrix_b B matrix
     838             : !> \param trace    computed traces
     839             : !> \par History
     840             : !>    * 08.2018 forked from cp_fm_trace() [Sergey Chulkov]
     841             : !> \note \parblock
     842             : !>      Computing the trace requires collective communication between involved MPI processes
     843             : !>      that implies a synchronisation point between them. The aim of this subroutine is to reduce
     844             : !>      the amount of time wasted in such synchronisation by performing one large collective
     845             : !>      operation which involves all the matrices in question.
     846             : !>
     847             : !>      The subroutine's suffix reflects dimensionality of dummy arrays; 'a1b0t1' means that
     848             : !>      the dummy variables 'matrix_a' and 'trace' are 1-dimensional arrays, while the variable
     849             : !>      'matrix_b' is a single matrix.
     850             : !>      \endparblock
     851             : ! **************************************************************************************************
     852             :    #:for longname, shortname, appendix in types
     853        3030 :       SUBROUTINE cp_fm_trace_a1b0t1_${shortname}$ (matrix_a, matrix_b, trace)
     854             :          TYPE(${longname}$), DIMENSION(:), INTENT(in)       :: matrix_a
     855             :          TYPE(cp_fm_type), INTENT(IN)                       :: matrix_b
     856             :          REAL(kind=dp), DIMENSION(:), INTENT(out)           :: trace
     857             : 
     858             :          CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a1b0t1_${shortname}$'
     859             : 
     860             :          INTEGER                                            :: handle, imatrix, n_matrices, &
     861             :                                                                ncols_local, nrows_local
     862             :          LOGICAL                                            :: use_sp_a, use_sp_b
     863        3030 :          REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
     864        3030 :          REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
     865             :          TYPE(mp_comm_type)                                 :: group
     866             : 
     867        3030 :          CALL timeset(routineN, handle)
     868             : 
     869        3030 :          n_matrices = SIZE(trace)
     870        3030 :          CPASSERT(SIZE(matrix_a) == n_matrices)
     871             : 
     872        3030 :          CALL cp_fm_get_info(matrix_b, nrow_local=nrows_local, ncol_local=ncols_local)
     873        3030 :          use_sp_b = matrix_b%use_sp
     874             : 
     875        3030 :          IF (use_sp_b) THEN
     876           0 :             ldata_b_sp => matrix_b%local_data_sp(1:nrows_local, 1:ncols_local)
     877             :          ELSE
     878        3030 :             ldata_b => matrix_b%local_data(1:nrows_local, 1:ncols_local)
     879             :          END IF
     880             : 
     881             : !$OMP PARALLEL DO DEFAULT(NONE), &
     882             : !$OMP             PRIVATE(imatrix, ldata_a, ldata_a_sp, use_sp_a), &
     883             : !$OMP             SHARED(ldata_b, ldata_b_sp, matrix_a, matrix_b), &
     884        3030 : !$OMP             SHARED(ncols_local, nrows_local, n_matrices, trace, use_sp_b)
     885             : 
     886             :          DO imatrix = 1, n_matrices
     887             : 
     888             :             use_sp_a = matrix_a(imatrix) ${appendix}$%use_sp
     889             : 
     890             :             ! assume that the matrices A(i) and B have identical shapes and distribution schemes
     891             :             IF (use_sp_a .AND. use_sp_b) THEN
     892             :                ldata_a_sp => matrix_a(imatrix) ${appendix}$%local_data_sp(1:nrows_local, 1:ncols_local)
     893             :                trace(imatrix) = accurate_dot_product(ldata_a_sp, ldata_b_sp)
     894             :             ELSE IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
     895             :                ldata_a => matrix_a(imatrix) ${appendix}$%local_data(1:nrows_local, 1:ncols_local)
     896             :                trace(imatrix) = accurate_dot_product(ldata_a, ldata_b)
     897             :             ELSE
     898             :                CPABORT("Matrices A and B are of different types")
     899             :             END IF
     900             :          END DO
     901             : !$OMP END PARALLEL DO
     902             : 
     903        3030 :          group = matrix_b%matrix_struct%para_env
     904       18882 :          CALL group%sum(trace)
     905             : 
     906        3030 :          CALL timestop(handle)
     907        3030 :       END SUBROUTINE cp_fm_trace_a1b0t1_${shortname}$
     908             :    #:endfor
     909             : 
     910             : ! **************************************************************************************************
     911             : !> \brief Compute trace(k) = Tr (matrix_a(k)^T matrix_b(k)) for each pair of matrices A_k and B_k.
     912             : !> \param matrix_a list of A matrices
     913             : !> \param matrix_b list of B matrices
     914             : !> \param trace    computed traces
     915             : !> \param accurate ...
     916             : !> \par History
     917             : !>    * 11.2016 forked from cp_fm_trace() [Sergey Chulkov]
     918             : !> \note \parblock
     919             : !>      Computing the trace requires collective communication between involved MPI processes
     920             : !>      that implies a synchronisation point between them. The aim of this subroutine is to reduce
     921             : !>      the amount of time wasted in such synchronisation by performing one large collective
     922             : !>      operation which involves all the matrices in question.
     923             : !>
     924             : !>      The subroutine's suffix reflects dimensionality of dummy arrays; 'a1b1t1' means that
     925             : !>      all dummy variables (matrix_a, matrix_b, and trace) are 1-dimensional arrays.
     926             : !>      \endparblock
     927             : ! **************************************************************************************************
     928             :    #:for longname1, shortname1, appendix1 in types
     929             :       #:for longname2, shortname2, appendix2 in types
     930      138648 :          SUBROUTINE cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$ (matrix_a, matrix_b, trace, accurate)
     931             :             TYPE(${longname1}$), DIMENSION(:), INTENT(in)       :: matrix_a
     932             :             TYPE(${longname2}$), DIMENSION(:), INTENT(in)       :: matrix_b
     933             :             REAL(kind=dp), DIMENSION(:), INTENT(out)           :: trace
     934             :             LOGICAL, INTENT(IN), OPTIONAL                      :: accurate
     935             : 
     936             :             CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$'
     937             : 
     938             :             INTEGER                                            :: handle, imatrix, n_matrices, &
     939             :                                                                   ncols_local, nrows_local
     940             :             LOGICAL                                            :: use_accurate_sum, use_sp_a, use_sp_b
     941      138648 :             REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
     942      138648 :             REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
     943             :             TYPE(mp_comm_type)                                 :: group
     944             : 
     945      138648 :             CALL timeset(routineN, handle)
     946             : 
     947      138648 :             n_matrices = SIZE(trace)
     948      138648 :             CPASSERT(SIZE(matrix_a) == n_matrices)
     949      138648 :             CPASSERT(SIZE(matrix_b) == n_matrices)
     950             : 
     951      138648 :             use_accurate_sum = .TRUE.
     952      138648 :             IF (PRESENT(accurate)) use_accurate_sum = accurate
     953             : 
     954             : !$OMP PARALLEL DO DEFAULT(NONE), &
     955             : !$OMP             PRIVATE(imatrix, ldata_a, ldata_a_sp, ldata_b, ldata_b_sp, ncols_local), &
     956             : !$OMP             PRIVATE(nrows_local, use_sp_a, use_sp_b), &
     957      138648 : !$OMP             SHARED(matrix_a, matrix_b, n_matrices, trace, use_accurate_sum)
     958             :             DO imatrix = 1, n_matrices
     959             :                CALL cp_fm_get_info(matrix_a(imatrix) ${appendix1}$, nrow_local=nrows_local, ncol_local=ncols_local)
     960             : 
     961             :                use_sp_a = matrix_a(imatrix) ${appendix1}$%use_sp
     962             :                use_sp_b = matrix_b(imatrix) ${appendix2}$%use_sp
     963             : 
     964             :                ! assume that the matrices A(i) and B(i) have identical shapes and distribution schemes
     965             :                IF (use_sp_a .AND. use_sp_b) THEN
     966             :                   ldata_a_sp => matrix_a(imatrix) ${appendix1}$%local_data_sp(1:nrows_local, 1:ncols_local)
     967             :                   ldata_b_sp => matrix_b(imatrix) ${appendix2}$%local_data_sp(1:nrows_local, 1:ncols_local)
     968             :                   IF (use_accurate_sum) THEN
     969             :                      trace(imatrix) = accurate_dot_product(ldata_a_sp, ldata_b_sp)
     970             :                   ELSE
     971             :                      trace(imatrix) = SUM(ldata_a_sp*ldata_b_sp)
     972             :                   END IF
     973             :                ELSE IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
     974             :                   ldata_a => matrix_a(imatrix) ${appendix1}$%local_data(1:nrows_local, 1:ncols_local)
     975             :                   ldata_b => matrix_b(imatrix) ${appendix2}$%local_data(1:nrows_local, 1:ncols_local)
     976             :                   IF (use_accurate_sum) THEN
     977             :                      trace(imatrix) = accurate_dot_product(ldata_a, ldata_b)
     978             :                   ELSE
     979             :                      trace(imatrix) = SUM(ldata_a*ldata_b)
     980             :                   END IF
     981             :                ELSE
     982             :                   CPABORT("Matrices A and B are of different types")
     983             :                END IF
     984             :             END DO
     985             : !$OMP END PARALLEL DO
     986             : 
     987      138648 :             group = matrix_a(1) ${appendix1}$%matrix_struct%para_env
     988      460424 :             CALL group%sum(trace)
     989             : 
     990      138648 :             CALL timestop(handle)
     991      138648 :          END SUBROUTINE cp_fm_trace_a1b1t1_${shortname1}$${shortname2}$
     992             :       #:endfor
     993             :    #:endfor
     994             : 
     995             : ! **************************************************************************************************
     996             : !> \brief Compute trace(i,j) = \sum_k Tr (matrix_a(k,i)^T matrix_b(k,j)).
     997             : !> \param matrix_a list of A matrices
     998             : !> \param matrix_b list of B matrices
     999             : !> \param trace    computed traces
    1000             : !> \param accurate ...
    1001             : ! **************************************************************************************************
    1002             :    #:for longname1, shortname1, appendix1 in types
    1003             :       #:for longname2, shortname2, appendix2 in types
    1004       13816 :          SUBROUTINE cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$ (matrix_a, matrix_b, trace, accurate)
    1005             :             TYPE(${longname1}$), DIMENSION(:, :), INTENT(in)       :: matrix_a
    1006             :             TYPE(${longname2}$), DIMENSION(:, :), INTENT(in)       :: matrix_b
    1007             :             REAL(kind=dp), DIMENSION(:, :), INTENT(out)        :: trace
    1008             :             LOGICAL, INTENT(IN), OPTIONAL                      :: accurate
    1009             : 
    1010             :             CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$'
    1011             : 
    1012             :             INTEGER                                            :: handle, ia, ib, iz, na, nb, ncols_local, &
    1013             :                                                                   nrows_local, nz
    1014             :             INTEGER(kind=int_8)                                :: ib8, itrace, na8, ntraces
    1015             :             LOGICAL                                            :: use_accurate_sum, use_sp_a, use_sp_b
    1016             :             REAL(kind=dp)                                      :: t
    1017       13816 :             REAL(kind=dp), DIMENSION(:, :), POINTER            :: ldata_a, ldata_b
    1018       13816 :             REAL(kind=sp), DIMENSION(:, :), POINTER            :: ldata_a_sp, ldata_b_sp
    1019             :             TYPE(mp_comm_type)                                 :: group
    1020             : 
    1021       13816 :             CALL timeset(routineN, handle)
    1022             : 
    1023       13816 :             nz = SIZE(matrix_a, 1)
    1024       13816 :             CPASSERT(SIZE(matrix_b, 1) == nz)
    1025             : 
    1026       13816 :             na = SIZE(matrix_a, 2)
    1027       13816 :             nb = SIZE(matrix_b, 2)
    1028       13816 :             CPASSERT(SIZE(trace, 1) == na)
    1029       13816 :             CPASSERT(SIZE(trace, 2) == nb)
    1030             : 
    1031       13816 :             use_accurate_sum = .TRUE.
    1032       13816 :             IF (PRESENT(accurate)) use_accurate_sum = accurate
    1033             : 
    1034             :             ! here we use one running index (itrace) instead of two (ia, ib) in order to
    1035             :             ! improve load balance between shared-memory threads
    1036       13816 :             ntraces = na*nb
    1037       13816 :             na8 = INT(na, kind=int_8)
    1038             : 
    1039             : !$OMP PARALLEL DO DEFAULT(NONE), &
    1040             : !$OMP             PRIVATE(ia, ib, ib8, itrace, iz, ldata_a, ldata_a_sp, ldata_b, ldata_b_sp, ncols_local), &
    1041             : !$OMP             PRIVATE(nrows_local, t, use_sp_a, use_sp_b), &
    1042       13816 : !$OMP             SHARED(matrix_a, matrix_b, na, na8, nb, ntraces, nz, trace, use_accurate_sum)
    1043             :             DO itrace = 1, ntraces
    1044             :                ib8 = (itrace - 1)/na8
    1045             :                ia = INT(itrace - ib8*na8)
    1046             :                ib = INT(ib8) + 1
    1047             : 
    1048             :                t = 0.0_dp
    1049             :                DO iz = 1, nz
    1050             :                   CALL cp_fm_get_info(matrix_a(iz, ia) ${appendix1}$, nrow_local=nrows_local, ncol_local=ncols_local)
    1051             :                   use_sp_a = matrix_a(iz, ia) ${appendix1}$%use_sp
    1052             :                   use_sp_b = matrix_b(iz, ib) ${appendix2}$%use_sp
    1053             : 
    1054             :                   ! assume that the matrices A(iz, ia) and B(iz, ib) have identical shapes and distribution schemes
    1055             :                   IF (.NOT. use_sp_a .AND. .NOT. use_sp_b) THEN
    1056             :                      ldata_a => matrix_a(iz, ia) ${appendix1}$%local_data(1:nrows_local, 1:ncols_local)
    1057             :                      ldata_b => matrix_b(iz, ib) ${appendix2}$%local_data(1:nrows_local, 1:ncols_local)
    1058             :                      IF (use_accurate_sum) THEN
    1059             :                         t = t + accurate_dot_product(ldata_a, ldata_b)
    1060             :                      ELSE
    1061             :                         t = t + SUM(ldata_a*ldata_b)
    1062             :                      END IF
    1063             :                   ELSE IF (use_sp_a .AND. use_sp_b) THEN
    1064             :                      ldata_a_sp => matrix_a(iz, ia) ${appendix1}$%local_data_sp(1:nrows_local, 1:ncols_local)
    1065             :                      ldata_b_sp => matrix_b(iz, ib) ${appendix2}$%local_data_sp(1:nrows_local, 1:ncols_local)
    1066             :                      IF (use_accurate_sum) THEN
    1067             :                         t = t + accurate_dot_product(ldata_a_sp, ldata_b_sp)
    1068             :                      ELSE
    1069             :                         t = t + SUM(ldata_a_sp*ldata_b_sp)
    1070             :                      END IF
    1071             :                   ELSE
    1072             :                      CPABORT("Matrices A and B are of different types")
    1073             :                   END IF
    1074             :                END DO
    1075             :                trace(ia, ib) = t
    1076             :             END DO
    1077             : !$OMP END PARALLEL DO
    1078             : 
    1079       13816 :             group = matrix_a(1, 1) ${appendix1}$%matrix_struct%para_env
    1080      617500 :             CALL group%sum(trace)
    1081             : 
    1082       13816 :             CALL timestop(handle)
    1083       13816 :          END SUBROUTINE cp_fm_contracted_trace_a2b2t2_${shortname1}$${shortname2}$
    1084             :       #:endfor
    1085             :    #:endfor
    1086             : 
    1087             : ! **************************************************************************************************
    1088             : !> \brief multiplies in place by a triangular matrix:
    1089             : !>       matrix_b = alpha op(triangular_matrix) matrix_b
    1090             : !>      or (if side='R')
    1091             : !>       matrix_b = alpha matrix_b op(triangular_matrix)
    1092             : !>      op(triangular_matrix) is:
    1093             : !>       triangular_matrix (if transpose_tr=.false. and invert_tr=.false.)
    1094             : !>       triangular_matrix^T (if transpose_tr=.true. and invert_tr=.false.)
    1095             : !>       triangular_matrix^(-1) (if transpose_tr=.false. and invert_tr=.true.)
    1096             : !>       triangular_matrix^(-T) (if transpose_tr=.true. and invert_tr=.true.)
    1097             : !> \param triangular_matrix the triangular matrix that multiplies the other
    1098             : !> \param matrix_b the matrix that gets multiplied and stores the result
    1099             : !> \param side on which side of matrix_b stays op(triangular_matrix)
    1100             : !>        (defaults to 'L')
    1101             : !> \param transpose_tr if the triangular matrix should be transposed
    1102             : !>        (defaults to false)
    1103             : !> \param invert_tr if the triangular matrix should be inverted
    1104             : !>        (defaults to false)
    1105             : !> \param uplo_tr if triangular_matrix is stored in the upper ('U') or
    1106             : !>        lower ('L') triangle (defaults to 'U')
    1107             : !> \param unit_diag_tr if the diagonal elements of triangular_matrix should
    1108             : !>        be assumed to be 1 (defaults to false)
    1109             : !> \param n_rows the number of rows of the result (defaults to
    1110             : !>        size(matrix_b,1))
    1111             : !> \param n_cols the number of columns of the result (defaults to
    1112             : !>        size(matrix_b,2))
    1113             : !> \param alpha ...
    1114             : !> \par History
    1115             : !>      08.2002 created [fawzi]
    1116             : !> \author Fawzi Mohamed
    1117             : !> \note
    1118             : !>      needs an mpi env
    1119             : ! **************************************************************************************************
    1120      101622 :    SUBROUTINE cp_fm_triangular_multiply(triangular_matrix, matrix_b, side, &
    1121             :                                         transpose_tr, invert_tr, uplo_tr, unit_diag_tr, n_rows, n_cols, &
    1122             :                                         alpha)
    1123             :       TYPE(cp_fm_type), INTENT(IN)                       :: triangular_matrix, matrix_b
    1124             :       CHARACTER, INTENT(in), OPTIONAL                    :: side
    1125             :       LOGICAL, INTENT(in), OPTIONAL                      :: transpose_tr, invert_tr
    1126             :       CHARACTER, INTENT(in), OPTIONAL                    :: uplo_tr
    1127             :       LOGICAL, INTENT(in), OPTIONAL                      :: unit_diag_tr
    1128             :       INTEGER, INTENT(in), OPTIONAL                      :: n_rows, n_cols
    1129             :       REAL(KIND=dp), INTENT(in), OPTIONAL                :: alpha
    1130             : 
    1131             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_triangular_multiply'
    1132             : 
    1133             :       CHARACTER                                          :: side_char, transa, unit_diag, uplo
    1134             :       INTEGER                                            :: handle, m, n
    1135             :       LOGICAL                                            :: invert
    1136             :       REAL(KIND=dp)                                      :: al
    1137             : 
    1138       50811 :       CALL timeset(routineN, handle)
    1139       50811 :       side_char = 'L'
    1140       50811 :       unit_diag = 'N'
    1141       50811 :       uplo = 'U'
    1142       50811 :       transa = 'N'
    1143       50811 :       invert = .FALSE.
    1144       50811 :       al = 1.0_dp
    1145       50811 :       CALL cp_fm_get_info(matrix_b, nrow_global=m, ncol_global=n)
    1146       50811 :       IF (PRESENT(side)) side_char = side
    1147       50811 :       IF (PRESENT(invert_tr)) invert = invert_tr
    1148       50811 :       IF (PRESENT(uplo_tr)) uplo = uplo_tr
    1149       50811 :       IF (PRESENT(unit_diag_tr)) THEN
    1150           0 :          IF (unit_diag_tr) THEN
    1151           0 :             unit_diag = 'U'
    1152             :          ELSE
    1153             :             unit_diag = 'N'
    1154             :          END IF
    1155             :       END IF
    1156       50811 :       IF (PRESENT(transpose_tr)) THEN
    1157        3438 :          IF (transpose_tr) THEN
    1158        1246 :             transa = 'T'
    1159             :          ELSE
    1160             :             transa = 'N'
    1161             :          END IF
    1162             :       END IF
    1163       50811 :       IF (PRESENT(alpha)) al = alpha
    1164       50811 :       IF (PRESENT(n_rows)) m = n_rows
    1165       50811 :       IF (PRESENT(n_cols)) n = n_cols
    1166             : 
    1167       50811 :       IF (invert) THEN
    1168             : 
    1169             : #if defined(__parallel)
    1170             :          CALL pdtrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1171             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1172             :                      triangular_matrix%matrix_struct%descriptor, &
    1173             :                      matrix_b%local_data(1, 1), 1, 1, &
    1174       41659 :                      matrix_b%matrix_struct%descriptor(1))
    1175             : #else
    1176             :          CALL dtrsm(side_char, uplo, transa, unit_diag, m, n, al, &
    1177             :                     triangular_matrix%local_data(1, 1), &
    1178             :                     SIZE(triangular_matrix%local_data, 1), &
    1179             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1180             : #endif
    1181             : 
    1182             :       ELSE
    1183             : 
    1184             : #if defined(__parallel)
    1185             :          CALL pdtrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1186             :                      triangular_matrix%local_data(1, 1), 1, 1, &
    1187             :                      triangular_matrix%matrix_struct%descriptor, &
    1188             :                      matrix_b%local_data(1, 1), 1, 1, &
    1189        9152 :                      matrix_b%matrix_struct%descriptor(1))
    1190             : #else
    1191             :          CALL dtrmm(side_char, uplo, transa, unit_diag, m, n, al, &
    1192             :                     triangular_matrix%local_data(1, 1), &
    1193             :                     SIZE(triangular_matrix%local_data, 1), &
    1194             :                     matrix_b%local_data(1, 1), SIZE(matrix_b%local_data, 1))
    1195             : #endif
    1196             : 
    1197             :       END IF
    1198             : 
    1199       50811 :       CALL timestop(handle)
    1200       50811 :    END SUBROUTINE cp_fm_triangular_multiply
    1201             : 
    1202             : ! **************************************************************************************************
    1203             : !> \brief scales a matrix
    1204             : !>      matrix_a = alpha * matrix_b
    1205             : !> \param alpha ...
    1206             : !> \param matrix_a ...
    1207             : !> \note
    1208             : !>      use cp_fm_set_all to zero (avoids problems with nan)
    1209             : ! **************************************************************************************************
    1210       83415 :    SUBROUTINE cp_fm_scale(alpha, matrix_a)
    1211             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
    1212             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
    1213             : 
    1214             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'cp_fm_scale'
    1215             : 
    1216             :       INTEGER                                            :: handle, size_a
    1217             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    1218             : 
    1219       83415 :       CALL timeset(routineN, handle)
    1220             : 
    1221             :       NULLIFY (a)
    1222             : 
    1223       83415 :       a => matrix_a%local_data
    1224       83415 :       size_a = SIZE(a, 1)*SIZE(a, 2)
    1225             : 
    1226       83415 :       CALL DSCAL(size_a, alpha, a, 1)
    1227             : 
    1228       83415 :       CALL timestop(handle)
    1229             : 
    1230       83415 :    END SUBROUTINE cp_fm_scale
    1231             : 
    1232             : ! **************************************************************************************************
    1233             : !> \brief transposes a matrix
    1234             : !>      matrixt = matrix ^ T
    1235             : !> \param matrix ...
    1236             : !> \param matrixt ...
    1237             : !> \note
    1238             : !>      all matrix elements are transposed (see cp_fm_upper_to_half to symmetrise a matrix)
    1239             : ! **************************************************************************************************
    1240       19106 :    SUBROUTINE cp_fm_transpose(matrix, matrixt)
    1241             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix, matrixt
    1242             : 
    1243             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_transpose'
    1244             : 
    1245             :       INTEGER                                  :: handle, ncol_global, &
    1246             :                                                   nrow_global, ncol_globalt, nrow_globalt
    1247        9553 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, c
    1248             : #if defined(__parallel)
    1249             :       INTEGER, DIMENSION(9)                    :: desca, descc
    1250             : #else
    1251             :       INTEGER                                  :: i, j
    1252             : #endif
    1253             : 
    1254        9553 :       nrow_global = matrix%matrix_struct%nrow_global
    1255        9553 :       ncol_global = matrix%matrix_struct%ncol_global
    1256        9553 :       nrow_globalt = matrixt%matrix_struct%nrow_global
    1257        9553 :       ncol_globalt = matrixt%matrix_struct%ncol_global
    1258           0 :       CPASSERT(nrow_global == ncol_globalt)
    1259        9553 :       CPASSERT(nrow_globalt == ncol_global)
    1260             : 
    1261        9553 :       CALL timeset(routineN, handle)
    1262             : 
    1263        9553 :       a => matrix%local_data
    1264        9553 :       c => matrixt%local_data
    1265             : 
    1266             : #if defined(__parallel)
    1267       95530 :       desca(:) = matrix%matrix_struct%descriptor(:)
    1268       95530 :       descc(:) = matrixt%matrix_struct%descriptor(:)
    1269        9553 :       CALL pdtran(ncol_global, nrow_global, 1.0_dp, a(1, 1), 1, 1, desca, 0.0_dp, c(1, 1), 1, 1, descc)
    1270             : #else
    1271             :       DO j = 1, ncol_global
    1272             :          DO i = 1, nrow_global
    1273             :             c(j, i) = a(i, j)
    1274             :          END DO
    1275             :       END DO
    1276             : #endif
    1277        9553 :       CALL timestop(handle)
    1278             : 
    1279        9553 :    END SUBROUTINE cp_fm_transpose
    1280             : 
    1281             : ! **************************************************************************************************
    1282             : !> \brief given an upper triangular matrix computes the corresponding full matrix
    1283             : !> \param matrix the upper triangular matrix as input, the full matrix as output
    1284             : !> \param work a matrix of the same size as matrix
    1285             : !> \author Matthias Krack
    1286             : !> \note
    1287             : !>       the lower triangular part is irrelevant
    1288             : ! **************************************************************************************************
    1289      302006 :    SUBROUTINE cp_fm_upper_to_full(matrix, work)
    1290             : 
    1291             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix, work
    1292             : 
    1293             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_upper_to_full'
    1294             : 
    1295             :       INTEGER                                  :: handle, icol_global, irow_global, &
    1296             :                                                   mypcol, myprow, ncol_global, &
    1297             :                                                   npcol, nprow, nrow_global
    1298      151003 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: a
    1299      151003 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: a_sp
    1300             :       TYPE(cp_blacs_env_type), POINTER         :: context
    1301             : 
    1302             : #if defined(__parallel)
    1303             :       INTEGER                                  :: icol_local, irow_local, &
    1304             :                                                   ncol_block, ncol_local, &
    1305             :                                                   nrow_block, nrow_local
    1306             :       INTEGER, DIMENSION(9)                    :: desca, descc
    1307      151003 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: c
    1308      151003 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: c_sp
    1309             : #endif
    1310             : 
    1311      151003 :       nrow_global = matrix%matrix_struct%nrow_global
    1312      151003 :       ncol_global = matrix%matrix_struct%ncol_global
    1313           0 :       CPASSERT(nrow_global == ncol_global)
    1314      151003 :       nrow_global = work%matrix_struct%nrow_global
    1315      151003 :       ncol_global = work%matrix_struct%ncol_global
    1316      151003 :       CPASSERT(nrow_global == ncol_global)
    1317      151003 :       CPASSERT(matrix%use_sp .EQV. work%use_sp)
    1318             : 
    1319      151003 :       CALL timeset(routineN, handle)
    1320             : 
    1321      151003 :       context => matrix%matrix_struct%context
    1322      151003 :       myprow = context%mepos(1)
    1323      151003 :       mypcol = context%mepos(2)
    1324      151003 :       nprow = context%num_pe(1)
    1325      151003 :       npcol = context%num_pe(2)
    1326             : 
    1327             : #if defined(__parallel)
    1328             : 
    1329      151003 :       nrow_block = matrix%matrix_struct%nrow_block
    1330      151003 :       ncol_block = matrix%matrix_struct%ncol_block
    1331             : 
    1332      151003 :       nrow_local = matrix%matrix_struct%nrow_locals(myprow)
    1333      151003 :       ncol_local = matrix%matrix_struct%ncol_locals(mypcol)
    1334             : 
    1335      151003 :       a => work%local_data
    1336      151003 :       a_sp => work%local_data_sp
    1337     1510030 :       desca(:) = work%matrix_struct%descriptor(:)
    1338      151003 :       c => matrix%local_data
    1339      151003 :       c_sp => matrix%local_data_sp
    1340     1510030 :       descc(:) = matrix%matrix_struct%descriptor(:)
    1341             : 
    1342     4926454 :       DO icol_local = 1, ncol_local
    1343     4775451 :          icol_global = matrix%matrix_struct%col_indices(icol_local)
    1344   226632643 :          DO irow_local = 1, nrow_local
    1345   221706189 :             irow_global = matrix%matrix_struct%row_indices(irow_local)
    1346   226481640 :             IF (irow_global > icol_global) THEN
    1347   109410969 :                IF (matrix%use_sp) THEN
    1348           0 :                   c_sp(irow_local, icol_local) = 0.0_sp
    1349             :                ELSE
    1350   109410969 :                   c(irow_local, icol_local) = 0.0_dp
    1351             :                END IF
    1352   112295220 :             ELSE IF (irow_global == icol_global) THEN
    1353     2884251 :                IF (matrix%use_sp) THEN
    1354           0 :                   c_sp(irow_local, icol_local) = 0.5_sp*c_sp(irow_local, icol_local)
    1355             :                ELSE
    1356     2884251 :                   c(irow_local, icol_local) = 0.5_dp*c(irow_local, icol_local)
    1357             :                END IF
    1358             :             END IF
    1359             :          END DO
    1360             :       END DO
    1361             : 
    1362     4926454 :       DO icol_local = 1, ncol_local
    1363   226632643 :       DO irow_local = 1, nrow_local
    1364   226481640 :          IF (matrix%use_sp) THEN
    1365           0 :             a_sp(irow_local, icol_local) = c_sp(irow_local, icol_local)
    1366             :          ELSE
    1367   221706189 :             a(irow_local, icol_local) = c(irow_local, icol_local)
    1368             :          END IF
    1369             :       END DO
    1370             :       END DO
    1371             : 
    1372      151003 :       IF (matrix%use_sp) THEN
    1373           0 :          CALL pstran(nrow_global, ncol_global, 1.0_sp, a_sp(1, 1), 1, 1, desca, 1.0_sp, c_sp(1, 1), 1, 1, descc)
    1374             :       ELSE
    1375      151003 :          CALL pdtran(nrow_global, ncol_global, 1.0_dp, a(1, 1), 1, 1, desca, 1.0_dp, c(1, 1), 1, 1, descc)
    1376             :       END IF
    1377             : 
    1378             : #else
    1379             : 
    1380             :       a => matrix%local_data
    1381             :       a_sp => matrix%local_data_sp
    1382             :       DO irow_global = 1, nrow_global
    1383             :          DO icol_global = irow_global + 1, ncol_global
    1384             :             IF (matrix%use_sp) THEN
    1385             :                a_sp(icol_global, irow_global) = a_sp(irow_global, icol_global)
    1386             :             ELSE
    1387             :                a(icol_global, irow_global) = a(irow_global, icol_global)
    1388             :             END IF
    1389             :          END DO
    1390             :       END DO
    1391             : 
    1392             : #endif
    1393      151003 :       CALL timestop(handle)
    1394             : 
    1395      151003 :    END SUBROUTINE cp_fm_upper_to_full
    1396             : 
    1397             : ! **************************************************************************************************
    1398             : !> \brief scales column i of matrix a with scaling(i)
    1399             : !> \param matrixa ...
    1400             : !> \param scaling : an array used for scaling the columns,
    1401             : !>                  SIZE(scaling) determines the number of columns to be scaled
    1402             : !> \author Joost VandeVondele
    1403             : !> \note
    1404             : !>      this is very useful as a first step in the computation of C = sum_i alpha_i A_i transpose (A_i)
    1405             : !>      that is a rank-k update (cp_fm_syrk , cp_sm_plus_fm_fm_t)
    1406             : !>      this procedure can be up to 20 times faster than calling cp_fm_syrk n times
    1407             : !>      where every vector has a different prefactor
    1408             : ! **************************************************************************************************
    1409      125568 :    SUBROUTINE cp_fm_column_scale(matrixa, scaling)
    1410             :       TYPE(cp_fm_type), INTENT(IN)          :: matrixa
    1411             :       REAL(KIND=dp), DIMENSION(:), INTENT(in)  :: scaling
    1412             : 
    1413             :       INTEGER                                  :: k, mypcol, myprow, n, ncol_global, &
    1414             :                                                   npcol, nprow
    1415      125568 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1416      125568 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    1417             : #if defined(__parallel)
    1418             :       INTEGER                                  :: icol_global, icol_local, &
    1419             :                                                   ipcol, iprow, irow_local
    1420             : #else
    1421             :       INTEGER                                  :: i
    1422             : #endif
    1423             : 
    1424      125568 :       myprow = matrixa%matrix_struct%context%mepos(1)
    1425      125568 :       mypcol = matrixa%matrix_struct%context%mepos(2)
    1426      125568 :       nprow = matrixa%matrix_struct%context%num_pe(1)
    1427      125568 :       npcol = matrixa%matrix_struct%context%num_pe(2)
    1428             : 
    1429      125568 :       ncol_global = matrixa%matrix_struct%ncol_global
    1430             : 
    1431      125568 :       a => matrixa%local_data
    1432      125568 :       a_sp => matrixa%local_data_sp
    1433      125568 :       IF (matrixa%use_sp) THEN
    1434           0 :          n = SIZE(a_sp, 1)
    1435             :       ELSE
    1436      125568 :          n = SIZE(a, 1)
    1437             :       END IF
    1438      125568 :       k = MIN(SIZE(scaling), ncol_global)
    1439             : 
    1440             : #if defined(__parallel)
    1441             : 
    1442     1878749 :       DO icol_global = 1, k
    1443             :          CALL infog2l(1, icol_global, matrixa%matrix_struct%descriptor, &
    1444             :                       nprow, npcol, myprow, mypcol, &
    1445     1753181 :                       irow_local, icol_local, iprow, ipcol)
    1446     1878749 :          IF ((ipcol == mypcol)) THEN
    1447     1753181 :             IF (matrixa%use_sp) THEN
    1448           0 :                CALL SSCAL(n, REAL(scaling(icol_global), sp), a_sp(:, icol_local), 1)
    1449             :             ELSE
    1450     1753181 :                CALL DSCAL(n, scaling(icol_global), a(:, icol_local), 1)
    1451             :             END IF
    1452             :          END IF
    1453             :       END DO
    1454             : #else
    1455             :       DO i = 1, k
    1456             :          IF (matrixa%use_sp) THEN
    1457             :             CALL SSCAL(n, REAL(scaling(i), sp), a_sp(:, i), 1)
    1458             :          ELSE
    1459             :             CALL DSCAL(n, scaling(i), a(:, i), 1)
    1460             :          END IF
    1461             :       END DO
    1462             : #endif
    1463      125568 :    END SUBROUTINE cp_fm_column_scale
    1464             : 
    1465             : ! **************************************************************************************************
    1466             : !> \brief scales row i of matrix a with scaling(i)
    1467             : !> \param matrixa ...
    1468             : !> \param scaling : an array used for scaling the columns,
    1469             : !> \author JGH
    1470             : !> \note
    1471             : ! **************************************************************************************************
    1472        6564 :    SUBROUTINE cp_fm_row_scale(matrixa, scaling)
    1473             :       TYPE(cp_fm_type), INTENT(IN)          :: matrixa
    1474             :       REAL(KIND=dp), DIMENSION(:), INTENT(in)  :: scaling
    1475             : 
    1476             :       INTEGER                                  :: n, m, nrow_global, nrow_local, ncol_local
    1477        6564 :       INTEGER, DIMENSION(:), POINTER           :: row_indices
    1478        6564 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1479        6564 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    1480             : #if defined(__parallel)
    1481             :       INTEGER                                  :: irow_global, icol, irow
    1482             : #else
    1483             :       INTEGER                                  :: j
    1484             : #endif
    1485             : 
    1486             :       CALL cp_fm_get_info(matrixa, row_indices=row_indices, nrow_global=nrow_global, &
    1487        6564 :                           nrow_local=nrow_local, ncol_local=ncol_local)
    1488        6564 :       CPASSERT(SIZE(scaling) == nrow_global)
    1489             : 
    1490        6564 :       a => matrixa%local_data
    1491        6564 :       a_sp => matrixa%local_data_sp
    1492        6564 :       IF (matrixa%use_sp) THEN
    1493        6564 :          n = SIZE(a_sp, 1)
    1494        6564 :          m = SIZE(a_sp, 2)
    1495             :       ELSE
    1496        6564 :          n = SIZE(a, 1)
    1497        6564 :          m = SIZE(a, 2)
    1498             :       END IF
    1499             : 
    1500             : #if defined(__parallel)
    1501       81426 :       DO icol = 1, ncol_local
    1502       81426 :          IF (matrixa%use_sp) THEN
    1503           0 :             DO irow = 1, nrow_local
    1504           0 :                irow_global = row_indices(irow)
    1505           0 :                a(irow, icol) = REAL(scaling(irow_global), dp)*a(irow, icol)
    1506             :             END DO
    1507             :          ELSE
    1508     6667421 :             DO irow = 1, nrow_local
    1509     6592559 :                irow_global = row_indices(irow)
    1510     6667421 :                a(irow, icol) = scaling(irow_global)*a(irow, icol)
    1511             :             END DO
    1512             :          END IF
    1513             :       END DO
    1514             : #else
    1515             :       IF (matrixa%use_sp) THEN
    1516             :          DO j = 1, m
    1517             :             a_sp(1:n, j) = REAL(scaling(1:n), sp)*a_sp(1:n, j)
    1518             :          END DO
    1519             :       ELSE
    1520             :          DO j = 1, m
    1521             :             a(1:n, j) = scaling(1:n)*a(1:n, j)
    1522             :          END DO
    1523             :       END IF
    1524             : #endif
    1525        6564 :    END SUBROUTINE cp_fm_row_scale
    1526             : ! **************************************************************************************************
    1527             : !> \brief Inverts a cp_fm_type matrix, optionally returning the determinant of the input matrix
    1528             : !> \param matrix_a the matrix to invert
    1529             : !> \param matrix_inverse the inverse of matrix_a
    1530             : !> \param det_a the determinant of matrix_a
    1531             : !> \param eps_svd optional parameter to active SVD based inversion, singular values below eps_svd
    1532             : !>                are screened
    1533             : !> \param eigval optionally return matrix eigenvalues/singular values
    1534             : !> \par History
    1535             : !>      note of Jan Wilhelm (12.2015)
    1536             : !>      - computation of determinant corrected
    1537             : !>      - determinant only computed if det_a is present
    1538             : !>      12.2016 added option to use SVD instead of LU [Nico Holmberg]
    1539             : !>      - Use cp_fm_get diag instead of n times cp_fm_get_element (A. Bussy)
    1540             : !> \author Florian Schiffmann(02.2007)
    1541             : ! **************************************************************************************************
    1542         702 :    SUBROUTINE cp_fm_invert(matrix_a, matrix_inverse, det_a, eps_svd, eigval)
    1543             : 
    1544             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, matrix_inverse
    1545             :       REAL(KIND=dp), INTENT(OUT), OPTIONAL     :: det_a
    1546             :       REAL(KIND=dp), INTENT(IN), OPTIONAL      :: eps_svd
    1547             :       REAL(KIND=dp), DIMENSION(:), POINTER, &
    1548             :          INTENT(INOUT), OPTIONAL               :: eigval
    1549             : 
    1550             :       INTEGER                                  :: n
    1551         702 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
    1552             :       REAL(KIND=dp)                            :: determinant, my_eps_svd
    1553             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1554             :       TYPE(cp_fm_type)                :: matrix_lu
    1555             : 
    1556             : #if defined(__parallel)
    1557             :       TYPE(cp_fm_type)                :: u, vt, sigma, inv_sigma_ut
    1558             :       TYPE(mp_comm_type) :: group
    1559             :       INTEGER                                  :: i, info, liwork, lwork, exponent_of_minus_one
    1560             :       INTEGER, DIMENSION(9)                    :: desca
    1561             :       LOGICAL                                  :: quenched
    1562             :       REAL(KIND=dp)                            :: alpha, beta
    1563         702 :       REAL(KIND=dp), DIMENSION(:), POINTER     :: diag
    1564         702 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: work
    1565             : #else
    1566             :       LOGICAL                                  :: sign
    1567             :       REAL(KIND=dp)                            :: eps1
    1568             : #endif
    1569             : 
    1570         702 :       my_eps_svd = 0.0_dp
    1571         468 :       IF (PRESENT(eps_svd)) my_eps_svd = eps_svd
    1572             : 
    1573             :       CALL cp_fm_create(matrix=matrix_lu, &
    1574             :                         matrix_struct=matrix_a%matrix_struct, &
    1575         702 :                         name="A_lu"//TRIM(ADJUSTL(cp_to_string(1)))//"MATRIX")
    1576         702 :       CALL cp_fm_to_fm(matrix_a, matrix_lu)
    1577             : 
    1578         702 :       a => matrix_lu%local_data
    1579         702 :       n = matrix_lu%matrix_struct%nrow_global
    1580        2106 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
    1581       10934 :       ipivot(:) = 0
    1582             : #if defined(__parallel)
    1583         702 :       IF (my_eps_svd .EQ. 0.0_dp) THEN
    1584             :          ! Use LU decomposition
    1585         674 :          lwork = 3*n
    1586         674 :          liwork = 3*n
    1587        6740 :          desca(:) = matrix_lu%matrix_struct%descriptor(:)
    1588         674 :          CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
    1589             : 
    1590         674 :          IF (PRESENT(det_a) .OR. PRESENT(eigval)) THEN
    1591             : 
    1592        1116 :             ALLOCATE (diag(n))
    1593         916 :             diag(:) = 0.0_dp
    1594         372 :             CALL cp_fm_get_diag(matrix_lu, diag)
    1595             : 
    1596         372 :             exponent_of_minus_one = 0
    1597         372 :             determinant = 1.0_dp
    1598         916 :             DO i = 1, n
    1599         544 :                determinant = determinant*diag(i)
    1600         916 :                IF (ipivot(i) .NE. i) THEN
    1601         224 :                   exponent_of_minus_one = exponent_of_minus_one + 1
    1602             :                END IF
    1603             :             END DO
    1604         372 :             IF (PRESENT(eigval)) THEN
    1605           0 :                CPASSERT(.NOT. ASSOCIATED(eigval))
    1606           0 :                ALLOCATE (eigval(n))
    1607           0 :                eigval(:) = diag
    1608             :             END IF
    1609         372 :             DEALLOCATE (diag)
    1610             : 
    1611         372 :             group = matrix_lu%matrix_struct%para_env
    1612         372 :             CALL group%sum(exponent_of_minus_one)
    1613             : 
    1614         372 :             determinant = determinant*(-1.0_dp)**exponent_of_minus_one
    1615             : 
    1616             :          END IF
    1617             : 
    1618         674 :          alpha = 0.0_dp
    1619         674 :          beta = 1.0_dp
    1620         674 :          CALL cp_fm_set_all(matrix_inverse, alpha, beta)
    1621         674 :          CALL pdgetrs('N', n, n, matrix_lu%local_data, 1, 1, desca, ipivot, matrix_inverse%local_data, 1, 1, desca, info)
    1622             :       ELSE
    1623             :          ! Use singular value decomposition
    1624             :          CALL cp_fm_create(matrix=u, &
    1625             :                            matrix_struct=matrix_a%matrix_struct, &
    1626          28 :                            name="LEFT_SINGULAR_MATRIX")
    1627          28 :          CALL cp_fm_set_all(u, alpha=0.0_dp)
    1628             :          CALL cp_fm_create(matrix=vt, &
    1629             :                            matrix_struct=matrix_a%matrix_struct, &
    1630          28 :                            name="RIGHT_SINGULAR_MATRIX")
    1631          28 :          CALL cp_fm_set_all(vt, alpha=0.0_dp)
    1632          84 :          ALLOCATE (diag(n))
    1633          92 :          diag(:) = 0.0_dp
    1634         280 :          desca(:) = matrix_lu%matrix_struct%descriptor(:)
    1635          28 :          ALLOCATE (work(1))
    1636             :          ! Workspace query
    1637          28 :          lwork = -1
    1638             :          CALL pdgesvd('V', 'V', n, n, matrix_lu%local_data, 1, 1, desca, diag, u%local_data, &
    1639          28 :                       1, 1, desca, vt%local_data, 1, 1, desca, work, lwork, info)
    1640          28 :          lwork = INT(work(1))
    1641          28 :          DEALLOCATE (work)
    1642          84 :          ALLOCATE (work(lwork))
    1643             :          ! SVD
    1644             :          CALL pdgesvd('V', 'V', n, n, matrix_lu%local_data, 1, 1, desca, diag, u%local_data, &
    1645          28 :                       1, 1, desca, vt%local_data, 1, 1, desca, work, lwork, info)
    1646             :          ! info == n+1 implies homogeneity error when the number of procs is large
    1647             :          ! this likely isnt a problem, but maybe we should handle it separately
    1648          28 :          IF (info /= 0 .AND. info /= n + 1) &
    1649           0 :             CPABORT("Singular value decomposition of matrix failed.")
    1650             :          ! (Pseudo)inverse and (pseudo)determinant
    1651             :          CALL cp_fm_create(matrix=sigma, &
    1652             :                            matrix_struct=matrix_a%matrix_struct, &
    1653          28 :                            name="SINGULAR_VALUE_MATRIX")
    1654          28 :          CALL cp_fm_set_all(sigma, alpha=0.0_dp)
    1655          28 :          determinant = 1.0_dp
    1656          28 :          quenched = .FALSE.
    1657          28 :          IF (PRESENT(eigval)) THEN
    1658          28 :             CPASSERT(.NOT. ASSOCIATED(eigval))
    1659          84 :             ALLOCATE (eigval(n))
    1660         156 :             eigval(:) = diag
    1661             :          END IF
    1662          92 :          DO i = 1, n
    1663          64 :             IF (diag(i) < my_eps_svd) THEN
    1664          18 :                diag(i) = 0.0_dp
    1665          18 :                quenched = .TRUE.
    1666             :             ELSE
    1667          46 :                determinant = determinant*diag(i)
    1668          46 :                diag(i) = 1.0_dp/diag(i)
    1669             :             END IF
    1670          92 :             CALL cp_fm_set_element(sigma, i, i, diag(i))
    1671             :          END DO
    1672          28 :          DEALLOCATE (diag)
    1673          28 :          IF (quenched) &
    1674             :             CALL cp_warn(__LOCATION__, &
    1675             :                          "Linear dependencies were detected in the SVD inversion of matrix "//TRIM(ADJUSTL(matrix_a%name))// &
    1676          12 :                          ". At least one singular value has been quenched.")
    1677             :          ! Sigma^-1 * U^T
    1678             :          CALL cp_fm_create(matrix=inv_sigma_ut, &
    1679             :                            matrix_struct=matrix_a%matrix_struct, &
    1680          28 :                            name="SINGULAR_VALUE_MATRIX")
    1681          28 :          CALL cp_fm_set_all(inv_sigma_ut, alpha=0.0_dp)
    1682             :          CALL pdgemm('N', 'T', n, n, n, 1.0_dp, sigma%local_data, 1, 1, desca, &
    1683          28 :                      u%local_data, 1, 1, desca, 0.0_dp, inv_sigma_ut%local_data, 1, 1, desca)
    1684             :          ! A^-1 = V * (Sigma^-1 * U^T)
    1685          28 :          CALL cp_fm_set_all(matrix_inverse, alpha=0.0_dp)
    1686             :          CALL pdgemm('T', 'N', n, n, n, 1.0_dp, vt%local_data, 1, 1, desca, &
    1687          28 :                      inv_sigma_ut%local_data, 1, 1, desca, 0.0_dp, matrix_inverse%local_data, 1, 1, desca)
    1688             :          ! Clean up
    1689          28 :          DEALLOCATE (work)
    1690          28 :          CALL cp_fm_release(u)
    1691          28 :          CALL cp_fm_release(vt)
    1692          28 :          CALL cp_fm_release(sigma)
    1693         140 :          CALL cp_fm_release(inv_sigma_ut)
    1694             :       END IF
    1695             : #else
    1696             :       IF (my_eps_svd .EQ. 0.0_dp) THEN
    1697             :          sign = .TRUE.
    1698             :          CALL invert_matrix(matrix_a%local_data, matrix_inverse%local_data, &
    1699             :                             eval_error=eps1)
    1700             :          CALL cp_fm_lu_decompose(matrix_lu, determinant, correct_sign=sign)
    1701             :          IF (PRESENT(eigval)) &
    1702             :             CALL cp_abort(__LOCATION__, &
    1703             :                           "NYI. Eigenvalues not available for return without SCALAPACK.")
    1704             :       ELSE
    1705             :          CALL get_pseudo_inverse_svd(matrix_a%local_data, matrix_inverse%local_data, eps_svd, &
    1706             :                                      determinant, eigval)
    1707             :       END IF
    1708             : #endif
    1709         702 :       CALL cp_fm_release(matrix_lu)
    1710         702 :       DEALLOCATE (ipivot)
    1711         702 :       IF (PRESENT(det_a)) det_a = determinant
    1712         702 :    END SUBROUTINE cp_fm_invert
    1713             : 
    1714             : ! **************************************************************************************************
    1715             : !> \brief inverts a triangular matrix
    1716             : !> \param matrix_a ...
    1717             : !> \param uplo_tr ...
    1718             : !> \author MI
    1719             : ! **************************************************************************************************
    1720        4978 :    SUBROUTINE cp_fm_triangular_invert(matrix_a, uplo_tr)
    1721             : 
    1722             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a
    1723             :       CHARACTER, INTENT(IN), OPTIONAL          :: uplo_tr
    1724             : 
    1725             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'cp_fm_triangular_invert'
    1726             : 
    1727             :       CHARACTER                                :: unit_diag, uplo
    1728             :       INTEGER                                  :: handle, info, ncol_global
    1729        4978 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1730             : #if defined(__parallel)
    1731             :       INTEGER, DIMENSION(9)                    :: desca
    1732             : #endif
    1733             : 
    1734        4978 :       CALL timeset(routineN, handle)
    1735             : 
    1736        4978 :       unit_diag = 'N'
    1737        4978 :       uplo = 'U'
    1738        4978 :       IF (PRESENT(uplo_tr)) uplo = uplo_tr
    1739             : 
    1740        4978 :       ncol_global = matrix_a%matrix_struct%ncol_global
    1741             : 
    1742        4978 :       a => matrix_a%local_data
    1743             : 
    1744             : #if defined(__parallel)
    1745             : 
    1746       49780 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1747             : 
    1748        4978 :       CALL pdtrtri(uplo, unit_diag, ncol_global, a(1, 1), 1, 1, desca, info)
    1749             : 
    1750             : #else
    1751             :       CALL dtrtri(uplo, unit_diag, ncol_global, a(1, 1), ncol_global, info)
    1752             : #endif
    1753             : 
    1754        4978 :       CALL timestop(handle)
    1755        4978 :    END SUBROUTINE cp_fm_triangular_invert
    1756             : 
    1757             : ! **************************************************************************************************
    1758             : !> \brief  performs a QR factorization of the input rectangular matrix A or of a submatrix of A
    1759             : !>         the computed upper triangular matrix R is in output in the submatrix sub(A) of size NxN
    1760             : !>         M and M give the dimension of the submatrix that has to be factorized (MxN) with M>N
    1761             : !> \param matrix_a ...
    1762             : !> \param matrix_r ...
    1763             : !> \param nrow_fact ...
    1764             : !> \param ncol_fact ...
    1765             : !> \param first_row ...
    1766             : !> \param first_col ...
    1767             : !> \author MI
    1768             : ! **************************************************************************************************
    1769       19320 :    SUBROUTINE cp_fm_qr_factorization(matrix_a, matrix_r, nrow_fact, ncol_fact, first_row, first_col)
    1770             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, matrix_r
    1771             :       INTEGER, INTENT(IN), OPTIONAL            :: nrow_fact, ncol_fact, &
    1772             :                                                   first_row, first_col
    1773             : 
    1774             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'cp_fm_qr_factorization'
    1775             : 
    1776             :       INTEGER                                  :: handle, i, icol, info, irow, &
    1777             :                                                   j, lda, lwork, ncol, &
    1778             :                                                   ndim, nrow
    1779       19320 :       REAL(dp), ALLOCATABLE, DIMENSION(:)      :: tau, work
    1780       19320 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)   :: r_mat
    1781             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    1782             : #if defined(__parallel)
    1783             :       INTEGER, DIMENSION(9)                    :: desca
    1784             : #endif
    1785             : 
    1786       19320 :       CALL timeset(routineN, handle)
    1787             : 
    1788       19320 :       ncol = matrix_a%matrix_struct%ncol_global
    1789       19320 :       nrow = matrix_a%matrix_struct%nrow_global
    1790       19320 :       lda = nrow
    1791             : 
    1792       19320 :       a => matrix_a%local_data
    1793             : 
    1794       19320 :       IF (PRESENT(nrow_fact)) nrow = nrow_fact
    1795       19320 :       IF (PRESENT(ncol_fact)) ncol = ncol_fact
    1796       19320 :       irow = 1
    1797       19320 :       IF (PRESENT(first_row)) irow = first_row
    1798       19320 :       icol = 1
    1799       19320 :       IF (PRESENT(first_col)) icol = first_col
    1800             : 
    1801       19320 :       CPASSERT(nrow >= ncol)
    1802       19320 :       ndim = SIZE(a, 2)
    1803             : !    ALLOCATE(ipiv(ndim))
    1804       57960 :       ALLOCATE (tau(ndim))
    1805             : 
    1806             : #if defined(__parallel)
    1807             : 
    1808      193200 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1809             : 
    1810       19320 :       lwork = -1
    1811       57960 :       ALLOCATE (work(2*ndim))
    1812       19320 :       CALL pdgeqrf(nrow, ncol, a, irow, icol, desca, tau, work, lwork, info)
    1813       19320 :       lwork = INT(work(1))
    1814       19320 :       DEALLOCATE (work)
    1815       57960 :       ALLOCATE (work(lwork))
    1816       19320 :       CALL pdgeqrf(nrow, ncol, a, irow, icol, desca, tau, work, lwork, info)
    1817             : 
    1818             : #else
    1819             :       lwork = -1
    1820             :       ALLOCATE (work(2*ndim))
    1821             :       CALL dgeqrf(nrow, ncol, a, lda, tau, work, lwork, info)
    1822             :       lwork = INT(work(1))
    1823             :       DEALLOCATE (work)
    1824             :       ALLOCATE (work(lwork))
    1825             :       CALL dgeqrf(nrow, ncol, a, lda, tau, work, lwork, info)
    1826             : 
    1827             : #endif
    1828             : 
    1829       77280 :       ALLOCATE (r_mat(ncol, ncol))
    1830       19320 :       CALL cp_fm_get_submatrix(matrix_a, r_mat, 1, 1, ncol, ncol)
    1831       38640 :       DO i = 1, ncol
    1832       38640 :          DO j = i + 1, ncol
    1833       19320 :             r_mat(j, i) = 0.0_dp
    1834             :          END DO
    1835             :       END DO
    1836       19320 :       CALL cp_fm_set_submatrix(matrix_r, r_mat, 1, 1, ncol, ncol)
    1837             : 
    1838       19320 :       DEALLOCATE (tau, work, r_mat)
    1839             : 
    1840       19320 :       CALL timestop(handle)
    1841             : 
    1842       19320 :    END SUBROUTINE cp_fm_qr_factorization
    1843             : 
    1844             : ! **************************************************************************************************
    1845             : !> \brief computes the the solution to A*b=A_general using lu decomposition
    1846             : !>        pay attention, both matrices are overwritten, a_general contais the result
    1847             : !> \param matrix_a ...
    1848             : !> \param general_a ...
    1849             : !> \author Florian Schiffmann
    1850             : ! **************************************************************************************************
    1851        4296 :    SUBROUTINE cp_fm_solve(matrix_a, general_a)
    1852             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix_a, general_a
    1853             : 
    1854             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_solve'
    1855             : 
    1856             :       INTEGER                                  :: handle, info, n
    1857        4296 :       INTEGER, ALLOCATABLE, DIMENSION(:)       :: ipivot
    1858             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a, a_general
    1859             : #if defined(__parallel)
    1860             :       INTEGER, DIMENSION(9)                    :: desca, descb
    1861             : #else
    1862             :       INTEGER                                  :: lda, ldb
    1863             : #endif
    1864             : 
    1865        4296 :       CALL timeset(routineN, handle)
    1866             : 
    1867        4296 :       a => matrix_a%local_data
    1868        4296 :       a_general => general_a%local_data
    1869        4296 :       n = matrix_a%matrix_struct%nrow_global
    1870       12888 :       ALLOCATE (ipivot(n + matrix_a%matrix_struct%nrow_block))
    1871             : 
    1872             : #if defined(__parallel)
    1873       42960 :       desca(:) = matrix_a%matrix_struct%descriptor(:)
    1874       42960 :       descb(:) = general_a%matrix_struct%descriptor(:)
    1875        4296 :       CALL pdgetrf(n, n, a, 1, 1, desca, ipivot, info)
    1876             :       CALL pdgetrs("N", n, n, a, 1, 1, desca, ipivot, a_general, &
    1877        4296 :                    1, 1, descb, info)
    1878             : 
    1879             : #else
    1880             :       lda = SIZE(a, 1)
    1881             :       ldb = SIZE(a_general, 1)
    1882             :       CALL dgetrf(n, n, a, lda, ipivot, info)
    1883             :       CALL dgetrs("N", n, n, a, lda, ipivot, a_general, ldb, info)
    1884             : 
    1885             : #endif
    1886             :       ! info is allowed to be zero
    1887             :       ! this does just signal a zero diagonal element
    1888        4296 :       DEALLOCATE (ipivot)
    1889        4296 :       CALL timestop(handle)
    1890        4296 :    END SUBROUTINE
    1891             : 
    1892             : ! **************************************************************************************************
    1893             : !> \brief Convenience function. Computes the matrix multiplications needed
    1894             : !>        for the multiplication of complex matrices.
    1895             : !>        C = beta * C + alpha * ( A  ** transa ) * ( B ** transb )
    1896             : !> \param transa : 'N' -> normal   'T' -> transpose
    1897             : !>      alpha,beta :: can be 0.0_dp and 1.0_dp
    1898             : !> \param transb ...
    1899             : !> \param m ...
    1900             : !> \param n ...
    1901             : !> \param k ...
    1902             : !> \param alpha ...
    1903             : !> \param A_re m x k matrix ( ! for transa = 'N'), real part
    1904             : !> \param A_im m x k matrix ( ! for transa = 'N'), imaginary part
    1905             : !> \param B_re k x n matrix ( ! for transa = 'N'), real part
    1906             : !> \param B_im k x n matrix ( ! for transa = 'N'), imaginary part
    1907             : !> \param beta ...
    1908             : !> \param C_re m x n matrix, real part
    1909             : !> \param C_im m x n matrix, imaginary part
    1910             : !> \param a_first_col ...
    1911             : !> \param a_first_row ...
    1912             : !> \param b_first_col : the k x n matrix starts at col b_first_col of matrix_b (avoid usage)
    1913             : !> \param b_first_row ...
    1914             : !> \param c_first_col ...
    1915             : !> \param c_first_row ...
    1916             : !> \author Samuel Andermatt
    1917             : !> \note
    1918             : !>      C should have no overlap with A, B
    1919             : ! **************************************************************************************************
    1920           0 :    SUBROUTINE cp_complex_fm_gemm(transa, transb, m, n, k, alpha, A_re, A_im, B_re, B_im, beta, &
    1921             :                                  C_re, C_im, a_first_col, a_first_row, b_first_col, b_first_row, c_first_col, &
    1922             :                                  c_first_row)
    1923             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
    1924             :       INTEGER, INTENT(IN)                                :: m, n, k
    1925             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
    1926             :       TYPE(cp_fm_type), INTENT(IN)                       :: A_re, A_im, B_re, B_im
    1927             :       REAL(KIND=dp), INTENT(IN)                          :: beta
    1928             :       TYPE(cp_fm_type), INTENT(IN)                       :: C_re, C_im
    1929             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
    1930             :                                                             b_first_row, c_first_col, c_first_row
    1931             : 
    1932             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_complex_fm_gemm'
    1933             : 
    1934             :       INTEGER                                            :: handle
    1935             : 
    1936           0 :       CALL timeset(routineN, handle)
    1937             : 
    1938             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_re, B_re, beta, C_re, &
    1939             :                       a_first_col=a_first_col, &
    1940             :                       a_first_row=a_first_row, &
    1941             :                       b_first_col=b_first_col, &
    1942             :                       b_first_row=b_first_row, &
    1943             :                       c_first_col=c_first_col, &
    1944           0 :                       c_first_row=c_first_row)
    1945             :       CALL cp_fm_gemm(transa, transb, m, n, k, -alpha, A_im, B_im, 1.0_dp, C_re, &
    1946             :                       a_first_col=a_first_col, &
    1947             :                       a_first_row=a_first_row, &
    1948             :                       b_first_col=b_first_col, &
    1949             :                       b_first_row=b_first_row, &
    1950             :                       c_first_col=c_first_col, &
    1951           0 :                       c_first_row=c_first_row)
    1952             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_re, B_im, beta, C_im, &
    1953             :                       a_first_col=a_first_col, &
    1954             :                       a_first_row=a_first_row, &
    1955             :                       b_first_col=b_first_col, &
    1956             :                       b_first_row=b_first_row, &
    1957             :                       c_first_col=c_first_col, &
    1958           0 :                       c_first_row=c_first_row)
    1959             :       CALL cp_fm_gemm(transa, transb, m, n, k, alpha, A_im, B_re, 1.0_dp, C_im, &
    1960             :                       a_first_col=a_first_col, &
    1961             :                       a_first_row=a_first_row, &
    1962             :                       b_first_col=b_first_col, &
    1963             :                       b_first_row=b_first_row, &
    1964             :                       c_first_col=c_first_col, &
    1965           0 :                       c_first_row=c_first_row)
    1966             : 
    1967           0 :       CALL timestop(handle)
    1968             : 
    1969           0 :    END SUBROUTINE cp_complex_fm_gemm
    1970             : 
    1971             : ! **************************************************************************************************
    1972             : !> \brief inverts a matrix using LU decomposition
    1973             : !>        the input matrix will be overwritten
    1974             : !> \param matrix   : input a general square non-singular matrix, outputs its inverse
    1975             : !> \param info_out : optional, if present outputs the info from (p)zgetri
    1976             : !> \author Lianheng Tong
    1977             : ! **************************************************************************************************
    1978           0 :    SUBROUTINE cp_fm_lu_invert(matrix, info_out)
    1979             :       TYPE(cp_fm_type), INTENT(IN)          :: matrix
    1980             :       INTEGER, INTENT(OUT), OPTIONAL           :: info_out
    1981             : 
    1982             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_lu_invert'
    1983             : 
    1984             :       INTEGER :: nrows_global, handle, info, lwork
    1985           0 :       INTEGER, DIMENSION(:), ALLOCATABLE       :: ipivot
    1986             :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: mat
    1987             :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: mat_sp
    1988           0 :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE :: work
    1989           0 :       REAL(KIND=sp), DIMENSION(:), ALLOCATABLE :: work_sp
    1990             : #if defined(__parallel)
    1991             :       INTEGER                                  :: liwork
    1992             :       INTEGER, DIMENSION(9)                    :: desca
    1993           0 :       INTEGER, DIMENSION(:), ALLOCATABLE       :: iwork
    1994             : #else
    1995             :       INTEGER                                  :: lda
    1996             : #endif
    1997             : 
    1998           0 :       CALL timeset(routineN, handle)
    1999             : 
    2000           0 :       mat => matrix%local_data
    2001           0 :       mat_sp => matrix%local_data_sp
    2002           0 :       nrows_global = matrix%matrix_struct%nrow_global
    2003           0 :       CPASSERT(nrows_global .EQ. matrix%matrix_struct%ncol_global)
    2004           0 :       ALLOCATE (ipivot(nrows_global))
    2005             :       ! do LU decomposition
    2006             : #if defined(__parallel)
    2007           0 :       desca = matrix%matrix_struct%descriptor
    2008           0 :       IF (matrix%use_sp) THEN
    2009             :          CALL psgetrf(nrows_global, nrows_global, &
    2010           0 :                       mat_sp, 1, 1, desca, ipivot, info)
    2011             :       ELSE
    2012             :          CALL pdgetrf(nrows_global, nrows_global, &
    2013           0 :                       mat, 1, 1, desca, ipivot, info)
    2014             :       END IF
    2015             : #else
    2016             :       lda = SIZE(mat, 1)
    2017             :       IF (matrix%use_sp) THEN
    2018             :          CALL sgetrf(nrows_global, nrows_global, &
    2019             :                      mat_sp, lda, ipivot, info)
    2020             :       ELSE
    2021             :          CALL dgetrf(nrows_global, nrows_global, &
    2022             :                      mat, lda, ipivot, info)
    2023             :       END IF
    2024             : #endif
    2025           0 :       IF (info /= 0) THEN
    2026           0 :          CALL cp_abort(__LOCATION__, "LU decomposition has failed")
    2027             :       END IF
    2028             :       ! do inversion
    2029           0 :       IF (matrix%use_sp) THEN
    2030           0 :          ALLOCATE (work(1))
    2031             :       ELSE
    2032           0 :          ALLOCATE (work_sp(1))
    2033             :       END IF
    2034             : #if defined(__parallel)
    2035           0 :       ALLOCATE (iwork(1))
    2036           0 :       IF (matrix%use_sp) THEN
    2037             :          CALL psgetri(nrows_global, mat_sp, 1, 1, desca, &
    2038           0 :                       ipivot, work_sp, -1, iwork, -1, info)
    2039           0 :          lwork = INT(work_sp(1))
    2040           0 :          DEALLOCATE (work_sp)
    2041           0 :          ALLOCATE (work_sp(lwork))
    2042             :       ELSE
    2043             :          CALL pdgetri(nrows_global, mat, 1, 1, desca, &
    2044           0 :                       ipivot, work, -1, iwork, -1, info)
    2045           0 :          lwork = INT(work(1))
    2046           0 :          DEALLOCATE (work)
    2047           0 :          ALLOCATE (work(lwork))
    2048             :       END IF
    2049           0 :       liwork = INT(iwork(1))
    2050           0 :       DEALLOCATE (iwork)
    2051           0 :       ALLOCATE (iwork(liwork))
    2052           0 :       IF (matrix%use_sp) THEN
    2053             :          CALL psgetri(nrows_global, mat_sp, 1, 1, desca, &
    2054           0 :                       ipivot, work_sp, lwork, iwork, liwork, info)
    2055             :       ELSE
    2056             :          CALL pdgetri(nrows_global, mat, 1, 1, desca, &
    2057           0 :                       ipivot, work, lwork, iwork, liwork, info)
    2058             :       END IF
    2059           0 :       DEALLOCATE (iwork)
    2060             : #else
    2061             :       IF (matrix%use_sp) THEN
    2062             :          CALL sgetri(nrows_global, mat_sp, lda, &
    2063             :                      ipivot, work_sp, -1, info)
    2064             :          lwork = INT(work_sp(1))
    2065             :          DEALLOCATE (work_sp)
    2066             :          ALLOCATE (work_sp(lwork))
    2067             :          CALL sgetri(nrows_global, mat_sp, lda, &
    2068             :                      ipivot, work_sp, lwork, info)
    2069             :       ELSE
    2070             :          CALL dgetri(nrows_global, mat, lda, &
    2071             :                      ipivot, work, -1, info)
    2072             :          lwork = INT(work(1))
    2073             :          DEALLOCATE (work)
    2074             :          ALLOCATE (work(lwork))
    2075             :          CALL dgetri(nrows_global, mat, lda, &
    2076             :                      ipivot, work, lwork, info)
    2077             :       END IF
    2078             : #endif
    2079           0 :       IF (matrix%use_sp) THEN
    2080           0 :          DEALLOCATE (work_sp)
    2081             :       ELSE
    2082           0 :          DEALLOCATE (work)
    2083             :       END IF
    2084           0 :       DEALLOCATE (ipivot)
    2085             : 
    2086           0 :       IF (PRESENT(info_out)) THEN
    2087           0 :          info_out = info
    2088             :       ELSE
    2089           0 :          IF (info /= 0) &
    2090           0 :             CALL cp_abort(__LOCATION__, "LU inversion has failed")
    2091             :       END IF
    2092             : 
    2093           0 :       CALL timestop(handle)
    2094             : 
    2095           0 :    END SUBROUTINE cp_fm_lu_invert
    2096             : 
    2097             : ! **************************************************************************************************
    2098             : !> \brief norm of matrix using (p)dlange
    2099             : !> \param matrix   : input a general matrix
    2100             : !> \param mode     : 'M' max abs element value,
    2101             : !>                   '1' or 'O' one norm, i.e. maximum column sum
    2102             : !>                   'I' infinity norm, i.e. maximum row sum
    2103             : !>                   'F' or 'E' Frobenius norm, i.e. sqrt of sum of all squares of elements
    2104             : !> \return : the norm according to mode
    2105             : !> \author Lianheng Tong
    2106             : ! **************************************************************************************************
    2107         492 :    FUNCTION cp_fm_norm(matrix, mode) RESULT(res)
    2108             :       TYPE(cp_fm_type), INTENT(IN) :: matrix
    2109             :       CHARACTER, INTENT(IN) :: mode
    2110             :       REAL(KIND=dp) :: res
    2111             : 
    2112             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_norm'
    2113             : 
    2114             :       INTEGER :: nrows, ncols, handle, lwork, nrows_local, ncols_local
    2115             :       REAL(KIND=sp) :: res_sp
    2116             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa
    2117             :       REAL(KIND=sp), DIMENSION(:, :), POINTER :: aa_sp
    2118         492 :       REAL(KIND=dp), DIMENSION(:), ALLOCATABLE :: work
    2119         492 :       REAL(KIND=sp), DIMENSION(:), ALLOCATABLE :: work_sp
    2120             : #if defined(__parallel)
    2121             :       INTEGER, DIMENSION(9) :: desca
    2122             : #else
    2123             :       INTEGER :: lda
    2124             : #endif
    2125             : 
    2126         492 :       CALL timeset(routineN, handle)
    2127             : 
    2128             :       CALL cp_fm_get_info(matrix=matrix, &
    2129             :                           nrow_global=nrows, &
    2130             :                           ncol_global=ncols, &
    2131             :                           nrow_local=nrows_local, &
    2132         492 :                           ncol_local=ncols_local)
    2133         492 :       aa => matrix%local_data
    2134         492 :       aa_sp => matrix%local_data_sp
    2135             : 
    2136             : #if defined(__parallel)
    2137        4920 :       desca = matrix%matrix_struct%descriptor
    2138             :       SELECT CASE (mode)
    2139             :       CASE ('M', 'm')
    2140         492 :          lwork = 1
    2141             :       CASE ('1', 'O', 'o')
    2142         492 :          lwork = ncols_local
    2143             :       CASE ('I', 'i')
    2144           0 :          lwork = nrows_local
    2145             :       CASE ('F', 'f', 'E', 'e')
    2146           0 :          lwork = 1
    2147             :       CASE DEFAULT
    2148         492 :          CPABORT("mode input is not valid")
    2149             :       END SELECT
    2150         492 :       IF (matrix%use_sp) THEN
    2151           0 :          ALLOCATE (work_sp(lwork))
    2152           0 :          res_sp = pslange(mode, nrows, ncols, aa_sp, 1, 1, desca, work_sp)
    2153           0 :          DEALLOCATE (work_sp)
    2154           0 :          res = REAL(res_sp, KIND=dp)
    2155             :       ELSE
    2156        1476 :          ALLOCATE (work(lwork))
    2157         492 :          res = pdlange(mode, nrows, ncols, aa, 1, 1, desca, work)
    2158         492 :          DEALLOCATE (work)
    2159             :       END IF
    2160             : #else
    2161             :       SELECT CASE (mode)
    2162             :       CASE ('M', 'm')
    2163             :          lwork = 1
    2164             :       CASE ('1', 'O', 'o')
    2165             :          lwork = 1
    2166             :       CASE ('I', 'i')
    2167             :          lwork = nrows
    2168             :       CASE ('F', 'f', 'E', 'e')
    2169             :          lwork = 1
    2170             :       CASE DEFAULT
    2171             :          CPABORT("mode input is not valid")
    2172             :       END SELECT
    2173             :       IF (matrix%use_sp) THEN
    2174             :          ALLOCATE (work_sp(lwork))
    2175             :          lda = SIZE(aa_sp, 1)
    2176             :          res_sp = slange(mode, nrows, ncols, aa_sp, lda, work_sp)
    2177             :          DEALLOCATE (work_sp)
    2178             :          res = REAL(res_sp, KIND=dp)
    2179             :       ELSE
    2180             :          ALLOCATE (work(lwork))
    2181             :          lda = SIZE(aa, 1)
    2182             :          res = dlange(mode, nrows, ncols, aa, lda, work)
    2183             :          DEALLOCATE (work)
    2184             :       END IF
    2185             : #endif
    2186             : 
    2187         492 :       CALL timestop(handle)
    2188             : 
    2189         492 :    END FUNCTION cp_fm_norm
    2190             : 
    2191             : ! **************************************************************************************************
    2192             : !> \brief trace of a matrix using pdlatra
    2193             : !> \param matrix   : input a square matrix
    2194             : !> \return : the trace
    2195             : !> \author Lianheng Tong
    2196             : ! **************************************************************************************************
    2197           0 :    FUNCTION cp_fm_latra(matrix) RESULT(res)
    2198             :       TYPE(cp_fm_type), INTENT(IN) :: matrix
    2199             :       REAL(KIND=dp) :: res
    2200             : 
    2201             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_latra'
    2202             : 
    2203             :       INTEGER :: nrows, ncols, handle
    2204             :       REAL(KIND=sp) :: res_sp
    2205             :       REAL(KIND=dp), DIMENSION(:, :), POINTER :: aa
    2206             :       REAL(KIND=sp), DIMENSION(:, :), POINTER :: aa_sp
    2207             : #if defined(__parallel)
    2208             :       INTEGER, DIMENSION(9) :: desca
    2209             : #else
    2210             :       INTEGER :: ii
    2211             : #endif
    2212             : 
    2213           0 :       CALL timeset(routineN, handle)
    2214             : 
    2215           0 :       nrows = matrix%matrix_struct%nrow_global
    2216           0 :       ncols = matrix%matrix_struct%ncol_global
    2217           0 :       CPASSERT(nrows .EQ. ncols)
    2218           0 :       aa => matrix%local_data
    2219           0 :       aa_sp => matrix%local_data_sp
    2220             : 
    2221             : #if defined(__parallel)
    2222           0 :       desca = matrix%matrix_struct%descriptor
    2223           0 :       IF (matrix%use_sp) THEN
    2224           0 :          res_sp = pslatra(nrows, aa_sp, 1, 1, desca)
    2225           0 :          res = REAL(res_sp, KIND=dp)
    2226             :       ELSE
    2227           0 :          res = pdlatra(nrows, aa, 1, 1, desca)
    2228             :       END IF
    2229             : #else
    2230             :       IF (matrix%use_sp) THEN
    2231             :          res_sp = 0.0_sp
    2232             :          DO ii = 1, nrows
    2233             :             res_sp = res_sp + aa_sp(ii, ii)
    2234             :          END DO
    2235             :          res = REAL(res_sp, KIND=dp)
    2236             :       ELSE
    2237             :          res = 0.0_dp
    2238             :          DO ii = 1, nrows
    2239             :             res = res + aa(ii, ii)
    2240             :          END DO
    2241             :       END IF
    2242             : #endif
    2243             : 
    2244           0 :       CALL timestop(handle)
    2245             : 
    2246           0 :    END FUNCTION cp_fm_latra
    2247             : 
    2248             : ! **************************************************************************************************
    2249             : !> \brief compute a QR factorization with column pivoting of a M-by-N distributed matrix
    2250             : !>        sub( A ) = A(IA:IA+M-1,JA:JA+N-1)
    2251             : !> \param matrix   : input M-by-N distributed matrix sub( A ) which is to be factored
    2252             : !> \param tau      :  scalar factors TAU of the elementary reflectors. TAU is tied to the distributed matrix A
    2253             : !> \param nrow ...
    2254             : !> \param ncol ...
    2255             : !> \param first_row ...
    2256             : !> \param first_col ...
    2257             : !> \author MI
    2258             : ! **************************************************************************************************
    2259          36 :    SUBROUTINE cp_fm_pdgeqpf(matrix, tau, nrow, ncol, first_row, first_col)
    2260             : 
    2261             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix
    2262             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: tau
    2263             :       INTEGER, INTENT(IN)                                :: nrow, ncol, first_row, first_col
    2264             : 
    2265             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_pdgeqpf'
    2266             : 
    2267             :       INTEGER                                            :: handle
    2268             :       INTEGER                                            :: info, lwork
    2269          36 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipiv
    2270             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    2271             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: work
    2272             : #if defined(__parallel)
    2273             :       INTEGER, DIMENSION(9) :: descc
    2274             : #else
    2275             :       INTEGER :: lda
    2276             : #endif
    2277             : 
    2278          36 :       CALL timeset(routineN, handle)
    2279             : 
    2280          36 :       a => matrix%local_data
    2281          36 :       lwork = -1
    2282         108 :       ALLOCATE (work(2*nrow))
    2283         108 :       ALLOCATE (ipiv(ncol))
    2284          36 :       info = 0
    2285             : 
    2286             : #if defined(__parallel)
    2287         360 :       descc(:) = matrix%matrix_struct%descriptor(:)
    2288             :       ! Call SCALAPACK routine to get optimal work dimension
    2289          36 :       CALL pdgeqpf(nrow, ncol, a, first_row, first_col, descc, ipiv, tau, work, lwork, info)
    2290          36 :       lwork = INT(work(1))
    2291          36 :       DEALLOCATE (work)
    2292         108 :       ALLOCATE (work(lwork))
    2293         244 :       tau = 0.0_dp
    2294         354 :       ipiv = 0
    2295             : 
    2296             :       ! Call SCALAPACK routine to get QR decomposition of CTs
    2297          36 :       CALL pdgeqpf(nrow, ncol, a, first_row, first_col, descc, ipiv, tau, work, lwork, info)
    2298             : #else
    2299             :       CPASSERT(first_row == 1 .AND. first_col == 1)
    2300             :       lda = SIZE(a, 1)
    2301             :       CALL dgeqp3(nrow, ncol, a, lda, ipiv, tau, work, lwork, info)
    2302             :       lwork = INT(work(1))
    2303             :       DEALLOCATE (work)
    2304             :       ALLOCATE (work(lwork))
    2305             :       tau = 0.0_dp
    2306             :       ipiv = 0
    2307             :       CALL dgeqp3(nrow, ncol, a, lda, ipiv, tau, work, lwork, info)
    2308             : #endif
    2309          36 :       CPASSERT(info == 0)
    2310             : 
    2311          36 :       DEALLOCATE (work)
    2312          36 :       DEALLOCATE (ipiv)
    2313             : 
    2314          36 :       CALL timestop(handle)
    2315             : 
    2316          36 :    END SUBROUTINE cp_fm_pdgeqpf
    2317             : 
    2318             : ! **************************************************************************************************
    2319             : !> \brief generates an M-by-N real distributed matrix Q denoting A(IA:IA+M-1,JA:JA+N-1)
    2320             : !>         with orthonormal columns, which is defined as the first N columns of a product of K
    2321             : !>         elementary reflectors of order M
    2322             : !> \param matrix : On entry, the j-th column must contain the vector which defines the elementary reflector
    2323             : !>                  as returned from PDGEQRF
    2324             : !>                 On exit it contains  the M-by-N distributed matrix Q
    2325             : !> \param tau :   contains the scalar factors TAU of elementary reflectors  as returned by PDGEQRF
    2326             : !> \param nrow ...
    2327             : !> \param first_row ...
    2328             : !> \param first_col ...
    2329             : !> \author MI
    2330             : ! **************************************************************************************************
    2331          36 :    SUBROUTINE cp_fm_pdorgqr(matrix, tau, nrow, first_row, first_col)
    2332             : 
    2333             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix
    2334             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: tau
    2335             :       INTEGER, INTENT(IN)                                :: nrow, first_row, first_col
    2336             : 
    2337             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_pdorgqr'
    2338             : 
    2339             :       INTEGER                                            :: handle
    2340             :       INTEGER                                            :: info, lwork
    2341             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: a
    2342             :       REAL(KIND=dp), DIMENSION(:), POINTER               :: work
    2343             : #if defined(__parallel)
    2344             :       INTEGER, DIMENSION(9) :: descc
    2345             : #else
    2346             :       INTEGER :: lda
    2347             : #endif
    2348             : 
    2349          36 :       CALL timeset(routineN, handle)
    2350             : 
    2351          36 :       a => matrix%local_data
    2352          36 :       lwork = -1
    2353         108 :       ALLOCATE (work(2*nrow))
    2354          36 :       info = 0
    2355             : 
    2356             : #if defined(__parallel)
    2357         360 :       descc(:) = matrix%matrix_struct%descriptor(:)
    2358             : 
    2359          36 :       CALL pdorgqr(nrow, nrow, nrow, a, first_row, first_col, descc, tau, work, lwork, info)
    2360          36 :       CPASSERT(info == 0)
    2361          36 :       lwork = INT(work(1))
    2362          36 :       DEALLOCATE (work)
    2363         108 :       ALLOCATE (work(lwork))
    2364             : 
    2365             :       ! Call SCALAPACK routine to get Q
    2366          36 :       CALL pdorgqr(nrow, nrow, nrow, a, first_row, first_col, descc, tau, work, lwork, info)
    2367             : #else
    2368             :       CPASSERT(first_row == 1 .AND. first_col == 1)
    2369             :       lda = SIZE(a, 1)
    2370             :       CALL dorgqr(nrow, nrow, nrow, a, lda, tau, work, lwork, info)
    2371             :       lwork = INT(work(1))
    2372             :       DEALLOCATE (work)
    2373             :       ALLOCATE (work(lwork))
    2374             :       CALL dorgqr(nrow, nrow, nrow, a, lda, tau, work, lwork, info)
    2375             : #endif
    2376          36 :       CPASSERT(INFO == 0)
    2377             : 
    2378          36 :       DEALLOCATE (work)
    2379          36 :       CALL timestop(handle)
    2380             : 
    2381          36 :    END SUBROUTINE cp_fm_pdorgqr
    2382             : 
    2383             : ! **************************************************************************************************
    2384             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th rows.
    2385             : !> \param cs cosine of the rotation angle
    2386             : !> \param sn sinus of the rotation angle
    2387             : !> \param irow ...
    2388             : !> \param jrow ...
    2389             : !> \author Ole Schuett
    2390             : ! **************************************************************************************************
    2391      543328 :    SUBROUTINE cp_fm_rot_rows(matrix, irow, jrow, cs, sn)
    2392             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix
    2393             :       INTEGER, INTENT(IN)                      :: irow, jrow
    2394             :       REAL(dp), INTENT(IN)                     :: cs, sn
    2395             : 
    2396             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_rot_rows'
    2397             :       INTEGER                                  :: handle, nrow, ncol
    2398             : 
    2399             : #if defined(__parallel)
    2400             :       INTEGER                                  :: info, lwork
    2401             :       INTEGER, DIMENSION(9)                    :: desc
    2402      543328 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    2403             : #endif
    2404             : 
    2405      543328 :       CALL timeset(routineN, handle)
    2406      543328 :       CALL cp_fm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    2407             : 
    2408             : #if defined(__parallel)
    2409      543328 :       lwork = 2*ncol + 1
    2410     1629984 :       ALLOCATE (work(lwork))
    2411     5433280 :       desc(:) = matrix%matrix_struct%descriptor(:)
    2412             :       CALL pdrot(ncol, &
    2413             :                  matrix%local_data(1, 1), irow, 1, desc, ncol, &
    2414             :                  matrix%local_data(1, 1), jrow, 1, desc, ncol, &
    2415      543328 :                  cs, sn, work, lwork, info)
    2416      543328 :       CPASSERT(info == 0)
    2417      543328 :       DEALLOCATE (work)
    2418             : #else
    2419             :       CALL drot(ncol, matrix%local_data(irow, 1), ncol, matrix%local_data(jrow, 1), ncol, cs, sn)
    2420             : #endif
    2421             : 
    2422      543328 :       CALL timestop(handle)
    2423      543328 :    END SUBROUTINE cp_fm_rot_rows
    2424             : 
    2425             : ! **************************************************************************************************
    2426             : !> \brief Applies a planar rotation defined by cs and sn to the i'th and j'th columnns.
    2427             : !> \param cs cosine of the rotation angle
    2428             : !> \param sn sinus of the rotation angle
    2429             : !> \param icol ...
    2430             : !> \param jcol ...
    2431             : !> \author Ole Schuett
    2432             : ! **************************************************************************************************
    2433      612158 :    SUBROUTINE cp_fm_rot_cols(matrix, icol, jcol, cs, sn)
    2434             :       TYPE(cp_fm_type), INTENT(IN)             :: matrix
    2435             :       INTEGER, INTENT(IN)                      :: icol, jcol
    2436             :       REAL(dp), INTENT(IN)                     :: cs, sn
    2437             : 
    2438             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_rot_cols'
    2439             :       INTEGER                                  :: handle, nrow, ncol
    2440             : 
    2441             : #if defined(__parallel)
    2442             :       INTEGER                                  :: info, lwork
    2443             :       INTEGER, DIMENSION(9)                    :: desc
    2444      612158 :       REAL(dp), DIMENSION(:), ALLOCATABLE      :: work
    2445             : #endif
    2446             : 
    2447      612158 :       CALL timeset(routineN, handle)
    2448      612158 :       CALL cp_fm_get_info(matrix, nrow_global=nrow, ncol_global=ncol)
    2449             : 
    2450             : #if defined(__parallel)
    2451      612158 :       lwork = 2*nrow + 1
    2452     1836474 :       ALLOCATE (work(lwork))
    2453     6121580 :       desc(:) = matrix%matrix_struct%descriptor(:)
    2454             :       CALL pdrot(nrow, &
    2455             :                  matrix%local_data(1, 1), 1, icol, desc, 1, &
    2456             :                  matrix%local_data(1, 1), 1, jcol, desc, 1, &
    2457      612158 :                  cs, sn, work, lwork, info)
    2458      612158 :       CPASSERT(info == 0)
    2459      612158 :       DEALLOCATE (work)
    2460             : #else
    2461             :       CALL drot(nrow, matrix%local_data(1, icol), 1, matrix%local_data(1, jcol), 1, cs, sn)
    2462             : #endif
    2463             : 
    2464      612158 :       CALL timestop(handle)
    2465      612158 :    END SUBROUTINE cp_fm_rot_cols
    2466             : 
    2467             : ! **************************************************************************************************
    2468             : !> \brief Orthonormalizes selected rows and columns of a full matrix, matrix_a
    2469             : !> \param matrix_a ...
    2470             : !> \param B ...
    2471             : !> \param nrows number of rows of matrix_a, optional, defaults to size(matrix_a,1)
    2472             : !> \param ncols number of columns of matrix_a, optional, defaults to size(matrix_a, 2)
    2473             : !> \param start_row starting index of rows, optional, defaults to 1
    2474             : !> \param start_col starting index of columns, optional, defaults to 1
    2475             : !> \param do_norm ...
    2476             : !> \param do_print ...
    2477             : ! **************************************************************************************************
    2478           0 :    SUBROUTINE cp_fm_Gram_Schmidt_orthonorm(matrix_a, B, nrows, ncols, start_row, start_col, &
    2479             :                                            do_norm, do_print)
    2480             : 
    2481             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a
    2482             :       REAL(kind=dp), DIMENSION(:, :), INTENT(OUT)        :: B
    2483             :       INTEGER, INTENT(IN), OPTIONAL                      :: nrows, ncols, start_row, start_col
    2484             :       LOGICAL, INTENT(IN), OPTIONAL                      :: do_norm, do_print
    2485             : 
    2486             :       CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_Gram_Schmidt_orthonorm'
    2487             : 
    2488             :       INTEGER :: end_col_global, end_col_local, end_row_global, end_row_local, handle, i, j, &
    2489             :                  j_col, ncol_global, ncol_local, nrow_global, nrow_local, start_col_global, &
    2490             :                  start_col_local, start_row_global, start_row_local, this_col, unit_nr
    2491           0 :       INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
    2492             :       LOGICAL                                            :: my_do_norm, my_do_print
    2493             :       REAL(KIND=dp)                                      :: norm
    2494           0 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: a
    2495             : 
    2496           0 :       CALL timeset(routineN, handle)
    2497             : 
    2498           0 :       my_do_norm = .TRUE.
    2499           0 :       IF (PRESENT(do_norm)) my_do_norm = do_norm
    2500             : 
    2501           0 :       my_do_print = .FALSE.
    2502           0 :       IF (PRESENT(do_print) .AND. (my_do_norm)) my_do_print = do_print
    2503             : 
    2504           0 :       unit_nr = -1
    2505           0 :       IF (my_do_print) THEN
    2506           0 :          unit_nr = cp_logger_get_default_unit_nr()
    2507           0 :          IF (unit_nr < 1) my_do_print = .FALSE.
    2508             :       END IF
    2509             : 
    2510           0 :       IF (SIZE(B) /= 0) THEN
    2511           0 :          IF (PRESENT(nrows)) THEN
    2512           0 :             nrow_global = nrows
    2513             :          ELSE
    2514           0 :             nrow_global = SIZE(B, 1)
    2515             :          END IF
    2516             : 
    2517           0 :          IF (PRESENT(ncols)) THEN
    2518           0 :             ncol_global = ncols
    2519             :          ELSE
    2520           0 :             ncol_global = SIZE(B, 2)
    2521             :          END IF
    2522             : 
    2523           0 :          IF (PRESENT(start_row)) THEN
    2524           0 :             start_row_global = start_row
    2525             :          ELSE
    2526             :             start_row_global = 1
    2527             :          END IF
    2528             : 
    2529           0 :          IF (PRESENT(start_col)) THEN
    2530           0 :             start_col_global = start_col
    2531             :          ELSE
    2532             :             start_col_global = 1
    2533             :          END IF
    2534             : 
    2535           0 :          end_row_global = start_row_global + nrow_global - 1
    2536           0 :          end_col_global = start_col_global + ncol_global - 1
    2537             : 
    2538             :          CALL cp_fm_get_info(matrix=matrix_a, &
    2539             :                              nrow_global=nrow_global, ncol_global=ncol_global, &
    2540             :                              nrow_local=nrow_local, ncol_local=ncol_local, &
    2541           0 :                              row_indices=row_indices, col_indices=col_indices)
    2542           0 :          IF (end_row_global > nrow_global) THEN
    2543             :             end_row_global = nrow_global
    2544             :          END IF
    2545           0 :          IF (end_col_global > ncol_global) THEN
    2546             :             end_col_global = ncol_global
    2547             :          END IF
    2548             : 
    2549             :          ! find out row/column indices of locally stored matrix elements that
    2550             :          ! needs to be copied.
    2551             :          ! Arrays row_indices and col_indices are assumed to be sorted in
    2552             :          ! ascending order
    2553           0 :          DO start_row_local = 1, nrow_local
    2554           0 :             IF (row_indices(start_row_local) >= start_row_global) EXIT
    2555             :          END DO
    2556             : 
    2557           0 :          DO end_row_local = start_row_local, nrow_local
    2558           0 :             IF (row_indices(end_row_local) > end_row_global) EXIT
    2559             :          END DO
    2560           0 :          end_row_local = end_row_local - 1
    2561             : 
    2562           0 :          DO start_col_local = 1, ncol_local
    2563           0 :             IF (col_indices(start_col_local) >= start_col_global) EXIT
    2564             :          END DO
    2565             : 
    2566           0 :          DO end_col_local = start_col_local, ncol_local
    2567           0 :             IF (col_indices(end_col_local) > end_col_global) EXIT
    2568             :          END DO
    2569           0 :          end_col_local = end_col_local - 1
    2570             : 
    2571           0 :          a => matrix_a%local_data
    2572             : 
    2573           0 :          this_col = col_indices(start_col_local) - start_col_global + 1
    2574             : 
    2575           0 :          B(:, this_col) = a(:, start_col_local)
    2576             : 
    2577           0 :          IF (my_do_norm) THEN
    2578           0 :             norm = SQRT(accurate_dot_product(B(:, this_col), B(:, this_col)))
    2579           0 :             B(:, this_col) = B(:, this_col)/norm
    2580           0 :             IF (my_do_print) WRITE (unit_nr, '(I3,F8.3)') this_col, norm
    2581             :          END IF
    2582             : 
    2583           0 :          DO i = start_col_local + 1, end_col_local
    2584           0 :             this_col = col_indices(i) - start_col_global + 1
    2585           0 :             B(:, this_col) = a(:, i)
    2586           0 :             DO j = start_col_local, i - 1
    2587           0 :                j_col = col_indices(j) - start_col_global + 1
    2588             :                B(:, this_col) = B(:, this_col) - &
    2589             :                                 accurate_dot_product(B(:, j_col), B(:, this_col))* &
    2590           0 :                                 B(:, j_col)/accurate_dot_product(B(:, j_col), B(:, j_col))
    2591             :             END DO
    2592             : 
    2593           0 :             IF (my_do_norm) THEN
    2594           0 :                norm = SQRT(accurate_dot_product(B(:, this_col), B(:, this_col)))
    2595           0 :                B(:, this_col) = B(:, this_col)/norm
    2596           0 :                IF (my_do_print) WRITE (unit_nr, '(I3,F8.3)') this_col, norm
    2597             :             END IF
    2598             : 
    2599             :          END DO
    2600           0 :          CALL matrix_a%matrix_struct%para_env%sum(B)
    2601             :       END IF
    2602             : 
    2603           0 :       CALL timestop(handle)
    2604             : 
    2605           0 :    END SUBROUTINE cp_fm_Gram_Schmidt_orthonorm
    2606             : 
    2607             : ! **************************************************************************************************
    2608             : !> \brief Cholesky decomposition
    2609             : !> \param fm_matrix ...
    2610             : !> \param n ...
    2611             : ! **************************************************************************************************
    2612       10069 :    SUBROUTINE cp_fm_potrf(fm_matrix, n)
    2613             :       TYPE(cp_fm_type)                         :: fm_matrix
    2614             :       INTEGER, INTENT(in)                      :: n
    2615             : 
    2616             :       INTEGER                                  :: info
    2617       10069 :       REAL(KIND=dp), DIMENSION(:, :), POINTER  :: a
    2618       10069 :       REAL(KIND=sp), DIMENSION(:, :), POINTER  :: a_sp
    2619             : #if defined(__parallel)
    2620             :       INTEGER, DIMENSION(9)                    :: desca
    2621             : #endif
    2622             : 
    2623       10069 :       a => fm_matrix%local_data
    2624       10069 :       a_sp => fm_matrix%local_data_sp
    2625             : #if defined(__parallel)
    2626      100690 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2627       10069 :       IF (fm_matrix%use_sp) THEN
    2628           0 :          CALL pspotrf('U', n, a_sp(1, 1), 1, 1, desca, info)
    2629             :       ELSE
    2630       10069 :          CALL pdpotrf('U', n, a(1, 1), 1, 1, desca, info)
    2631             :       END IF
    2632             : #else
    2633             :       IF (fm_matrix%use_sp) THEN
    2634             :          CALL spotrf('U', n, a_sp(1, 1), SIZE(a_sp, 1), info)
    2635             :       ELSE
    2636             :          CALL dpotrf('U', n, a(1, 1), SIZE(a, 1), info)
    2637             :       END IF
    2638             : #endif
    2639       10069 :       IF (info /= 0) &
    2640           0 :          CPABORT("Cholesky decomposition failed. Matrix ill conditioned ?")
    2641             : 
    2642       10069 :    END SUBROUTINE cp_fm_potrf
    2643             : 
    2644             : ! **************************************************************************************************
    2645             : !> \brief Invert trianguar matrix
    2646             : !> \param fm_matrix the matrix to invert (must be an upper triangular matrix)
    2647             : !> \param n size of the matrix to invert
    2648             : ! **************************************************************************************************
    2649        9293 :    SUBROUTINE cp_fm_potri(fm_matrix, n)
    2650             :       TYPE(cp_fm_type)                          :: fm_matrix
    2651             :       INTEGER, INTENT(in)                       :: n
    2652             : 
    2653        9293 :       REAL(KIND=dp), DIMENSION(:, :), POINTER   :: a
    2654        9293 :       REAL(KIND=sp), DIMENSION(:, :), POINTER   :: a_sp
    2655             :       INTEGER                                   :: info
    2656             : #if defined(__parallel)
    2657             :       INTEGER, DIMENSION(9)                     :: desca
    2658             : #endif
    2659             : 
    2660        9293 :       a => fm_matrix%local_data
    2661        9293 :       a_sp => fm_matrix%local_data_sp
    2662             : #if defined(__parallel)
    2663       92930 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2664        9293 :       IF (fm_matrix%use_sp) THEN
    2665           0 :          CALL pspotri('U', n, a_sp(1, 1), 1, 1, desca, info)
    2666             :       ELSE
    2667        9293 :          CALL pdpotri('U', n, a(1, 1), 1, 1, desca, info)
    2668             :       END IF
    2669             : #else
    2670             :       IF (fm_matrix%use_sp) THEN
    2671             :          CALL spotri('U', n, a_sp(1, 1), SIZE(a_sp, 1), info)
    2672             :       ELSE
    2673             :          CALL dpotri('U', n, a(1, 1), SIZE(a, 1), info)
    2674             :       END IF
    2675             : #endif
    2676        9293 :       CPASSERT(info == 0)
    2677        9293 :    END SUBROUTINE cp_fm_potri
    2678             : 
    2679             : ! **************************************************************************************************
    2680             : !> \brief ...
    2681             : !> \param fm_matrix ...
    2682             : !> \param neig ...
    2683             : !> \param fm_matrixb ...
    2684             : !> \param fm_matrixout ...
    2685             : !> \param op ...
    2686             : !> \param pos ...
    2687             : !> \param transa ...
    2688             : ! **************************************************************************************************
    2689        1184 :    SUBROUTINE cp_fm_cholesky_restore(fm_matrix, neig, fm_matrixb, fm_matrixout, op, pos, transa)
    2690             :       TYPE(cp_fm_type)                               :: fm_matrix
    2691             :       TYPE(cp_fm_type)                               :: fm_matrixb
    2692             :       TYPE(cp_fm_type)                               :: fm_matrixout
    2693             :       INTEGER, INTENT(IN)                            :: neig
    2694             :       CHARACTER(LEN=*), INTENT(IN)                   :: op
    2695             :       CHARACTER(LEN=*), INTENT(IN)                   :: pos
    2696             :       CHARACTER(LEN=*), INTENT(IN)                   :: transa
    2697             : 
    2698        1184 :       REAL(KIND=dp), DIMENSION(:, :), POINTER        :: a, b, outm
    2699        1184 :       REAL(KIND=sp), DIMENSION(:, :), POINTER        :: a_sp, b_sp, outm_sp
    2700             :       INTEGER                                        :: n, itype
    2701             :       REAL(KIND=dp)                                  :: alpha
    2702             : #if defined(__parallel)
    2703             :       INTEGER                                        :: i
    2704             :       INTEGER, DIMENSION(9)                          :: desca, descb, descout
    2705             : #endif
    2706             : 
    2707             :       ! notice b is the cholesky guy
    2708        1184 :       a => fm_matrix%local_data
    2709        1184 :       b => fm_matrixb%local_data
    2710        1184 :       outm => fm_matrixout%local_data
    2711        1184 :       a_sp => fm_matrix%local_data_sp
    2712        1184 :       b_sp => fm_matrixb%local_data_sp
    2713        1184 :       outm_sp => fm_matrixout%local_data_sp
    2714             : 
    2715        1184 :       n = fm_matrix%matrix_struct%nrow_global
    2716        1184 :       itype = 1
    2717             : 
    2718             : #if defined(__parallel)
    2719       11840 :       desca(:) = fm_matrix%matrix_struct%descriptor(:)
    2720       11840 :       descb(:) = fm_matrixb%matrix_struct%descriptor(:)
    2721       11840 :       descout(:) = fm_matrixout%matrix_struct%descriptor(:)
    2722        1184 :       alpha = 1.0_dp
    2723        5316 :       DO i = 1, neig
    2724        5316 :          IF (fm_matrix%use_sp) THEN
    2725           0 :             CALL pscopy(n, a_sp(1, 1), 1, i, desca, 1, outm_sp(1, 1), 1, i, descout, 1)
    2726             :          ELSE
    2727        4132 :             CALL pdcopy(n, a(1, 1), 1, i, desca, 1, outm(1, 1), 1, i, descout, 1)
    2728             :          END IF
    2729             :       END DO
    2730        1184 :       IF (op .EQ. "SOLVE") THEN
    2731        1184 :          IF (fm_matrix%use_sp) THEN
    2732             :             CALL pstrsm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
    2733           0 :                         outm_sp(1, 1), 1, 1, descout)
    2734             :          ELSE
    2735        1184 :             CALL pdtrsm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, outm(1, 1), 1, 1, descout)
    2736             :          END IF
    2737             :       ELSE
    2738           0 :          IF (fm_matrix%use_sp) THEN
    2739             :             CALL pstrmm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), 1, 1, descb, &
    2740           0 :                         outm_sp(1, 1), 1, 1, descout)
    2741             :          ELSE
    2742           0 :             CALL pdtrmm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), 1, 1, descb, outm(1, 1), 1, 1, descout)
    2743             :          END IF
    2744             :       END IF
    2745             : #else
    2746             :       alpha = 1.0_dp
    2747             :       IF (fm_matrix%use_sp) THEN
    2748             :          CALL scopy(neig*n, a_sp(1, 1), 1, outm_sp(1, 1), 1)
    2749             :       ELSE
    2750             :          CALL dcopy(neig*n, a(1, 1), 1, outm(1, 1), 1)
    2751             :       END IF
    2752             :       IF (op .EQ. "SOLVE") THEN
    2753             :          IF (fm_matrix%use_sp) THEN
    2754             :             CALL strsm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), SIZE(b_sp, 1), outm_sp(1, 1), n)
    2755             :          ELSE
    2756             :             CALL dtrsm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), SIZE(b, 1), outm(1, 1), n)
    2757             :          END IF
    2758             :       ELSE
    2759             :          IF (fm_matrix%use_sp) THEN
    2760             :             CALL strmm(pos, 'U', transa, 'N', n, neig, REAL(alpha, sp), b_sp(1, 1), n, outm_sp(1, 1), n)
    2761             :          ELSE
    2762             :             CALL dtrmm(pos, 'U', transa, 'N', n, neig, alpha, b(1, 1), n, outm(1, 1), n)
    2763             :          END IF
    2764             :       END IF
    2765             : #endif
    2766             : 
    2767        1184 :    END SUBROUTINE cp_fm_cholesky_restore
    2768             : 
    2769             : END MODULE cp_fm_basic_linalg

Generated by: LCOV version 1.15