Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <limits.h>
10 : #include <stdlib.h>
11 : #include <string.h>
12 :
13 : #include "../offload/offload_runtime.h"
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_library.h"
16 : #include "dbm_multiply.h"
17 : #include "dbm_multiply_comm.h"
18 : #include "dbm_multiply_cpu.h"
19 : #include "dbm_multiply_gpu.h"
20 : #include "dbm_multiply_internal.h"
21 :
22 : /*******************************************************************************
23 : * \brief Returns the larger of two given integer (missing from the C standard).
24 : * \author Ole Schuett
25 : ******************************************************************************/
26 : static inline int imax(int x, int y) { return (x > y ? x : y); }
27 :
28 : /*******************************************************************************
29 : * \brief Updates the min/max of a range of values (initially {INT_MAX, 0}).
30 : * \author Hans Pabst
31 : ******************************************************************************/
32 : static inline void min_max(int result[2], int value) {
33 : if (value < result[0]) {
34 : result[0] = value;
35 : }
36 : if (result[1] < value) {
37 : result[1] = value;
38 : }
39 : }
40 :
41 : /*******************************************************************************
42 : * \brief Private routine for computing the max filter threshold for each row.
43 : * \author Ole Schuett
44 : ******************************************************************************/
45 201963 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
46 : const double filter_eps) {
47 201963 : const int nrows = (trans) ? matrix->ncols : matrix->nrows;
48 201963 : int *nblocks_per_row = calloc(nrows, sizeof(int));
49 201963 : float *row_max_eps = malloc(nrows * sizeof(float));
50 201963 : assert(row_max_eps != NULL);
51 :
52 201963 : #pragma omp parallel
53 : {
54 : #pragma omp for
55 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
56 : dbm_shard_t *shard = &matrix->shards[ishard];
57 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
58 : const dbm_block_t *blk = &shard->blocks[iblock];
59 : const int row = (trans) ? blk->col : blk->row;
60 : #pragma omp atomic
61 : nblocks_per_row[row]++;
62 : }
63 : }
64 : #pragma omp single
65 : dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
66 : #pragma omp barrier
67 : #pragma omp for
68 : for (int i = 0; i < nrows; i++) {
69 : const float f =
70 : ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
71 : row_max_eps[i] = f * f;
72 : }
73 : } // end of omp parallel region
74 :
75 201963 : free(nblocks_per_row);
76 201963 : return row_max_eps; // Ownership of row_max_eps transfers to caller.
77 : }
78 :
79 : /*******************************************************************************
80 : * \brief Private struct for storing the context of the multiplication backend.
81 : * \author Ole Schuett
82 : ******************************************************************************/
83 : typedef struct {
84 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
85 : dbm_multiply_gpu_context_t gpu;
86 : #endif
87 : } backend_context_t;
88 :
89 : /*******************************************************************************
90 : * \brief Private routine for intializing the multiplication backend.
91 : * \author Ole Schuett
92 : ******************************************************************************/
93 201963 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
94 201963 : backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
95 :
96 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
97 : dbm_multiply_gpu_start(MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
98 : matrix_c->shards, &ctx->gpu);
99 : #else
100 201963 : (void)matrix_c; // mark as used
101 : #endif
102 :
103 201963 : return ctx;
104 : }
105 :
106 : /*******************************************************************************
107 : * \brief Private routine for handing newly arrived packs to the backend.
108 : * \author Ole Schuett
109 : ******************************************************************************/
110 0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
111 : const dbm_pack_t *pack_b,
112 : backend_context_t *ctx) {
113 :
114 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
115 : dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
116 : #else
117 0 : (void)pack_a; // mark as used
118 0 : (void)pack_b;
119 0 : (void)ctx;
120 : #endif
121 0 : }
122 :
123 : /*******************************************************************************
124 : * \brief Private routine for sending a batch to the multiplication backend.
125 : * \author Ole Schuett
126 : ******************************************************************************/
127 223818 : static void backend_process_batch(const int ntasks, dbm_task_t batch[ntasks],
128 : const int mnk_range[3][2], const double alpha,
129 : const dbm_pack_t *pack_a,
130 : const dbm_pack_t *pack_b, const int kshard,
131 : dbm_shard_t *shard_c,
132 : backend_context_t *ctx) {
133 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
134 : (void)pack_a; // mark as used
135 : (void)pack_b;
136 : (void)shard_c;
137 : dbm_multiply_gpu_process_batch(ntasks, batch, mnk_range, alpha, kshard,
138 : &ctx->gpu);
139 : #else
140 223818 : (void)mnk_range; // mark as used
141 223818 : (void)kshard;
142 223818 : (void)ctx;
143 223818 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
144 : #endif
145 223818 : }
146 :
147 : /*******************************************************************************
148 : * \brief Private routine for downloading results of the multiplication backend.
149 : * \author Ole Schuett
150 : ******************************************************************************/
151 0 : static void backend_download_results(backend_context_t *ctx) {
152 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
153 : dbm_multiply_gpu_download_results(&ctx->gpu);
154 : #else
155 0 : (void)ctx; // mark as used
156 : #endif
157 0 : }
158 :
159 : /*******************************************************************************
160 : * \brief Private routine for shutting down the multiplication backend.
161 : * \author Ole Schuett
162 : ******************************************************************************/
163 201963 : static void backend_stop(backend_context_t *ctx) {
164 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
165 : dbm_multiply_gpu_stop(&ctx->gpu);
166 : #endif
167 201963 : free(ctx);
168 201963 : }
169 :
170 : /*******************************************************************************
171 : * \brief Private routine for multipling two packs.
172 : * \author Ole Schuett
173 : ******************************************************************************/
174 223763 : static void multiply_packs(const bool transa, const bool transb,
175 : const double alpha, const dbm_pack_t *pack_a,
176 : const dbm_pack_t *pack_b,
177 : const dbm_matrix_t *matrix_a,
178 : const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
179 : const bool retain_sparsity,
180 : const float *rows_max_eps, int64_t *flop,
181 223763 : backend_context_t *ctx) {
182 223763 : const float alpha2 = alpha * alpha;
183 223763 : int64_t flop_sum = 0;
184 :
185 223763 : const int nshard_rows = matrix_c->dist->rows.nshards;
186 223763 : const int nshard_cols = matrix_c->dist->cols.nshards;
187 223763 : int shard_row_start[nshard_rows], shard_col_start[nshard_cols];
188 223763 : memset(shard_row_start, 0, nshard_rows * sizeof(int));
189 223763 : memset(shard_col_start, 0, nshard_cols * sizeof(int));
190 :
191 223763 : const int *sum_index_sizes_a =
192 : (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
193 223763 : const int *sum_index_sizes_b =
194 : (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
195 223763 : const int *free_index_sizes_a =
196 : (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
197 223763 : const int *free_index_sizes_b =
198 : (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
199 :
200 223763 : #pragma omp parallel reduction(+ : flop_sum)
201 : {
202 : // Blocks are ordered first by shard. Creating lookup tables of boundaries.
203 : #pragma omp for nowait
204 : for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
205 : const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
206 : const int prev_shard_row =
207 : pack_a->blocks[iblock - 1].free_index % nshard_rows;
208 : if (prev_shard_row != shard_row) {
209 : shard_row_start[shard_row] = iblock;
210 : }
211 : }
212 : #pragma omp for
213 : for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
214 : const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
215 : const int prev_shard_col =
216 : pack_b->blocks[jblock - 1].free_index % nshard_cols;
217 : if (prev_shard_col != shard_col) {
218 : shard_col_start[shard_col] = jblock;
219 : }
220 : }
221 :
222 : #pragma omp for collapse(2) schedule(dynamic, 1)
223 : for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
224 : for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
225 : const int ishard = shard_row * nshard_cols + shard_col;
226 : dbm_shard_t *shard_c = &matrix_c->shards[ishard];
227 : dbm_task_t batch[MAX_BATCH_SIZE];
228 : int mnk_range[][2] = {{INT_MAX, 0}, {INT_MAX, 0}, {INT_MAX, 0}};
229 : int ntasks = 0;
230 :
231 : // Use a merge-join to find pairs of blocks with matching sum indices.
232 : // This utilizes that blocks within a shard are ordered by sum_index.
233 : const int iblock_start = shard_row_start[shard_row];
234 : int jblock_start = shard_col_start[shard_col];
235 : for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
236 : const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
237 : if (blk_a->free_index % nshard_rows != shard_row) {
238 : break;
239 : }
240 : for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
241 : const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
242 : if (blk_b->free_index % nshard_cols != shard_col) {
243 : break;
244 : }
245 : if (blk_a->sum_index < blk_b->sum_index) {
246 : break;
247 : }
248 : if (blk_a->sum_index > blk_b->sum_index) {
249 : jblock_start++;
250 : continue;
251 : }
252 : // Found block pair with blk_a->sum_index == blk_b->sum_index.
253 :
254 : // Check norms.
255 : const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
256 : if (result_norm < rows_max_eps[blk_a->free_index]) {
257 : continue;
258 : }
259 :
260 : // Check block sizes.
261 : const int m = free_index_sizes_a[blk_a->free_index];
262 : const int n = free_index_sizes_b[blk_b->free_index];
263 : const int k = sum_index_sizes_a[blk_a->sum_index];
264 : assert(m == matrix_c->row_sizes[blk_a->free_index]);
265 : assert(n == matrix_c->col_sizes[blk_b->free_index]);
266 : assert(k == sum_index_sizes_b[blk_b->sum_index]);
267 :
268 : // Get C block.
269 : const int row = blk_a->free_index, col = blk_b->free_index;
270 : dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
271 : if (blk_c == NULL && retain_sparsity) {
272 : continue;
273 : } else if (blk_c == NULL) {
274 : assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
275 : assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
276 : matrix_c->dist->my_rank);
277 : blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
278 : }
279 :
280 : // Count flops.
281 : const int64_t task_flops = 2LL * m * n * k;
282 : if (task_flops == 0) {
283 : continue;
284 : }
285 : flop_sum += task_flops;
286 : dbm_library_counter_increment(m, n, k);
287 :
288 : // Add block multiplication to batch.
289 : batch[ntasks].m = m;
290 : batch[ntasks].n = n;
291 : batch[ntasks].k = k;
292 : batch[ntasks].offset_a = blk_a->offset;
293 : batch[ntasks].offset_b = blk_b->offset;
294 : batch[ntasks].offset_c = blk_c->offset;
295 : ntasks++;
296 :
297 : // track MxN-shape covering an entire batch
298 : min_max(mnk_range[0], m);
299 : min_max(mnk_range[1], n);
300 : min_max(mnk_range[2], k);
301 :
302 : if (ntasks == MAX_BATCH_SIZE) {
303 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a,
304 : pack_b, ishard, shard_c, ctx);
305 : mnk_range[0][0] = mnk_range[1][0] = mnk_range[2][0] = INT_MAX;
306 : mnk_range[0][1] = mnk_range[1][1] = mnk_range[2][1] = 0;
307 : ntasks = 0;
308 : }
309 : }
310 : }
311 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a, pack_b,
312 : ishard, shard_c, ctx);
313 : }
314 : }
315 : }
316 223763 : *flop += flop_sum;
317 223763 : }
318 :
319 : /*******************************************************************************
320 : * \brief Performs a multiplication of two dbm_matrix_t matrices.
321 : * See dbm_matrix.h for details.
322 : * \author Ole Schuett
323 : ******************************************************************************/
324 201963 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
325 : const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
326 : const double beta, dbm_matrix_t *matrix_c,
327 : const bool retain_sparsity, const double filter_eps,
328 : int64_t *flop) {
329 :
330 201963 : assert(omp_get_num_threads() == 1);
331 :
332 : // Throughout the matrix multiplication code the "sum_index" and "free_index"
333 : // denote the summation (aka dummy) and free index from the Einstein notation.
334 201963 : const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
335 201963 : const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
336 201963 : const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
337 201963 : const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
338 :
339 : // Sanity check matrix dimensions.
340 201963 : assert(num_sum_index_a == num_sum_index_b);
341 201963 : assert(num_free_index_a == matrix_c->nrows);
342 201963 : assert(num_free_index_b == matrix_c->ncols);
343 :
344 : // Prepare matrix_c.
345 201963 : dbm_scale(matrix_c, beta);
346 :
347 : // Start uploading matrix_c to the GPU.
348 201963 : backend_context_t *ctx = backend_start(matrix_c);
349 :
350 : // Compute filter thresholds for each row.
351 201963 : float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
352 :
353 : // Redistribute matrix_a and matrix_b across MPI ranks.
354 201963 : dbm_comm_iterator_t *iter =
355 201963 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
356 :
357 : // Main loop.
358 201963 : *flop = 0;
359 201963 : dbm_pack_t *pack_a, *pack_b;
360 425726 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
361 223763 : backend_upload_packs(pack_a, pack_b, ctx);
362 223763 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
363 : matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
364 : }
365 :
366 : // Start downloading matrix_c from the GPU.
367 201963 : backend_download_results(ctx);
368 :
369 : // Wait for all other MPI ranks to complete, then release ressources.
370 201963 : dbm_comm_iterator_stop(iter);
371 201963 : free(rows_max_eps);
372 201963 : backend_stop(ctx);
373 :
374 : // Compute average flops per rank.
375 201963 : dbm_mpi_sum_int64(flop, 1, matrix_c->dist->comm);
376 201963 : *flop = (*flop + matrix_c->dist->nranks - 1) / matrix_c->dist->nranks;
377 :
378 : // Final filter pass.
379 201963 : dbm_filter(matrix_c, filter_eps);
380 201963 : }
381 :
382 : // EOF
|