Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : MODULE local_gemm_api
9 : USE ISO_C_BINDING, ONLY: C_NULL_PTR, &
10 : C_PTR
11 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
12 : USE input_constants, ONLY: do_dgemm_spla
13 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
14 : C_LOC
15 : USE spla, ONLY: SPLA_PU_HOST, &
16 : SPLA_PU_GPU, &
17 : SPLA_OP_NONE, &
18 : SPLA_OP_TRANSPOSE, &
19 : SPLA_OP_CONJ_TRANSPOSE, &
20 : spla_ctx_create, &
21 : spla_ctx_destroy, &
22 : spla_dgemm, &
23 : spla_sgemm, &
24 : spla_cgemm, &
25 : spla_zgemm, &
26 : spla_ctx_set_op_threshold_gpu, &
27 : SPLA_SUCCESS
28 : #endif
29 :
30 : USE offload_api, ONLY: offload_activate_chosen_device
31 :
32 : #include "./base/base_uses.f90"
33 :
34 : IMPLICIT NONE
35 :
36 : PRIVATE
37 :
38 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
39 :
40 : PUBLIC :: local_gemm_ctxt_type, &
41 : local_gemm_set_library
42 :
43 : INTEGER, PARAMETER, PUBLIC :: &
44 : LOCAL_GEMM_PU_HOST = 0, &
45 : LOCAL_GEMM_PU_GPU = 1
46 :
47 : INTEGER, PRIVATE :: do_dgemm = 1
48 :
49 : TYPE local_gemm_ctxt_type
50 : TYPE(C_PTR) :: spla_context = C_NULL_PTR
51 : CONTAINS
52 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: create => local_gemm_create
53 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: destroy => local_gemm_destroy
54 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
55 : PROCEDURE, PASS(ctx), NON_OVERRIDABLE :: gemm => local_gemm
56 : END TYPE
57 :
58 : CONTAINS
59 :
60 : ! **************************************************************************************************
61 : !> \brief ...
62 : !> \param opA ...
63 : !> \param opB ...
64 : !> \param m ...
65 : !> \param n ...
66 : !> \param k ...
67 : !> \param alpha ...
68 : !> \param A ...
69 : !> \param lda ...
70 : !> \param B ...
71 : !> \param ldb ...
72 : !> \param beta ...
73 : !> \param C ...
74 : !> \param ldc ...
75 : !> \param ctx ...
76 : ! **************************************************************************************************
77 106576 : SUBROUTINE local_gemm(opA, opB, m, n, k, &
78 53288 : alpha, A, lda, B, ldb, &
79 53288 : beta, C, ldc, ctx)
80 : CHARACTER, INTENT(in) :: opA
81 : CHARACTER, INTENT(in) :: opB
82 : INTEGER, INTENT(in) :: m
83 : INTEGER, INTENT(in) :: n
84 : INTEGER, INTENT(in) :: k
85 : REAL(8), INTENT(in) :: alpha
86 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
87 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
88 : #else
89 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
90 : #endif
91 : INTEGER, INTENT(in) :: lda
92 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
93 : REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
94 : #else
95 : REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
96 : #endif
97 :
98 : INTEGER, INTENT(in) :: ldb
99 : REAL(8), INTENT(in) :: beta
100 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
101 : REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
102 : #else
103 : REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
104 : #endif
105 : INTEGER, INTENT(in) :: ldc
106 : CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
107 :
108 : INTEGER :: handle
109 : ! no point of using SPLA offloading on CPU ONLY nodes
110 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111 : INTEGER :: spla_op_A, spla_op_B, spla_error
112 : #endif
113 : CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
114 53288 : CALL timeset(routineN, handle)
115 :
116 : ! no point of using SPLA offloading on CPU ONLY nodes
117 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
118 : IF (do_dgemm == do_dgemm_spla) THEN
119 :
120 : IF (opA == 'N') spla_op_A = SPLA_OP_NONE
121 : IF (opA == 'T') spla_op_A = SPLA_OP_TRANSPOSE
122 :
123 : IF (opB == 'N') spla_op_B = SPLA_OP_NONE
124 : IF (opB == 'T') spla_op_B = SPLA_OP_TRANSPOSE
125 :
126 : #if __GNUC__ >= 9
127 : CPASSERT(IS_CONTIGUOUS(A))
128 : CPASSERT(IS_CONTIGUOUS(B))
129 : CPASSERT(IS_CONTIGUOUS(C))
130 : #endif
131 :
132 : CALL offload_activate_chosen_device()
133 : spla_error = spla_dgemm(spla_op_A, spla_op_B, &
134 : m, n, k, alpha, &
135 : c_loc(A), lda, &
136 : c_loc(B), ldb, &
137 : beta, c_loc(C), ldc, ctx%spla_context)
138 : CPASSERT(spla_error == SPLA_SUCCESS)
139 : ELSE
140 : #endif
141 : CALL dgemm(opA, opB, m, n, k, alpha, &
142 : A, lda, &
143 1523838 : B, ldb, beta, C, ldc)
144 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
145 : END IF
146 : #else
147 : MARK_USED(ctx)
148 : #endif
149 53288 : CALL timestop(handle)
150 :
151 53288 : END SUBROUTINE local_gemm
152 :
153 : ! **************************************************************************************************
154 : !> \brief create a context for handling gemm offloading
155 : !> \param ctx newly created context
156 : !> \param pu processing unit to run the (s,d,c,z}dgemm
157 : ! **************************************************************************************************
158 408 : SUBROUTINE local_gemm_create(ctx, pu)
159 : CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
160 : INTEGER, INTENT(in) :: pu
161 :
162 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
163 : INTEGER :: error_
164 :
165 : IF (.NOT. C_ASSOCIATED(ctx%spla_context)) THEN
166 : IF (do_dgemm == do_dgemm_spla) THEN
167 : CALL offload_activate_chosen_device()
168 :
169 : error_ = spla_ctx_create(ctx%spla_context, pu)
170 : CPASSERT(error_ == SPLA_SUCCESS)
171 : ELSE
172 : ctx%spla_context = C_NULL_PTR
173 : END IF
174 : END IF
175 : #else
176 : MARK_USED(pu)
177 408 : ctx%spla_context = C_NULL_PTR
178 : #endif
179 408 : END SUBROUTINE local_gemm_create
180 :
181 : ! **************************************************************************************************
182 : !> \brief release resources associated to a gemm context
183 : !> \param ctx handle
184 : ! **************************************************************************************************
185 874 : SUBROUTINE local_gemm_destroy(ctx)
186 : CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
187 :
188 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
189 : INTEGER :: error_
190 :
191 : IF (do_dgemm == do_dgemm_spla) THEN
192 : CALL offload_activate_chosen_device()
193 :
194 : error_ = spla_ctx_destroy(ctx%spla_context)
195 : CPASSERT(error_ == SPLA_SUCCESS)
196 : END IF
197 : #endif
198 874 : ctx%spla_context = C_NULL_PTR
199 874 : END SUBROUTINE local_gemm_destroy
200 :
201 : ! **************************************************************************************************
202 : !> \brief ...
203 : !> \param ctx ...
204 : !> \param opThresholdGPU ...
205 : ! **************************************************************************************************
206 408 : SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
207 : CLASS(local_gemm_ctxt_type), INTENT(INOUT) :: ctx
208 : INTEGER, INTENT(in) :: opThresholdGPU
209 :
210 : #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
211 : INTEGER :: error__
212 :
213 : CALL offload_activate_chosen_device()
214 : error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opThresholdGPU)
215 : #else
216 : MARK_USED(ctx)
217 : MARK_USED(opThresholdGPU)
218 : #endif
219 408 : END SUBROUTINE local_gemm_set_op_threshold_gpu
220 :
221 : ! **************************************************************************************************
222 : !> \brief ...
223 : !> \param dgemm_library ...
224 : ! **************************************************************************************************
225 9127 : SUBROUTINE local_gemm_set_library(dgemm_library)
226 : INTEGER, INTENT(IN) :: dgemm_library
227 :
228 9127 : do_dgemm = dgemm_library
229 9127 : END SUBROUTINE local_gemm_set_library
230 :
231 0 : END MODULE local_gemm_api
|