Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief basic linear algebra operations for full matrixes
10 : !> \par History
11 : !> 08.2002 splitted out of qs_blacs [fawzi]
12 : !> \author Fawzi Mohamed
13 : ! **************************************************************************************************
14 : MODULE parallel_gemm_api
15 : USE ISO_C_BINDING, ONLY: C_CHAR,&
16 : C_DOUBLE,&
17 : C_INT,&
18 : C_LOC,&
19 : C_PTR
20 : USE cp_cfm_basic_linalg, ONLY: cp_cfm_gemm
21 : USE cp_cfm_types, ONLY: cp_cfm_type
22 : USE cp_fm_basic_linalg, ONLY: cp_fm_gemm
23 : USE cp_fm_types, ONLY: cp_fm_get_mm_type,&
24 : cp_fm_type
25 : USE input_constants, ONLY: do_cosma,&
26 : do_scalapack
27 : USE kinds, ONLY: dp
28 : USE offload_api, ONLY: offload_activate_chosen_device
29 : #include "./base/base_uses.f90"
30 :
31 : IMPLICIT NONE
32 : PRIVATE
33 :
34 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'parallel_gemm_api'
35 :
36 : PUBLIC :: parallel_gemm
37 :
38 : INTERFACE parallel_gemm
39 : MODULE PROCEDURE parallel_gemm_fm
40 : MODULE PROCEDURE parallel_gemm_cfm
41 : END INTERFACE parallel_gemm
42 :
43 : CONTAINS
44 :
45 : ! **************************************************************************************************
46 : !> \brief ...
47 : !> \param transa ...
48 : !> \param transb ...
49 : !> \param m ...
50 : !> \param n ...
51 : !> \param k ...
52 : !> \param alpha ...
53 : !> \param matrix_a ...
54 : !> \param matrix_b ...
55 : !> \param beta ...
56 : !> \param matrix_c ...
57 : !> \param a_first_col ...
58 : !> \param a_first_row ...
59 : !> \param b_first_col ...
60 : !> \param b_first_row ...
61 : !> \param c_first_col ...
62 : !> \param c_first_row ...
63 : ! **************************************************************************************************
64 1097554 : SUBROUTINE parallel_gemm_fm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
65 : matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
66 : c_first_col, c_first_row)
67 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
68 : INTEGER, INTENT(IN) :: m, n, k
69 : REAL(KIND=dp), INTENT(IN) :: alpha
70 : TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
71 : REAL(KIND=dp), INTENT(IN) :: beta
72 : TYPE(cp_fm_type), INTENT(IN) :: matrix_c
73 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
74 : b_first_row, c_first_col, c_first_row
75 :
76 : CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_fm'
77 :
78 : INTEGER :: handle, handle1, my_multi
79 :
80 1097554 : CALL timeset(routineN, handle)
81 :
82 1097554 : my_multi = cp_fm_get_mm_type()
83 :
84 0 : SELECT CASE (my_multi)
85 : CASE (do_scalapack)
86 0 : CALL timeset(routineN//"_gemm", handle1)
87 : CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
88 : a_first_col=a_first_col, &
89 : a_first_row=a_first_row, &
90 : b_first_col=b_first_col, &
91 : b_first_row=b_first_row, &
92 : c_first_col=c_first_col, &
93 0 : c_first_row=c_first_row)
94 0 : CALL timestop(handle1)
95 : CASE (do_cosma)
96 : #if defined(__COSMA)
97 1097554 : CALL timeset(routineN//"_cosma", handle1)
98 1097554 : CALL offload_activate_chosen_device()
99 : CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
100 : matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
101 : a_first_col=a_first_col, &
102 : a_first_row=a_first_row, &
103 : b_first_col=b_first_col, &
104 : b_first_row=b_first_row, &
105 : c_first_col=c_first_col, &
106 1097554 : c_first_row=c_first_row)
107 2195108 : CALL timestop(handle1)
108 : #else
109 : CPABORT("CP2K compiled without the COSMA library.")
110 : #endif
111 : END SELECT
112 1097554 : CALL timestop(handle)
113 :
114 1097554 : END SUBROUTINE parallel_gemm_fm
115 :
116 : ! **************************************************************************************************
117 : !> \brief ...
118 : !> \param transa ...
119 : !> \param transb ...
120 : !> \param m ...
121 : !> \param n ...
122 : !> \param k ...
123 : !> \param alpha ...
124 : !> \param matrix_a ...
125 : !> \param matrix_b ...
126 : !> \param beta ...
127 : !> \param matrix_c ...
128 : !> \param a_first_col ...
129 : !> \param a_first_row ...
130 : !> \param b_first_col ...
131 : !> \param b_first_row ...
132 : !> \param c_first_col ...
133 : !> \param c_first_row ...
134 : ! **************************************************************************************************
135 303162 : SUBROUTINE parallel_gemm_cfm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
136 : matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
137 : c_first_col, c_first_row)
138 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
139 : INTEGER, INTENT(IN) :: m, n, k
140 : COMPLEX(KIND=dp), INTENT(IN) :: alpha
141 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
142 : COMPLEX(KIND=dp), INTENT(IN) :: beta
143 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
144 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
145 : b_first_row, c_first_col, c_first_row
146 :
147 : CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_cfm'
148 :
149 : INTEGER :: handle, handle1, my_multi
150 :
151 303162 : CALL timeset(routineN, handle)
152 :
153 303162 : my_multi = cp_fm_get_mm_type()
154 :
155 0 : SELECT CASE (my_multi)
156 : CASE (do_scalapack)
157 0 : CALL timeset(routineN//"_gemm", handle1)
158 : CALL cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
159 : a_first_col=a_first_col, &
160 : a_first_row=a_first_row, &
161 : b_first_col=b_first_col, &
162 : b_first_row=b_first_row, &
163 : c_first_col=c_first_col, &
164 0 : c_first_row=c_first_row)
165 0 : CALL timestop(handle1)
166 : CASE (do_cosma)
167 : #if defined(__COSMA)
168 303162 : CALL timeset(routineN//"_cosma", handle1)
169 303162 : CALL offload_activate_chosen_device()
170 : CALL cosma_pzgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
171 : matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
172 : a_first_col=a_first_col, &
173 : a_first_row=a_first_row, &
174 : b_first_col=b_first_col, &
175 : b_first_row=b_first_row, &
176 : c_first_col=c_first_col, &
177 303162 : c_first_row=c_first_row)
178 606324 : CALL timestop(handle1)
179 : #else
180 : CPABORT("CP2K compiled without the COSMA library.")
181 : #endif
182 : END SELECT
183 303162 : CALL timestop(handle)
184 :
185 303162 : END SUBROUTINE parallel_gemm_cfm
186 :
187 : #if defined(__COSMA)
188 : ! **************************************************************************************************
189 : !> \brief Fortran wrapper for cosma_pdgemm.
190 : !> \param transa ...
191 : !> \param transb ...
192 : !> \param m ...
193 : !> \param n ...
194 : !> \param k ...
195 : !> \param alpha ...
196 : !> \param matrix_a ...
197 : !> \param matrix_b ...
198 : !> \param beta ...
199 : !> \param matrix_c ...
200 : !> \param a_first_col ...
201 : !> \param a_first_row ...
202 : !> \param b_first_col ...
203 : !> \param b_first_row ...
204 : !> \param c_first_col ...
205 : !> \param c_first_row ...
206 : !> \author Ole Schuett
207 : ! **************************************************************************************************
208 1097554 : SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
209 : a_first_col, a_first_row, b_first_col, b_first_row, &
210 : c_first_col, c_first_row)
211 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
212 : INTEGER, INTENT(IN) :: m, n, k
213 : REAL(KIND=dp), INTENT(IN) :: alpha
214 : TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
215 : REAL(KIND=dp), INTENT(IN) :: beta
216 : TYPE(cp_fm_type), INTENT(IN) :: matrix_c
217 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
218 : b_first_row, c_first_col, c_first_row
219 :
220 : INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
221 : INTERFACE
222 : SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
223 : b, ib, jb, descb, beta, c, ic, jc, descc) &
224 : BIND(C, name="cosma_pdgemm")
225 : IMPORT :: C_PTR, C_INT, C_DOUBLE, C_CHAR
226 : CHARACTER(KIND=C_CHAR) :: transa
227 : CHARACTER(KIND=C_CHAR) :: transb
228 : INTEGER(KIND=C_INT) :: m
229 : INTEGER(KIND=C_INT) :: n
230 : INTEGER(KIND=C_INT) :: k
231 : REAL(KIND=C_DOUBLE) :: alpha
232 : TYPE(C_PTR), VALUE :: a
233 : INTEGER(KIND=C_INT) :: ia
234 : INTEGER(KIND=C_INT) :: ja
235 : TYPE(C_PTR), VALUE :: desca
236 : TYPE(C_PTR), VALUE :: b
237 : INTEGER(KIND=C_INT) :: ib
238 : INTEGER(KIND=C_INT) :: jb
239 : TYPE(C_PTR), VALUE :: descb
240 : REAL(KIND=C_DOUBLE) :: beta
241 : TYPE(C_PTR), VALUE :: c
242 : INTEGER(KIND=C_INT) :: ic
243 : INTEGER(KIND=C_INT) :: jc
244 : TYPE(C_PTR), VALUE :: descc
245 : END SUBROUTINE cosma_pdgemm_c
246 : END INTERFACE
247 :
248 1097554 : IF (PRESENT(a_first_row)) THEN
249 2694 : i_a = a_first_row
250 : ELSE
251 1094860 : i_a = 1
252 : END IF
253 1097554 : IF (PRESENT(a_first_col)) THEN
254 2694 : j_a = a_first_col
255 : ELSE
256 1094860 : j_a = 1
257 : END IF
258 1097554 : IF (PRESENT(b_first_row)) THEN
259 3044 : i_b = b_first_row
260 : ELSE
261 1094510 : i_b = 1
262 : END IF
263 1097554 : IF (PRESENT(b_first_col)) THEN
264 3928 : j_b = b_first_col
265 : ELSE
266 1093626 : j_b = 1
267 : END IF
268 1097554 : IF (PRESENT(c_first_row)) THEN
269 2450 : i_c = c_first_row
270 : ELSE
271 1095104 : i_c = 1
272 : END IF
273 1097554 : IF (PRESENT(c_first_col)) THEN
274 2468 : j_c = c_first_col
275 : ELSE
276 1095086 : j_c = 1
277 : END IF
278 :
279 : CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
280 : alpha=alpha, &
281 : a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
282 : desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
283 : b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
284 : descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
285 : beta=beta, &
286 : c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
287 1097554 : descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
288 :
289 1097554 : END SUBROUTINE cosma_pdgemm
290 :
291 : ! **************************************************************************************************
292 : !> \brief Fortran wrapper for cosma_pdgemm.
293 : !> \param transa ...
294 : !> \param transb ...
295 : !> \param m ...
296 : !> \param n ...
297 : !> \param k ...
298 : !> \param alpha ...
299 : !> \param matrix_a ...
300 : !> \param matrix_b ...
301 : !> \param beta ...
302 : !> \param matrix_c ...
303 : !> \param a_first_col ...
304 : !> \param a_first_row ...
305 : !> \param b_first_col ...
306 : !> \param b_first_row ...
307 : !> \param c_first_col ...
308 : !> \param c_first_row ...
309 : !> \author Ole Schuett
310 : ! **************************************************************************************************
311 303162 : SUBROUTINE cosma_pzgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
312 : a_first_col, a_first_row, b_first_col, b_first_row, &
313 : c_first_col, c_first_row)
314 : CHARACTER(LEN=1), INTENT(IN) :: transa, transb
315 : INTEGER, INTENT(IN) :: m, n, k
316 : COMPLEX(KIND=dp), INTENT(IN) :: alpha
317 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
318 : COMPLEX(KIND=dp), INTENT(IN) :: beta
319 : TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
320 : INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
321 : b_first_row, c_first_col, c_first_row
322 :
323 : INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
324 : REAL(KIND=dp), DIMENSION(2), TARGET :: alpha_t, beta_t
325 : INTERFACE
326 : SUBROUTINE cosma_pzgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
327 : b, ib, jb, descb, beta, c, ic, jc, descc) &
328 : BIND(C, name="cosma_pzgemm")
329 : IMPORT :: C_PTR, C_INT, C_CHAR
330 : CHARACTER(KIND=C_CHAR) :: transa
331 : CHARACTER(KIND=C_CHAR) :: transb
332 : INTEGER(KIND=C_INT) :: m
333 : INTEGER(KIND=C_INT) :: n
334 : INTEGER(KIND=C_INT) :: k
335 : TYPE(C_PTR), VALUE :: alpha
336 : TYPE(C_PTR), VALUE :: a
337 : INTEGER(KIND=C_INT) :: ia
338 : INTEGER(KIND=C_INT) :: ja
339 : TYPE(C_PTR), VALUE :: desca
340 : TYPE(C_PTR), VALUE :: b
341 : INTEGER(KIND=C_INT) :: ib
342 : INTEGER(KIND=C_INT) :: jb
343 : TYPE(C_PTR), VALUE :: descb
344 : TYPE(C_PTR), VALUE :: beta
345 : TYPE(C_PTR), VALUE :: c
346 : INTEGER(KIND=C_INT) :: ic
347 : INTEGER(KIND=C_INT) :: jc
348 : TYPE(C_PTR), VALUE :: descc
349 : END SUBROUTINE cosma_pzgemm_c
350 : END INTERFACE
351 :
352 303162 : IF (PRESENT(a_first_row)) THEN
353 0 : i_a = a_first_row
354 : ELSE
355 303162 : i_a = 1
356 : END IF
357 303162 : IF (PRESENT(a_first_col)) THEN
358 0 : j_a = a_first_col
359 : ELSE
360 303162 : j_a = 1
361 : END IF
362 303162 : IF (PRESENT(b_first_row)) THEN
363 0 : i_b = b_first_row
364 : ELSE
365 303162 : i_b = 1
366 : END IF
367 303162 : IF (PRESENT(b_first_col)) THEN
368 0 : j_b = b_first_col
369 : ELSE
370 303162 : j_b = 1
371 : END IF
372 303162 : IF (PRESENT(c_first_row)) THEN
373 0 : i_c = c_first_row
374 : ELSE
375 303162 : i_c = 1
376 : END IF
377 303162 : IF (PRESENT(c_first_col)) THEN
378 0 : j_c = c_first_col
379 : ELSE
380 303162 : j_c = 1
381 : END IF
382 :
383 303162 : alpha_t(1) = REAL(alpha, KIND=dp)
384 303162 : alpha_t(2) = REAL(AIMAG(alpha), KIND=dp)
385 303162 : beta_t(1) = REAL(beta, KIND=dp)
386 303162 : beta_t(2) = REAL(AIMAG(beta), KIND=dp)
387 :
388 : CALL cosma_pzgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
389 : alpha=C_LOC(alpha_t), &
390 : a=C_LOC(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
391 : desca=C_LOC(matrix_a%matrix_struct%descriptor(1)), &
392 : b=C_LOC(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
393 : descb=C_LOC(matrix_b%matrix_struct%descriptor(1)), &
394 : beta=C_LOC(beta_t), &
395 : c=C_LOC(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
396 303162 : descc=C_LOC(matrix_c%matrix_struct%descriptor(1)))
397 :
398 303162 : END SUBROUTINE cosma_pzgemm
399 : #endif
400 :
401 : END MODULE parallel_gemm_api
|