Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <limits.h>
10 : #include <stdlib.h>
11 : #include <string.h>
12 :
13 : #include "../offload/offload_runtime.h"
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_library.h"
16 : #include "dbm_multiply.h"
17 : #include "dbm_multiply_comm.h"
18 : #include "dbm_multiply_cpu.h"
19 : #include "dbm_multiply_gpu.h"
20 : #include "dbm_multiply_internal.h"
21 :
22 : /*******************************************************************************
23 : * \brief Returns the larger of two given integer (missing from the C standard).
24 : * \author Ole Schuett
25 : ******************************************************************************/
26 : static inline int imax(int x, int y) { return (x > y ? x : y); }
27 :
28 : /*******************************************************************************
29 : * \brief Updates the min/max of a range of values (initially {INT_MAX, 0}).
30 : * \author Hans Pabst
31 : ******************************************************************************/
32 : static inline void min_max(int result[2], int value) {
33 : if (value < result[0]) {
34 : result[0] = value;
35 : }
36 : if (result[1] < value) {
37 : result[1] = value;
38 : }
39 : }
40 :
41 : /*******************************************************************************
42 : * \brief Private routine for computing the max filter threshold for each row.
43 : * \author Ole Schuett
44 : ******************************************************************************/
45 203043 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
46 : const double filter_eps) {
47 203043 : const int nrows = (trans) ? matrix->ncols : matrix->nrows;
48 203043 : int *nblocks_per_row = calloc(nrows, sizeof(int));
49 203043 : float *row_max_eps = malloc(nrows * sizeof(float));
50 203043 : assert(row_max_eps != NULL);
51 :
52 203043 : #pragma omp parallel
53 : {
54 : #pragma omp for
55 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
56 : dbm_shard_t *shard = &matrix->shards[ishard];
57 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
58 : const dbm_block_t *blk = &shard->blocks[iblock];
59 : const int row = (trans) ? blk->col : blk->row;
60 : #pragma omp atomic
61 : nblocks_per_row[row]++;
62 : }
63 : }
64 : #pragma omp single
65 : dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
66 : #pragma omp barrier
67 : #pragma omp for
68 : for (int i = 0; i < nrows; i++) {
69 : const float f =
70 : ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
71 : row_max_eps[i] = f * f;
72 : }
73 : } // end of omp parallel region
74 :
75 203043 : free(nblocks_per_row);
76 203043 : return row_max_eps; // Ownership of row_max_eps transfers to caller.
77 : }
78 :
79 : /*******************************************************************************
80 : * \brief Private struct for storing the context of the multiplication backend.
81 : * \author Ole Schuett
82 : ******************************************************************************/
83 : typedef struct {
84 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
85 : dbm_multiply_gpu_context_t gpu;
86 : #endif
87 : } backend_context_t;
88 :
89 : /*******************************************************************************
90 : * \brief Private routine for intializing the multiplication backend.
91 : * \author Ole Schuett
92 : ******************************************************************************/
93 203043 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
94 203043 : backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
95 :
96 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
97 : dbm_multiply_gpu_start(MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
98 : matrix_c->shards, &ctx->gpu);
99 : #else
100 203043 : (void)matrix_c; // mark as used
101 : #endif
102 :
103 203043 : return ctx;
104 : }
105 :
106 : /*******************************************************************************
107 : * \brief Private routine for handing newly arrived packs to the backend.
108 : * \author Ole Schuett
109 : ******************************************************************************/
110 0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
111 : const dbm_pack_t *pack_b,
112 : backend_context_t *ctx) {
113 :
114 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
115 : dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
116 : #else
117 0 : (void)pack_a; // mark as used
118 0 : (void)pack_b;
119 0 : (void)ctx;
120 : #endif
121 0 : }
122 :
123 : /*******************************************************************************
124 : * \brief Private routine for sending a batch to the multiplication backend.
125 : * \author Ole Schuett
126 : ******************************************************************************/
127 225210 : static void backend_process_batch(const int ntasks, dbm_task_t batch[ntasks],
128 : const int mnk_range[3][2], const double alpha,
129 : const dbm_pack_t *pack_a,
130 : const dbm_pack_t *pack_b, const int kshard,
131 : dbm_shard_t *shard_c,
132 : backend_context_t *ctx) {
133 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
134 : (void)pack_a; // mark as used
135 : (void)pack_b;
136 : (void)shard_c;
137 : dbm_multiply_gpu_process_batch(ntasks, batch, mnk_range, alpha, kshard,
138 : &ctx->gpu);
139 : #else
140 225210 : (void)mnk_range; // mark as used
141 225210 : (void)kshard;
142 225210 : (void)ctx;
143 225210 : dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
144 : #endif
145 225210 : }
146 :
147 : /*******************************************************************************
148 : * \brief Private routine for downloading results of the multiplication backend.
149 : * \author Ole Schuett
150 : ******************************************************************************/
151 0 : static void backend_download_results(backend_context_t *ctx) {
152 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
153 : dbm_multiply_gpu_download_results(&ctx->gpu);
154 : #else
155 0 : (void)ctx; // mark as used
156 : #endif
157 0 : }
158 :
159 : /*******************************************************************************
160 : * \brief Private routine for shutting down the multiplication backend.
161 : * \author Ole Schuett
162 : ******************************************************************************/
163 203043 : static void backend_stop(backend_context_t *ctx) {
164 : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
165 : dbm_multiply_gpu_stop(&ctx->gpu);
166 : #endif
167 203043 : free(ctx);
168 203043 : }
169 :
170 : /*******************************************************************************
171 : * \brief Private routine for multipling two packs.
172 : * \author Ole Schuett
173 : ******************************************************************************/
174 224843 : static void multiply_packs(const bool transa, const bool transb,
175 : const double alpha, const dbm_pack_t *pack_a,
176 : const dbm_pack_t *pack_b,
177 : const dbm_matrix_t *matrix_a,
178 : const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
179 : const bool retain_sparsity,
180 : const float *rows_max_eps, int64_t *flop,
181 224843 : backend_context_t *ctx) {
182 224843 : const float alpha2 = alpha * alpha;
183 224843 : int64_t flop_sum = 0;
184 :
185 224843 : const int nshard_rows = matrix_c->dist->rows.nshards;
186 224843 : const int nshard_cols = matrix_c->dist->cols.nshards;
187 224843 : int shard_row_start[nshard_rows], shard_col_start[nshard_cols];
188 224843 : memset(shard_row_start, 0, nshard_rows * sizeof(int));
189 224843 : memset(shard_col_start, 0, nshard_cols * sizeof(int));
190 :
191 224843 : const int *sum_index_sizes_a =
192 : (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
193 224843 : const int *sum_index_sizes_b =
194 : (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
195 224843 : const int *free_index_sizes_a =
196 : (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
197 224843 : const int *free_index_sizes_b =
198 : (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
199 :
200 224843 : #pragma omp parallel reduction(+ : flop_sum)
201 : {
202 :
203 : // Blocks are ordered first by shard. Creating lookup tables of boundaries.
204 : #pragma omp for
205 : for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
206 : const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
207 : const int prev_shard_row =
208 : pack_a->blocks[iblock - 1].free_index % nshard_rows;
209 : if (prev_shard_row != shard_row) {
210 : shard_row_start[shard_row] = iblock;
211 : }
212 : }
213 : #pragma omp for
214 : for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
215 : const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
216 : const int prev_shard_col =
217 : pack_b->blocks[jblock - 1].free_index % nshard_cols;
218 : if (prev_shard_col != shard_col) {
219 : shard_col_start[shard_col] = jblock;
220 : }
221 : }
222 :
223 : #pragma omp for collapse(2) schedule(dynamic)
224 : for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
225 : for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
226 : const int ishard = shard_row * nshard_cols + shard_col;
227 : dbm_shard_t *shard_c = &matrix_c->shards[ishard];
228 : dbm_task_t batch[MAX_BATCH_SIZE];
229 : int mnk_range[][2] = {{INT_MAX, 0}, {INT_MAX, 0}, {INT_MAX, 0}};
230 : int ntasks = 0;
231 :
232 : // Use a merge-join to find pairs of blocks with matching sum indices.
233 : // This utilizes that blocks within a shard are ordered by sum_index.
234 : const int iblock_start = shard_row_start[shard_row];
235 : int jblock_start = shard_col_start[shard_col];
236 : for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
237 : const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
238 : if (blk_a->free_index % nshard_rows != shard_row) {
239 : break;
240 : }
241 : for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
242 : const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
243 : if (blk_b->free_index % nshard_cols != shard_col) {
244 : break;
245 : }
246 : if (blk_a->sum_index < blk_b->sum_index) {
247 : break;
248 : }
249 : if (blk_a->sum_index > blk_b->sum_index) {
250 : jblock_start++;
251 : continue;
252 : }
253 : // Found block pair with blk_a->sum_index == blk_b->sum_index.
254 :
255 : // Check norms.
256 : const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
257 : if (result_norm < rows_max_eps[blk_a->free_index]) {
258 : continue;
259 : }
260 :
261 : // Check block sizes.
262 : const int m = free_index_sizes_a[blk_a->free_index];
263 : const int n = free_index_sizes_b[blk_b->free_index];
264 : const int k = sum_index_sizes_a[blk_a->sum_index];
265 : assert(m == matrix_c->row_sizes[blk_a->free_index]);
266 : assert(n == matrix_c->col_sizes[blk_b->free_index]);
267 : assert(k == sum_index_sizes_b[blk_b->sum_index]);
268 :
269 : // Get C block.
270 : const int row = blk_a->free_index, col = blk_b->free_index;
271 : dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
272 : if (blk_c == NULL && retain_sparsity) {
273 : continue;
274 : } else if (blk_c == NULL) {
275 : assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
276 : assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
277 : matrix_c->dist->my_rank);
278 : blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
279 : }
280 :
281 : // Count flops.
282 : dbm_library_counter_increment(m, n, k);
283 : const int64_t task_flops = 2LL * m * n * k;
284 : flop_sum += task_flops;
285 : if (task_flops == 0) {
286 : continue;
287 : }
288 :
289 : // Add block multiplication to batch.
290 : batch[ntasks].m = m;
291 : batch[ntasks].n = n;
292 : batch[ntasks].k = k;
293 : batch[ntasks].offset_a = blk_a->offset;
294 : batch[ntasks].offset_b = blk_b->offset;
295 : batch[ntasks].offset_c = blk_c->offset;
296 : ntasks++;
297 :
298 : // track MxN-shape covering an entire batch
299 : min_max(mnk_range[0], m);
300 : min_max(mnk_range[1], n);
301 : min_max(mnk_range[2], k);
302 :
303 : if (ntasks == MAX_BATCH_SIZE) {
304 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a,
305 : pack_b, ishard, shard_c, ctx);
306 : mnk_range[0][0] = mnk_range[1][0] = mnk_range[2][0] = INT_MAX;
307 : mnk_range[0][1] = mnk_range[1][1] = mnk_range[2][1] = 0;
308 : ntasks = 0;
309 : }
310 : }
311 : }
312 : backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a, pack_b,
313 : ishard, shard_c, ctx);
314 : }
315 : }
316 : }
317 224843 : *flop += flop_sum;
318 224843 : }
319 :
320 : /*******************************************************************************
321 : * \brief Performs a multiplication of two dbm_matrix_t matrices.
322 : * See dbm_matrix.h for details.
323 : * \author Ole Schuett
324 : ******************************************************************************/
325 203043 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
326 : const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
327 : const double beta, dbm_matrix_t *matrix_c,
328 : const bool retain_sparsity, const double filter_eps,
329 : int64_t *flop) {
330 :
331 203043 : assert(omp_get_num_threads() == 1);
332 :
333 : // Throughout the matrix multiplication code the "sum_index" and "free_index"
334 : // denote the summation (aka dummy) and free index from the Einstein notation.
335 203043 : const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
336 203043 : const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
337 203043 : const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
338 203043 : const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
339 :
340 : // Sanity check matrix dimensions.
341 203043 : assert(num_sum_index_a == num_sum_index_b);
342 203043 : assert(num_free_index_a == matrix_c->nrows);
343 203043 : assert(num_free_index_b == matrix_c->ncols);
344 :
345 : // Prepare matrix_c.
346 203043 : dbm_scale(matrix_c, beta);
347 :
348 : // Start uploading matrix_c to the GPU.
349 203043 : backend_context_t *ctx = backend_start(matrix_c);
350 :
351 : // Compute filter thresholds for each row.
352 203043 : float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
353 :
354 : // Redistribute matrix_a and matrix_b across MPI ranks.
355 203043 : dbm_comm_iterator_t *iter =
356 203043 : dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
357 :
358 : // Main loop.
359 203043 : *flop = 0;
360 203043 : dbm_pack_t *pack_a, *pack_b;
361 427886 : while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
362 224843 : backend_upload_packs(pack_a, pack_b, ctx);
363 224843 : multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
364 : matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
365 : }
366 :
367 : // Start downloading matrix_c from the GPU.
368 203043 : backend_download_results(ctx);
369 :
370 : // Wait for all other MPI ranks to complete, then release ressources.
371 203043 : dbm_comm_iterator_stop(iter);
372 203043 : free(rows_max_eps);
373 203043 : backend_stop(ctx);
374 :
375 : // Compute average flops per rank.
376 203043 : dbm_mpi_sum_int64(flop, 1, matrix_c->dist->comm);
377 203043 : *flop = (*flop + matrix_c->dist->nranks - 1) / matrix_c->dist->nranks;
378 :
379 : // Final filter pass.
380 203043 : dbm_filter(matrix_c, filter_eps);
381 203043 : }
382 :
383 : // EOF
|