Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <stddef.h>
10 : #include <string.h>
11 :
12 : #if defined(__LIBXSMM)
13 : #include <libxsmm.h>
14 : #if !defined(DBM_LIBXSMM_PREFETCH)
15 : // #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
16 : #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
17 : #endif
18 : #if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
19 : #define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
20 : #endif
21 : #endif
22 :
23 : #include "dbm_hyperparams.h"
24 : #include "dbm_multiply_cpu.h"
25 :
26 : /*******************************************************************************
27 : * \brief Prototype for BLAS dgemm.
28 : * \author Ole Schuett
29 : ******************************************************************************/
30 : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
31 : const int *k, const double *alpha, const double *a, const int *lda,
32 : const double *b, const int *ldb, const double *beta, double *c,
33 : const int *ldc);
34 :
35 : /*******************************************************************************
36 : * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
37 : * \author Ole Schuett
38 : ******************************************************************************/
39 2665586 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
40 : const int n, const int k, const double alpha,
41 : const double *a, const int lda, const double *b,
42 : const int ldb, const double beta, double *c,
43 : const int ldc) {
44 :
45 2665586 : dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
46 : &ldc);
47 : }
48 :
49 : /*******************************************************************************
50 : * \brief Private hash function based on Szudzik's elegant pairing.
51 : * Using unsigned int to return a positive number even after overflow.
52 : * https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
53 : * https://stackoverflow.com/a/13871379
54 : * http://szudzik.com/ElegantPairing.pdf
55 : * \author Ole Schuett
56 : ******************************************************************************/
57 : #if defined(__LIBXSMM)
58 43340518 : static inline unsigned int hash(const dbm_task_t task) {
59 43340518 : const unsigned int m = task.m, n = task.n, k = task.k;
60 43340518 : const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
61 43340518 : const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
62 43340518 : return mnk;
63 : }
64 : #endif
65 :
66 : /*******************************************************************************
67 : * \brief Internal routine for executing the tasks in given batch on the CPU.
68 : * \author Ole Schuett
69 : ******************************************************************************/
70 223818 : void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks],
71 : const double alpha,
72 : const dbm_pack_t *pack_a,
73 : const dbm_pack_t *pack_b,
74 223818 : dbm_shard_t *shard_c) {
75 :
76 223818 : if (0 >= ntasks) { // nothing to do
77 38765 : return;
78 : }
79 185053 : dbm_shard_allocate_promised_blocks(shard_c);
80 :
81 : #if defined(__LIBXSMM)
82 :
83 : // Sort tasks approximately by m,n,k via bucket sort.
84 185053 : int buckets[BATCH_NUM_BUCKETS];
85 185053 : memset(buckets, 0, BATCH_NUM_BUCKETS * sizeof(int));
86 21855312 : for (int itask = 0; itask < ntasks; ++itask) {
87 21670259 : const int i = hash(batch[itask]) % BATCH_NUM_BUCKETS;
88 21670259 : ++buckets[i];
89 : }
90 185053000 : for (int i = 1; i < BATCH_NUM_BUCKETS; ++i) {
91 184867947 : buckets[i] += buckets[i - 1];
92 : }
93 185053 : assert(buckets[BATCH_NUM_BUCKETS - 1] == ntasks);
94 185053 : int batch_order[ntasks];
95 21855312 : for (int itask = 0; itask < ntasks; ++itask) {
96 21670259 : const int i = hash(batch[itask]) % BATCH_NUM_BUCKETS;
97 21670259 : --buckets[i];
98 21670259 : batch_order[buckets[i]] = itask;
99 : }
100 :
101 : // Prepare arguments for libxsmm's kernel-dispatch.
102 185053 : const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
103 185053 : const int prefetch = DBM_LIBXSMM_PREFETCH;
104 185053 : int kernel_m = 0, kernel_n = 0, kernel_k = 0;
105 185053 : dbm_task_t task_next = batch[batch_order[0]];
106 :
107 : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
108 : double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
109 : #endif
110 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
111 185053 : libxsmm_gemmfunction kernel_func = NULL;
112 : #else
113 : libxsmm_dmmfunction kernel_func = NULL;
114 : const double beta = 1.0;
115 : #endif
116 :
117 : // Loop over tasks.
118 21855312 : for (int itask = 0; itask < ntasks; ++itask) {
119 21670259 : const dbm_task_t task = task_next;
120 21670259 : task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
121 :
122 21670259 : if (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k) {
123 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
124 1483651 : const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
125 : task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb*/,
126 : task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
127 : LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
128 : LIBXSMM_DATATYPE_F64 /*calcp*/);
129 2967302 : kernel_func = (LIBXSMM_FEQ(1.0, alpha)
130 1209963 : ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
131 : (libxsmm_bitfield)prefetch)
132 1483651 : : NULL);
133 : #else
134 : kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
135 : NULL /*ldb*/, NULL /*ldc*/, &alpha,
136 : &beta, &flags, &prefetch);
137 : #endif
138 1483651 : kernel_m = task.m;
139 1483651 : kernel_n = task.n;
140 1483651 : kernel_k = task.k;
141 : }
142 :
143 : // gemm_param wants non-const data even for A and B
144 21670259 : double *const data_a = pack_a->data + task.offset_a;
145 21670259 : double *const data_b = pack_b->data + task.offset_b;
146 21670259 : double *const data_c = shard_c->data + task.offset_c;
147 :
148 21670259 : if (kernel_func != NULL) {
149 : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
150 19004673 : libxsmm_gemm_param gemm_param;
151 19004673 : gemm_param.a.primary = data_a;
152 19004673 : gemm_param.b.primary = data_b;
153 19004673 : gemm_param.c.primary = data_c;
154 : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
155 : gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
156 : gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
157 : gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
158 : #endif
159 19004673 : kernel_func(&gemm_param);
160 : #elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
161 : kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
162 : pack_b->data + task_next.offset_b,
163 : shard_c->data + task_next.offset_c);
164 : #else
165 : kernel_func(data_a, data_b, data_c);
166 : #endif
167 : } else {
168 2665586 : dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
169 : task.n, 1.0, data_c, task.m);
170 : }
171 : }
172 : #else
173 : // Fallback to BLAS when libxsmm is not available.
174 : for (int itask = 0; itask < ntasks; ++itask) {
175 : const dbm_task_t task = batch[itask];
176 : const double *data_a = &pack_a->data[task.offset_a];
177 : const double *data_b = &pack_b->data[task.offset_b];
178 : double *data_c = &shard_c->data[task.offset_c];
179 : dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
180 : task.n, 1.0, data_c, task.m);
181 : }
182 : #endif
183 : }
184 :
185 : // EOF
|