LCOV - code coverage report
Current view: top level - src - local_gemm_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 18 19 94.7 %
Date: 2024-11-21 06:45:46 Functions: 5 6 83.3 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : MODULE local_gemm_api
       9             :    USE ISO_C_BINDING, ONLY: C_NULL_PTR, &
      10             :                             C_PTR
      11             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      12             :    USE input_constants, ONLY: do_dgemm_spla
      13             :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
      14             :                             C_LOC
      15             :    USE spla, ONLY: SPLA_PU_HOST, &
      16             :                    SPLA_PU_GPU, &
      17             :                    SPLA_OP_NONE, &
      18             :                    SPLA_OP_TRANSPOSE, &
      19             :                    SPLA_OP_CONJ_TRANSPOSE, &
      20             :                    spla_ctx_create, &
      21             :                    spla_ctx_destroy, &
      22             :                    spla_dgemm, &
      23             :                    spla_sgemm, &
      24             :                    spla_cgemm, &
      25             :                    spla_zgemm, &
      26             :                    spla_ctx_set_op_threshold_gpu, &
      27             :                    SPLA_SUCCESS
      28             : #endif
      29             : 
      30             :    USE offload_api, ONLY: offload_activate_chosen_device
      31             : 
      32             : #include "./base/base_uses.f90"
      33             : 
      34             :    IMPLICIT NONE
      35             : 
      36             :    PRIVATE
      37             : 
      38             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
      39             : 
      40             :    PUBLIC :: local_gemm_ctxt_type, &
      41             :              local_gemm_set_library
      42             : 
      43             :    INTEGER, PARAMETER, PUBLIC :: &
      44             :       LOCAL_GEMM_PU_HOST = 0, &
      45             :       LOCAL_GEMM_PU_GPU = 1
      46             : 
      47             :    INTEGER, PRIVATE :: do_dgemm = 1
      48             : 
      49             :    TYPE local_gemm_ctxt_type
      50             :       TYPE(C_PTR) :: spla_context = C_NULL_PTR
      51             :    CONTAINS
      52             :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: create => local_gemm_create
      53             :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: destroy => local_gemm_destroy
      54             :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
      55             :       PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: gemm => local_gemm
      56             :    END TYPE
      57             : 
      58             : CONTAINS
      59             : 
      60             : ! **************************************************************************************************
      61             : !> \brief ...
      62             : !> \param opA ...
      63             : !> \param opB ...
      64             : !> \param m ...
      65             : !> \param n ...
      66             : !> \param k ...
      67             : !> \param alpha ...
      68             : !> \param A ...
      69             : !> \param lda ...
      70             : !> \param B ...
      71             : !> \param ldb ...
      72             : !> \param beta ...
      73             : !> \param C ...
      74             : !> \param ldc ...
      75             : !> \param ctx ...
      76             : ! **************************************************************************************************
      77      106576 :    SUBROUTINE local_gemm(opA, opB, m, n, k, &
      78       53288 :                          alpha, A, lda, B, ldb, &
      79       53288 :                          beta, C, ldc, ctx)
      80             :       CHARACTER, INTENT(in) :: opA
      81             :       CHARACTER, INTENT(in) :: opB
      82             :       INTEGER, INTENT(in) :: m
      83             :       INTEGER, INTENT(in) :: n
      84             :       INTEGER, INTENT(in) :: k
      85             :       REAL(8), INTENT(in) :: alpha
      86             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      87             :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
      88             : #else
      89             :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
      90             : #endif
      91             :       INTEGER, INTENT(in) :: lda
      92             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
      93             :       REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
      94             : #else
      95             :       REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
      96             : #endif
      97             : 
      98             :       INTEGER, INTENT(in) :: ldb
      99             :       REAL(8), INTENT(in) :: beta
     100             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     101             :       REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
     102             : #else
     103             :       REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
     104             : #endif
     105             :       INTEGER, INTENT(in) :: ldc
     106             :       CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
     107             : 
     108             :       INTEGER                                            :: handle
     109             : !     no point of using SPLA offloading on CPU ONLY nodes
     110             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     111             :       INTEGER :: spla_op_A, spla_op_B, spla_error
     112             : #endif
     113             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
     114       53288 :       CALL timeset(routineN, handle)
     115             : 
     116             : !     no point of using SPLA offloading on CPU ONLY nodes
     117             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     118             :       IF (do_dgemm == do_dgemm_spla) THEN
     119             : 
     120             :          IF (opA == 'N') spla_op_A = SPLA_OP_NONE
     121             :          IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
     122             : 
     123             :          IF (opB == 'N') spla_op_B = SPLA_OP_NONE
     124             :          IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
     125             : 
     126             : #if __GNUC__ >= 9
     127             :          CPASSERT(IS_CONTIGUOUS(A))
     128             :          CPASSERT(IS_CONTIGUOUS(B))
     129             :          CPASSERT(IS_CONTIGUOUS(C))
     130             : #endif
     131             : 
     132             :          CALL offload_activate_chosen_device()
     133             :          spla_error = spla_dgemm(spla_op_A, spla_op_B, &
     134             :                                  m, n, k, alpha, &
     135             :                                  c_loc(A), lda, &
     136             :                                  c_loc(B), ldb, &
     137             :                                  beta, c_loc(C), ldc, ctx%spla_context)
     138             :          CPASSERT(spla_error == SPLA_SUCCESS)
     139             :       ELSE
     140             : #endif
     141             :          CALL dgemm(opA, opB, m, n, k, alpha, &
     142             :                     A, lda, &
     143     1523838 :                     B, ldb, beta, C, ldc)
     144             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     145             :       END IF
     146             : #else
     147             :       MARK_USED(ctx)
     148             : #endif
     149       53288 :       CALL timestop(handle)
     150             : 
     151       53288 :    END SUBROUTINE local_gemm
     152             : 
     153             : ! **************************************************************************************************
     154             : !> \brief create a context for handling gemm offloading
     155             : !> \param ctx newly created context
     156             : !> \param pu processing unit to run the (s,d,c,z}dgemm
     157             : ! **************************************************************************************************
     158         408 :    SUBROUTINE local_gemm_create(ctx, pu)
     159             :       CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
     160             :       INTEGER, INTENT(in) :: pu
     161             : 
     162             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     163             :       INTEGER :: error_
     164             : 
     165             :       IF (.NOT. C_ASSOCIATED(ctx%spla_context)) THEN
     166             :          IF (do_dgemm == do_dgemm_spla) THEN
     167             :             CALL offload_activate_chosen_device()
     168             : 
     169             :             error_ = spla_ctx_create(ctx%spla_context, pu)
     170             :             CPASSERT(error_ == SPLA_SUCCESS)
     171             :          ELSE
     172             :             ctx%spla_context = C_NULL_PTR
     173             :          END IF
     174             :       END IF
     175             : #else
     176             :       MARK_USED(pu)
     177         408 :       ctx%spla_context = C_NULL_PTR
     178             : #endif
     179         408 :    END SUBROUTINE local_gemm_create
     180             : 
     181             : ! **************************************************************************************************
     182             : !> \brief release resources associated to a gemm context
     183             : !> \param ctx handle
     184             : ! **************************************************************************************************
     185         874 :    SUBROUTINE local_gemm_destroy(ctx)
     186             :       CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
     187             : 
     188             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     189             :       INTEGER :: error_
     190             : 
     191             :       IF (do_dgemm == do_dgemm_spla) THEN
     192             :          CALL offload_activate_chosen_device()
     193             : 
     194             :          error_ = spla_ctx_destroy(ctx%spla_context)
     195             :          CPASSERT(error_ == SPLA_SUCCESS)
     196             :       END IF
     197             : #endif
     198         874 :       ctx%spla_context = C_NULL_PTR
     199         874 :    END SUBROUTINE local_gemm_destroy
     200             : 
     201             : ! **************************************************************************************************
     202             : !> \brief ...
     203             : !> \param ctx ...
     204             : !> \param opThresholdGPU ...
     205             : ! **************************************************************************************************
     206         408 :    SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
     207             :       CLASS(local_gemm_ctxt_type), INTENT(INOUT)                                        :: ctx
     208             :       INTEGER, INTENT(in)                                :: opThresholdGPU
     209             : 
     210             : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
     211             :       INTEGER                                            :: error__
     212             : 
     213             :       CALL offload_activate_chosen_device()
     214             :       error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opThresholdGPU)
     215             : #else
     216             :       MARK_USED(ctx)
     217             :       MARK_USED(opThresholdGPU)
     218             : #endif
     219         408 :    END SUBROUTINE local_gemm_set_op_threshold_gpu
     220             : 
     221             : ! **************************************************************************************************
     222             : !> \brief ...
     223             : !> \param dgemm_library ...
     224             : ! **************************************************************************************************
     225        9127 :    SUBROUTINE local_gemm_set_library(dgemm_library)
     226             :       INTEGER, INTENT(IN)                                :: dgemm_library
     227             : 
     228        9127 :       do_dgemm = dgemm_library
     229        9127 :    END SUBROUTINE local_gemm_set_library
     230             : 
     231           0 : END MODULE local_gemm_api

Generated by: LCOV version 1.15