Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <limits.h>
10 : #include <stdlib.h>
11 : #include <string.h>
12 :
13 : #include "../offload/offload_runtime.h"
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_internal.h"
16 : #include "dbm_library.h"
17 : #include "dbm_multiply.h"
18 : #include "dbm_multiply_comm.h"
19 : #include "dbm_multiply_cpu.h"
20 : #include "dbm_multiply_gpu.h"
21 :
22 : #if defined(__LIBXSMM)
23 : #include <libxsmm.h>
24 : #endif
25 :
26 : #if !defined(DBM_VALIDATE_AGAINST_LIBXSMM) && 0
27 : #define DBM_VALIDATE_AGAINST_LIBXSMM
28 : #endif
29 :
30 : /*******************************************************************************
31 : * \brief Updates the min/max of a range of values (initially {INT_MAX, 0}).
32 : * \author Hans Pabst
33 : ******************************************************************************/
34 : static inline void min_max(int result[2], int value) {
35 : if (value < result[0]) {
36 : result[0] = value;
37 : }
38 : if (result[1] < value) {
39 : result[1] = value;
40 : }
41 : }
42 :
43 : /*******************************************************************************
44 : * \brief Private routine for computing the max filter threshold for each row.
45 : * \author Ole Schuett
46 : ******************************************************************************/
47 211795 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
48 : const double filter_eps) {
49 211795 : const int nrows = (trans) ? matrix->ncols : matrix->nrows;
50 211795 : int *nblocks_per_row = calloc(nrows, sizeof(int));
51 211795 : float *row_max_eps = malloc(nrows * sizeof(float));
52 211795 : assert(row_max_eps != NULL);
53 :
54 211795 : #pragma omp parallel
55 : {
56 : #pragma omp for
57 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
58 : dbm_shard_t *shard = &matrix->shards[ishard];
59 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
60 : const dbm_block_t *blk = &shard->blocks[iblock];
61 : const int row = (trans) ? blk->col : blk->row;
62 : #pragma omp atomic
63 : nblocks_per_row[row]++;
64 : }
65 : }
66 : #pragma omp single
67 : dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
68 : #pragma omp barrier
69 : #pragma omp for
70 : for (int i = 0; i < nrows; i++) {
71 : const float f =
72 : ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
73 : row_max_eps[i] = f * f;
74 : }
75 : } // end of omp parallel region
76 :
77 211795 : free(nblocks_per_row);
78 211795 : return row_max_eps; // Ownership of row_max_eps transfers to caller.
79 : }
80 :
81 : /*******************************************************************************
82 : * \brief Private struct for storing the context of the multiplication backend.
83 : * \author Ole Schuett
84 : ******************************************************************************/
85 : typedef struct {
86 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
87 : dbm_multiply_gpu_context_t gpu;
88 : #endif
89 : } backend_context_t;
90 :
91 : /*******************************************************************************
92 : * \brief Private routine for intializing the multiplication backend.
93 : * \author Ole Schuett
94 : ******************************************************************************/
95 211795 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
96 211795 : backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
97 :
98 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
99 : dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
100 : matrix_c->shards, &ctx->gpu);
101 : #else
102 211795 : (void)matrix_c; // mark as used
103 : #endif
104 :
105 211795 : return ctx;
106 : }
107 :
108 : /*******************************************************************************
109 : * \brief Private routine for handing newly arrived packs to the backend.
110 : * \author Ole Schuett
111 : ******************************************************************************/
112 0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
113 : const dbm_pack_t *pack_b,
114 : backend_context_t *ctx) {
115 :
116 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
117 : dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
118 : #else
119 0 : (void)pack_a; // mark as used
120 0 : (void)pack_b;
121 0 : (void)ctx;
122 : #endif
123 0 : }
124 :
125 : /*******************************************************************************
126 : * \brief Private routine for sending a batch to the multiplication backend.
127 : * \author Ole Schuett
128 : ******************************************************************************/
129 229286 : static void backend_process_batch(const int ntasks, dbm_task_t batch[ntasks],
130 : const int mnk_range[3][2], const double alpha,
131 : const dbm_pack_t *pack_a,
132 : const dbm_pack_t *pack_b, const int kshard,
133 : dbm_shard_t *shard_c,
134 : backend_context_t *ctx) {
135 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
136 : dbm_multiply_gpu_process_batch(ntasks, batch, mnk_range, alpha, kshard,
137 : &ctx->gpu);
138 : #if defined(DBM_VALIDATE_AGAINST_LIBXSMM) && defined(__LIBXSMM)
139 : dbm_shard_gpu_t *const shard_g = &ctx->gpu.shards_c_dev[kshard];
140 : dbm_shard_t shard_r;
141 : dbm_shard_allocate_promised_blocks(shard_c);
142 : /* start transferring GPU result to host */
143 : assert(shard_c->data_size == shard_g->data_size);
144 : dbm_shard_init(&shard_r);
145 : dbm_shard_copy(&shard_r, shard_c);
146 : offloadMemcpyAsyncDtoH(shard_c->data, shard_g->data,
147 : shard_c->data_size * sizeof(double), shard_g->stream);
148 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
149 : &shard_r);
150 : /* finish transferring GPU result to host */
151 : offloadStreamSynchronize(shard_g->stream);
152 : libxsmm_matdiff_info diff;
153 : libxsmm_matdiff_clear(&diff);
154 : for (int itask = 0; itask < ntasks; ++itask) {
155 : const dbm_task_t task = batch[itask];
156 : const double *const tst = &shard_c->data[task.offset_c];
157 : const double *const ref = &shard_r.data[task.offset_c];
158 : libxsmm_matdiff_info d;
159 : if (EXIT_SUCCESS == libxsmm_matdiff(&d, LIBXSMM_DATATYPE(double), task.m,
160 : task.n, ref, tst, NULL /*ldref*/,
161 : NULL /*ldtst*/)) {
162 : libxsmm_matdiff_reduce(&diff, &d);
163 : }
164 : }
165 : const double epsilon = libxsmm_matdiff_epsilon(&diff);
166 : if (1E-15 < epsilon) {
167 : fprintf(stderr, "INFO ACC/LIBDBM: mnk=%ix%ix%i ntasks=%i diff=%g\n",
168 : mnk_range[0][1], mnk_range[1][1], mnk_range[2][1], ntasks, epsilon);
169 : }
170 : dbm_shard_release(&shard_r);
171 : #else
172 : (void)pack_a;
173 : (void)pack_b;
174 : (void)shard_c; // mark as used
175 : #endif
176 : #else
177 229286 : (void)mnk_range;
178 229286 : (void)kshard;
179 229286 : (void)ctx; // mark as used
180 229286 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
181 : #endif
182 229286 : }
183 :
184 : /*******************************************************************************
185 : * \brief Private routine for downloading results of the multiplication backend.
186 : * \author Ole Schuett
187 : ******************************************************************************/
188 0 : static void backend_download_results(backend_context_t *ctx) {
189 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
190 : dbm_multiply_gpu_download_results(&ctx->gpu);
191 : #else
192 0 : (void)ctx; // mark as used
193 : #endif
194 0 : }
195 :
196 : /*******************************************************************************
197 : * \brief Private routine for shutting down the multiplication backend.
198 : * \author Ole Schuett
199 : ******************************************************************************/
200 211795 : static void backend_stop(backend_context_t *ctx) {
201 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
202 : dbm_multiply_gpu_stop(&ctx->gpu);
203 : #endif
204 211795 : free(ctx);
205 211795 : }
206 :
207 : /*******************************************************************************
208 : * \brief Private routine for multipling two packs.
209 : * \author Ole Schuett
210 : ******************************************************************************/
211 229231 : static void multiply_packs(const bool transa, const bool transb,
212 : const double alpha, const dbm_pack_t *pack_a,
213 : const dbm_pack_t *pack_b,
214 : const dbm_matrix_t *matrix_a,
215 : const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
216 : const bool retain_sparsity,
217 : const float *rows_max_eps, int64_t *flop,
218 229231 : backend_context_t *ctx) {
219 229231 : const float alpha2 = alpha * alpha;
220 229231 : int64_t flop_sum = 0;
221 :
222 229231 : const int nshard_rows = matrix_c->dist->rows.nshards;
223 229231 : const int nshard_cols = matrix_c->dist->cols.nshards;
224 229231 : int shard_row_start[nshard_rows], shard_col_start[nshard_cols];
225 229231 : memset(shard_row_start, 0, nshard_rows * sizeof(int));
226 229231 : memset(shard_col_start, 0, nshard_cols * sizeof(int));
227 :
228 229231 : const int *sum_index_sizes_a =
229 : (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
230 229231 : const int *sum_index_sizes_b =
231 : (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
232 229231 : const int *free_index_sizes_a =
233 : (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
234 229231 : const int *free_index_sizes_b =
235 : (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
236 :
237 229231 : #pragma omp parallel reduction(+ : flop_sum)
238 : {
239 : // Blocks are ordered first by shard. Creating lookup tables of boundaries.
240 : #pragma omp for nowait
241 : for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
242 : const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
243 : const int prev_shard_row =
244 : pack_a->blocks[iblock - 1].free_index % nshard_rows;
245 : if (prev_shard_row != shard_row) {
246 : shard_row_start[shard_row] = iblock;
247 : }
248 : }
249 : #pragma omp for
250 : for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
251 : const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
252 : const int prev_shard_col =
253 : pack_b->blocks[jblock - 1].free_index % nshard_cols;
254 : if (prev_shard_col != shard_col) {
255 : shard_col_start[shard_col] = jblock;
256 : }
257 : }
258 :
259 : #pragma omp for collapse(2) DBM_OMP_SCHEDULE
260 : for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
261 : for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
262 : const int ishard = shard_row * nshard_cols + shard_col;
263 : dbm_shard_t *shard_c = &matrix_c->shards[ishard];
264 : dbm_task_t batch[DBM_MAX_BATCH_SIZE];
265 : int mnk_range[][2] = {{INT_MAX, 0}, {INT_MAX, 0}, {INT_MAX, 0}};
266 : int ntasks = 0;
267 :
268 : // Use a merge-join to find pairs of blocks with matching sum indices.
269 : // This utilizes that blocks within a shard are ordered by sum_index.
270 : const int iblock_start = shard_row_start[shard_row];
271 : int jblock_start = shard_col_start[shard_col];
272 : for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
273 : const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
274 : if (blk_a->free_index % nshard_rows != shard_row) {
275 : break;
276 : }
277 : for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
278 : const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
279 : if (blk_b->free_index % nshard_cols != shard_col) {
280 : break;
281 : }
282 : if (blk_a->sum_index < blk_b->sum_index) {
283 : break;
284 : }
285 : if (blk_a->sum_index > blk_b->sum_index) {
286 : jblock_start++;
287 : continue;
288 : }
289 : // Found block pair with blk_a->sum_index == blk_b->sum_index.
290 :
291 : // Check norms.
292 : const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
293 : if (result_norm < rows_max_eps[blk_a->free_index]) {
294 : continue;
295 : }
296 :
297 : // Check block sizes.
298 : const int m = free_index_sizes_a[blk_a->free_index];
299 : const int n = free_index_sizes_b[blk_b->free_index];
300 : const int k = sum_index_sizes_a[blk_a->sum_index];
301 : assert(m == matrix_c->row_sizes[blk_a->free_index]);
302 : assert(n == matrix_c->col_sizes[blk_b->free_index]);
303 : assert(k == sum_index_sizes_b[blk_b->sum_index]);
304 :
305 : // Get C block.
306 : const int row = blk_a->free_index, col = blk_b->free_index;
307 : dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
308 : if (blk_c == NULL && retain_sparsity) {
309 : continue;
310 : } else if (blk_c == NULL) {
311 : assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
312 : assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
313 : matrix_c->dist->my_rank);
314 : blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
315 : }
316 :
317 : // Count flops.
318 : const int64_t task_flops = 2LL * m * n * k;
319 : if (task_flops == 0) {
320 : continue;
321 : }
322 : flop_sum += task_flops;
323 : dbm_library_counter_increment(m, n, k);
324 :
325 : // Add block multiplication to batch.
326 : batch[ntasks].m = m;
327 : batch[ntasks].n = n;
328 : batch[ntasks].k = k;
329 : batch[ntasks].offset_a = blk_a->offset;
330 : batch[ntasks].offset_b = blk_b->offset;
331 : batch[ntasks].offset_c = blk_c->offset;
332 : ntasks++;
333 :
334 : // track MxN-shape covering an entire batch
335 : min_max(mnk_range[0], m);
336 : min_max(mnk_range[1], n);
337 : min_max(mnk_range[2], k);
338 :
339 : if (ntasks == DBM_MAX_BATCH_SIZE) {
340 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a,
341 : pack_b, ishard, shard_c, ctx);
342 : mnk_range[0][0] = mnk_range[1][0] = mnk_range[2][0] = INT_MAX;
343 : mnk_range[0][1] = mnk_range[1][1] = mnk_range[2][1] = 0;
344 : ntasks = 0;
345 : }
346 : }
347 : }
348 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a, pack_b,
349 : ishard, shard_c, ctx);
350 : }
351 : }
352 : }
353 229231 : *flop += flop_sum;
354 229231 : }
355 :
356 : /*******************************************************************************
357 : * \brief Performs a multiplication of two dbm_matrix_t matrices.
358 : * See dbm_matrix.h for details.
359 : * \author Ole Schuett
360 : ******************************************************************************/
361 211795 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
362 : const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
363 : const double beta, dbm_matrix_t *matrix_c,
364 : const bool retain_sparsity, const double filter_eps,
365 : int64_t *flop) {
366 :
367 211795 : assert(omp_get_num_threads() == 1);
368 :
369 : // Throughout the matrix multiplication code the "sum_index" and "free_index"
370 : // denote the summation (aka dummy) and free index from the Einstein notation.
371 211795 : const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
372 211795 : const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
373 211795 : const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
374 211795 : const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
375 :
376 : // Sanity check matrix dimensions.
377 211795 : assert(num_sum_index_a == num_sum_index_b);
378 211795 : assert(num_free_index_a == matrix_c->nrows);
379 211795 : assert(num_free_index_b == matrix_c->ncols);
380 :
381 : // Prepare matrix_c.
382 211795 : dbm_scale(matrix_c, beta);
383 :
384 : // Start uploading matrix_c to the GPU.
385 211795 : backend_context_t *ctx = backend_start(matrix_c);
386 :
387 : // Compute filter thresholds for each row.
388 211795 : float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
389 :
390 : // Redistribute matrix_a and matrix_b across MPI ranks.
391 211795 : dbm_comm_iterator_t *iter =
392 211795 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
393 :
394 : // Main loop.
395 211795 : *flop = 0;
396 211795 : dbm_pack_t *pack_a, *pack_b;
397 441026 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
398 229231 : backend_upload_packs(pack_a, pack_b, ctx);
399 229231 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
400 : matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
401 : }
402 :
403 : // Start downloading matrix_c from the GPU.
404 211795 : backend_download_results(ctx);
405 :
406 : // Wait for all other MPI ranks to complete, then release ressources.
407 211795 : dbm_comm_iterator_stop(iter);
408 211795 : free(rows_max_eps);
409 211795 : backend_stop(ctx);
410 :
411 : // Compute average flops per rank.
412 211795 : dbm_mpi_sum_int64(flop, 1, matrix_c->dist->comm);
413 211795 : *flop = (*flop + matrix_c->dist->nranks - 1) / matrix_c->dist->nranks;
414 :
415 : // Final filter pass.
416 211795 : dbm_filter(matrix_c, filter_eps);
417 211795 : }
418 :
419 : // EOF
|