LCOV - code coverage report
Current view: top level - src - parallel_gemm_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 58 72 80.6 %
Date: 2024-12-21 06:28:57 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief basic linear algebra operations for full matrixes
      10             : !> \par History
      11             : !>      08.2002 splitted out of qs_blacs [fawzi]
      12             : !> \author Fawzi Mohamed
      13             : ! **************************************************************************************************
      14             : MODULE parallel_gemm_api
      15             :    USE ISO_C_BINDING,                   ONLY: C_CHAR,&
      16             :                                               C_DOUBLE,&
      17             :                                               C_INT,&
      18             :                                               C_LOC,&
      19             :                                               C_PTR
      20             :    USE cp_cfm_basic_linalg,             ONLY: cp_cfm_gemm
      21             :    USE cp_cfm_types,                    ONLY: cp_cfm_type
      22             :    USE cp_fm_basic_linalg,              ONLY: cp_fm_gemm
      23             :    USE cp_fm_types,                     ONLY: cp_fm_get_mm_type,&
      24             :                                               cp_fm_type
      25             :    USE input_constants,                 ONLY: do_cosma,&
      26             :                                               do_scalapack
      27             :    USE kinds,                           ONLY: dp
      28             :    USE offload_api,                     ONLY: offload_activate_chosen_device
      29             : #include "./base/base_uses.f90"
      30             : 
      31             :    IMPLICIT NONE
      32             :    PRIVATE
      33             : 
      34             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'parallel_gemm_api'
      35             : 
      36             :    PUBLIC :: parallel_gemm
      37             : 
      38             :    INTERFACE parallel_gemm
      39             :       MODULE PROCEDURE parallel_gemm_fm
      40             :       MODULE PROCEDURE parallel_gemm_cfm
      41             :    END INTERFACE parallel_gemm
      42             : 
      43             : CONTAINS
      44             : 
      45             : ! **************************************************************************************************
      46             : !> \brief ...
      47             : !> \param transa ...
      48             : !> \param transb ...
      49             : !> \param m ...
      50             : !> \param n ...
      51             : !> \param k ...
      52             : !> \param alpha ...
      53             : !> \param matrix_a ...
      54             : !> \param matrix_b ...
      55             : !> \param beta ...
      56             : !> \param matrix_c ...
      57             : !> \param a_first_col ...
      58             : !> \param a_first_row ...
      59             : !> \param b_first_col ...
      60             : !> \param b_first_row ...
      61             : !> \param c_first_col ...
      62             : !> \param c_first_row ...
      63             : ! **************************************************************************************************
      64     1088264 :    SUBROUTINE parallel_gemm_fm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
      65             :                                matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
      66             :                                c_first_col, c_first_row)
      67             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
      68             :       INTEGER, INTENT(IN)                                :: m, n, k
      69             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
      70             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
      71             :       REAL(KIND=dp), INTENT(IN)                          :: beta
      72             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c
      73             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
      74             :                                                             b_first_row, c_first_col, c_first_row
      75             : 
      76             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'parallel_gemm_fm'
      77             : 
      78             :       INTEGER                                            :: handle, handle1, my_multi
      79             : 
      80     1088264 :       CALL timeset(routineN, handle)
      81             : 
      82     1088264 :       my_multi = cp_fm_get_mm_type()
      83             : 
      84           0 :       SELECT CASE (my_multi)
      85             :       CASE (do_scalapack)
      86           0 :          CALL timeset(routineN//"_gemm", handle1)
      87             :          CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
      88             :                          a_first_col=a_first_col, &
      89             :                          a_first_row=a_first_row, &
      90             :                          b_first_col=b_first_col, &
      91             :                          b_first_row=b_first_row, &
      92             :                          c_first_col=c_first_col, &
      93           0 :                          c_first_row=c_first_row)
      94           0 :          CALL timestop(handle1)
      95             :       CASE (do_cosma)
      96             : #if defined(__COSMA)
      97     1088264 :          CALL timeset(routineN//"_cosma", handle1)
      98     1088264 :          CALL offload_activate_chosen_device()
      99             :          CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
     100             :                            matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
     101             :                            a_first_col=a_first_col, &
     102             :                            a_first_row=a_first_row, &
     103             :                            b_first_col=b_first_col, &
     104             :                            b_first_row=b_first_row, &
     105             :                            c_first_col=c_first_col, &
     106     1088264 :                            c_first_row=c_first_row)
     107     2176528 :          CALL timestop(handle1)
     108             : #else
     109             :          CPABORT("CP2K compiled without the COSMA library.")
     110             : #endif
     111             :       END SELECT
     112     1088264 :       CALL timestop(handle)
     113             : 
     114     1088264 :    END SUBROUTINE parallel_gemm_fm
     115             : 
     116             : ! **************************************************************************************************
     117             : !> \brief ...
     118             : !> \param transa ...
     119             : !> \param transb ...
     120             : !> \param m ...
     121             : !> \param n ...
     122             : !> \param k ...
     123             : !> \param alpha ...
     124             : !> \param matrix_a ...
     125             : !> \param matrix_b ...
     126             : !> \param beta ...
     127             : !> \param matrix_c ...
     128             : !> \param a_first_col ...
     129             : !> \param a_first_row ...
     130             : !> \param b_first_col ...
     131             : !> \param b_first_row ...
     132             : !> \param c_first_col ...
     133             : !> \param c_first_row ...
     134             : ! **************************************************************************************************
     135      303384 :    SUBROUTINE parallel_gemm_cfm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
     136             :                                 matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
     137             :                                 c_first_col, c_first_row)
     138             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     139             :       INTEGER, INTENT(IN)                                :: m, n, k
     140             :       COMPLEX(KIND=dp), INTENT(IN)                       :: alpha
     141             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     142             :       COMPLEX(KIND=dp), INTENT(IN)                       :: beta
     143             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_c
     144             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     145             :                                                             b_first_row, c_first_col, c_first_row
     146             : 
     147             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'parallel_gemm_cfm'
     148             : 
     149             :       INTEGER                                            :: handle, handle1, my_multi
     150             : 
     151      303384 :       CALL timeset(routineN, handle)
     152             : 
     153      303384 :       my_multi = cp_fm_get_mm_type()
     154             : 
     155           0 :       SELECT CASE (my_multi)
     156             :       CASE (do_scalapack)
     157           0 :          CALL timeset(routineN//"_gemm", handle1)
     158             :          CALL cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     159             :                           a_first_col=a_first_col, &
     160             :                           a_first_row=a_first_row, &
     161             :                           b_first_col=b_first_col, &
     162             :                           b_first_row=b_first_row, &
     163             :                           c_first_col=c_first_col, &
     164           0 :                           c_first_row=c_first_row)
     165           0 :          CALL timestop(handle1)
     166             :       CASE (do_cosma)
     167             : #if defined(__COSMA)
     168      303384 :          CALL timeset(routineN//"_cosma", handle1)
     169      303384 :          CALL offload_activate_chosen_device()
     170             :          CALL cosma_pzgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
     171             :                            matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
     172             :                            a_first_col=a_first_col, &
     173             :                            a_first_row=a_first_row, &
     174             :                            b_first_col=b_first_col, &
     175             :                            b_first_row=b_first_row, &
     176             :                            c_first_col=c_first_col, &
     177      303384 :                            c_first_row=c_first_row)
     178      606768 :          CALL timestop(handle1)
     179             : #else
     180             :          CPABORT("CP2K compiled without the COSMA library.")
     181             : #endif
     182             :       END SELECT
     183      303384 :       CALL timestop(handle)
     184             : 
     185      303384 :    END SUBROUTINE parallel_gemm_cfm
     186             : 
     187             : #if defined(__COSMA)
     188             : ! **************************************************************************************************
     189             : !> \brief Fortran wrapper for cosma_pdgemm.
     190             : !> \param transa ...
     191             : !> \param transb ...
     192             : !> \param m ...
     193             : !> \param n ...
     194             : !> \param k ...
     195             : !> \param alpha ...
     196             : !> \param matrix_a ...
     197             : !> \param matrix_b ...
     198             : !> \param beta ...
     199             : !> \param matrix_c ...
     200             : !> \param a_first_col ...
     201             : !> \param a_first_row ...
     202             : !> \param b_first_col ...
     203             : !> \param b_first_row ...
     204             : !> \param c_first_col ...
     205             : !> \param c_first_row ...
     206             : !> \author Ole Schuett
     207             : ! **************************************************************************************************
     208     1088264 :    SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     209             :                            a_first_col, a_first_row, b_first_col, b_first_row, &
     210             :                            c_first_col, c_first_row)
     211             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     212             :       INTEGER, INTENT(IN)                                :: m, n, k
     213             :       REAL(KIND=dp), INTENT(IN)                          :: alpha
     214             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_a, matrix_b
     215             :       REAL(KIND=dp), INTENT(IN)                          :: beta
     216             :       TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c
     217             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     218             :                                                             b_first_row, c_first_col, c_first_row
     219             : 
     220             :       INTEGER                                            :: i_a, i_b, i_c, j_a, j_b, j_c
     221             :       INTERFACE
     222             :          SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
     223             :                                    b, ib, jb, descb, beta, c, ic, jc, descc) &
     224             :             BIND(C, name="cosma_pdgemm")
     225             :             IMPORT :: C_PTR, C_INT, C_DOUBLE, C_CHAR
     226             :             CHARACTER(KIND=C_CHAR)                    :: transa
     227             :             CHARACTER(KIND=C_CHAR)                    :: transb
     228             :             INTEGER(KIND=C_INT)                       :: m
     229             :             INTEGER(KIND=C_INT)                       :: n
     230             :             INTEGER(KIND=C_INT)                       :: k
     231             :             REAL(KIND=C_DOUBLE)                       :: alpha
     232             :             TYPE(C_PTR), VALUE                        :: a
     233             :             INTEGER(KIND=C_INT)                       :: ia
     234             :             INTEGER(KIND=C_INT)                       :: ja
     235             :             TYPE(C_PTR), VALUE                        :: desca
     236             :             TYPE(C_PTR), VALUE                        :: b
     237             :             INTEGER(KIND=C_INT)                       :: ib
     238             :             INTEGER(KIND=C_INT)                       :: jb
     239             :             TYPE(C_PTR), VALUE                        :: descb
     240             :             REAL(KIND=C_DOUBLE)                       :: beta
     241             :             TYPE(C_PTR), VALUE                        :: c
     242             :             INTEGER(KIND=C_INT)                       :: ic
     243             :             INTEGER(KIND=C_INT)                       :: jc
     244             :             TYPE(C_PTR), VALUE                        :: descc
     245             :          END SUBROUTINE cosma_pdgemm_c
     246             :       END INTERFACE
     247             : 
     248     1088264 :       IF (PRESENT(a_first_row)) THEN
     249        2694 :          i_a = a_first_row
     250             :       ELSE
     251     1085570 :          i_a = 1
     252             :       END IF
     253     1088264 :       IF (PRESENT(a_first_col)) THEN
     254        2694 :          j_a = a_first_col
     255             :       ELSE
     256     1085570 :          j_a = 1
     257             :       END IF
     258     1088264 :       IF (PRESENT(b_first_row)) THEN
     259        3044 :          i_b = b_first_row
     260             :       ELSE
     261     1085220 :          i_b = 1
     262             :       END IF
     263     1088264 :       IF (PRESENT(b_first_col)) THEN
     264        3928 :          j_b = b_first_col
     265             :       ELSE
     266     1084336 :          j_b = 1
     267             :       END IF
     268     1088264 :       IF (PRESENT(c_first_row)) THEN
     269        2450 :          i_c = c_first_row
     270             :       ELSE
     271     1085814 :          i_c = 1
     272             :       END IF
     273     1088264 :       IF (PRESENT(c_first_col)) THEN
     274        2468 :          j_c = c_first_col
     275             :       ELSE
     276     1085796 :          j_c = 1
     277             :       END IF
     278             : 
     279             :       CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
     280             :                           alpha=alpha, &
     281             :                           a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
     282             :                           desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
     283             :                           b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
     284             :                           descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
     285             :                           beta=beta, &
     286             :                           c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
     287     1088264 :                           descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
     288             : 
     289     1088264 :    END SUBROUTINE cosma_pdgemm
     290             : 
     291             : ! **************************************************************************************************
     292             : !> \brief Fortran wrapper for cosma_pdgemm.
     293             : !> \param transa ...
     294             : !> \param transb ...
     295             : !> \param m ...
     296             : !> \param n ...
     297             : !> \param k ...
     298             : !> \param alpha ...
     299             : !> \param matrix_a ...
     300             : !> \param matrix_b ...
     301             : !> \param beta ...
     302             : !> \param matrix_c ...
     303             : !> \param a_first_col ...
     304             : !> \param a_first_row ...
     305             : !> \param b_first_col ...
     306             : !> \param b_first_row ...
     307             : !> \param c_first_col ...
     308             : !> \param c_first_row ...
     309             : !> \author Ole Schuett
     310             : ! **************************************************************************************************
     311      303384 :    SUBROUTINE cosma_pzgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
     312             :                            a_first_col, a_first_row, b_first_col, b_first_row, &
     313             :                            c_first_col, c_first_row)
     314             :       CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
     315             :       INTEGER, INTENT(IN)                                :: m, n, k
     316             :       COMPLEX(KIND=dp), INTENT(IN)                       :: alpha
     317             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_a, matrix_b
     318             :       COMPLEX(KIND=dp), INTENT(IN)                       :: beta
     319             :       TYPE(cp_cfm_type), INTENT(IN)                      :: matrix_c
     320             :       INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
     321             :                                                             b_first_row, c_first_col, c_first_row
     322             : 
     323             :       INTEGER                                            :: i_a, i_b, i_c, j_a, j_b, j_c
     324             :       REAL(KIND=dp), DIMENSION(2), TARGET                :: alpha_t, beta_t
     325             :       INTERFACE
     326             :          SUBROUTINE cosma_pzgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
     327             :                                    b, ib, jb, descb, beta, c, ic, jc, descc) &
     328             :             BIND(C, name="cosma_pzgemm")
     329             :             IMPORT :: C_PTR, C_INT, C_CHAR
     330             :             CHARACTER(KIND=C_CHAR)                    :: transa
     331             :             CHARACTER(KIND=C_CHAR)                    :: transb
     332             :             INTEGER(KIND=C_INT)                       :: m
     333             :             INTEGER(KIND=C_INT)                       :: n
     334             :             INTEGER(KIND=C_INT)                       :: k
     335             :             TYPE(C_PTR), VALUE                        :: alpha
     336             :             TYPE(C_PTR), VALUE                        :: a
     337             :             INTEGER(KIND=C_INT)                       :: ia
     338             :             INTEGER(KIND=C_INT)                       :: ja
     339             :             TYPE(C_PTR), VALUE                        :: desca
     340             :             TYPE(C_PTR), VALUE                        :: b
     341             :             INTEGER(KIND=C_INT)                       :: ib
     342             :             INTEGER(KIND=C_INT)                       :: jb
     343             :             TYPE(C_PTR), VALUE                        :: descb
     344             :             TYPE(C_PTR), VALUE                        :: beta
     345             :             TYPE(C_PTR), VALUE                        :: c
     346             :             INTEGER(KIND=C_INT)                       :: ic
     347             :             INTEGER(KIND=C_INT)                       :: jc
     348             :             TYPE(C_PTR), VALUE                        :: descc
     349             :          END SUBROUTINE cosma_pzgemm_c
     350             :       END INTERFACE
     351             : 
     352      303384 :       IF (PRESENT(a_first_row)) THEN
     353           0 :          i_a = a_first_row
     354             :       ELSE
     355      303384 :          i_a = 1
     356             :       END IF
     357      303384 :       IF (PRESENT(a_first_col)) THEN
     358           0 :          j_a = a_first_col
     359             :       ELSE
     360      303384 :          j_a = 1
     361             :       END IF
     362      303384 :       IF (PRESENT(b_first_row)) THEN
     363           0 :          i_b = b_first_row
     364             :       ELSE
     365      303384 :          i_b = 1
     366             :       END IF
     367      303384 :       IF (PRESENT(b_first_col)) THEN
     368           0 :          j_b = b_first_col
     369             :       ELSE
     370      303384 :          j_b = 1
     371             :       END IF
     372      303384 :       IF (PRESENT(c_first_row)) THEN
     373           0 :          i_c = c_first_row
     374             :       ELSE
     375      303384 :          i_c = 1
     376             :       END IF
     377      303384 :       IF (PRESENT(c_first_col)) THEN
     378           0 :          j_c = c_first_col
     379             :       ELSE
     380      303384 :          j_c = 1
     381             :       END IF
     382             : 
     383      303384 :       alpha_t(1) = REAL(alpha, KIND=dp)
     384      303384 :       alpha_t(2) = REAL(AIMAG(alpha), KIND=dp)
     385      303384 :       beta_t(1) = REAL(beta, KIND=dp)
     386      303384 :       beta_t(2) = REAL(AIMAG(beta), KIND=dp)
     387             : 
     388             :       CALL cosma_pzgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
     389             :                           alpha=C_LOC(alpha_t), &
     390             :                           a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
     391             :                           desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
     392             :                           b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
     393             :                           descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
     394             :                           beta=C_LOC(beta_t), &
     395             :                           c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
     396      303384 :                           descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
     397             : 
     398      303384 :    END SUBROUTINE cosma_pzgemm
     399             : #endif
     400             : 
     401             : END MODULE parallel_gemm_api

Generated by: LCOV version 1.15