LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b4bd748) Lines: 64 72 88.9 %
Date: 2025-03-09 07:56:22 Functions: 6 8 75.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <limits.h>
      10             : #include <stdlib.h>
      11             : #include <string.h>
      12             : 
      13             : #include "../offload/offload_runtime.h"
      14             : #include "dbm_hyperparams.h"
      15             : #include "dbm_internal.h"
      16             : #include "dbm_library.h"
      17             : #include "dbm_multiply.h"
      18             : #include "dbm_multiply_comm.h"
      19             : #include "dbm_multiply_cpu.h"
      20             : #include "dbm_multiply_gpu.h"
      21             : 
      22             : #if defined(__LIBXSMM)
      23             : #include <libxsmm.h>
      24             : #endif
      25             : 
      26             : #if !defined(DBM_VALIDATE_AGAINST_LIBXSMM) && 0
      27             : #define DBM_VALIDATE_AGAINST_LIBXSMM
      28             : #endif
      29             : 
      30             : /*******************************************************************************
      31             :  * \brief Updates the min/max of a range of values (initially {INT_MAX, 0}).
      32             :  * \author Hans Pabst
      33             :  ******************************************************************************/
      34             : static inline void min_max(int result[2], int value) {
      35             :   if (value < result[0]) {
      36             :     result[0] = value;
      37             :   }
      38             :   if (result[1] < value) {
      39             :     result[1] = value;
      40             :   }
      41             : }
      42             : 
      43             : /*******************************************************************************
      44             :  * \brief Private routine for computing the max filter threshold for each row.
      45             :  * \author Ole Schuett
      46             :  ******************************************************************************/
      47      211795 : static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
      48             :                                    const double filter_eps) {
      49      211795 :   const int nrows = (trans) ? matrix->ncols : matrix->nrows;
      50      211795 :   int *nblocks_per_row = calloc(nrows, sizeof(int));
      51      211795 :   float *row_max_eps = malloc(nrows * sizeof(float));
      52      211795 :   assert(row_max_eps != NULL);
      53             : 
      54      211795 : #pragma omp parallel
      55             :   {
      56             : #pragma omp for
      57             :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
      58             :       dbm_shard_t *shard = &matrix->shards[ishard];
      59             :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
      60             :         const dbm_block_t *blk = &shard->blocks[iblock];
      61             :         const int row = (trans) ? blk->col : blk->row;
      62             : #pragma omp atomic
      63             :         nblocks_per_row[row]++;
      64             :       }
      65             :     }
      66             : #pragma omp single
      67             :     dbm_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
      68             : #pragma omp barrier
      69             : #pragma omp for
      70             :     for (int i = 0; i < nrows; i++) {
      71             :       const float f =
      72             :           ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
      73             :       row_max_eps[i] = f * f;
      74             :     }
      75             :   } // end of omp parallel region
      76             : 
      77      211795 :   free(nblocks_per_row);
      78      211795 :   return row_max_eps; // Ownership of row_max_eps transfers to caller.
      79             : }
      80             : 
      81             : /*******************************************************************************
      82             :  * \brief Private struct for storing the context of the multiplication backend.
      83             :  * \author Ole Schuett
      84             :  ******************************************************************************/
      85             : typedef struct {
      86             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      87             :   dbm_multiply_gpu_context_t gpu;
      88             : #endif
      89             : } backend_context_t;
      90             : 
      91             : /*******************************************************************************
      92             :  * \brief Private routine for intializing the multiplication backend.
      93             :  * \author Ole Schuett
      94             :  ******************************************************************************/
      95      211795 : static backend_context_t *backend_start(const dbm_matrix_t *matrix_c) {
      96      211795 :   backend_context_t *ctx = calloc(1, sizeof(backend_context_t));
      97             : 
      98             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
      99             :   dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
     100             :                          matrix_c->shards, &ctx->gpu);
     101             : #else
     102      211795 :   (void)matrix_c; // mark as used
     103             : #endif
     104             : 
     105      211795 :   return ctx;
     106             : }
     107             : 
     108             : /*******************************************************************************
     109             :  * \brief Private routine for handing newly arrived packs to the backend.
     110             :  * \author Ole Schuett
     111             :  ******************************************************************************/
     112           0 : static void backend_upload_packs(const dbm_pack_t *pack_a,
     113             :                                  const dbm_pack_t *pack_b,
     114             :                                  backend_context_t *ctx) {
     115             : 
     116             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     117             :   dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
     118             : #else
     119           0 :   (void)pack_a; // mark as used
     120           0 :   (void)pack_b;
     121           0 :   (void)ctx;
     122             : #endif
     123           0 : }
     124             : 
     125             : /*******************************************************************************
     126             :  * \brief Private routine for sending a batch to the multiplication backend.
     127             :  * \author Ole Schuett
     128             :  ******************************************************************************/
     129      229286 : static void backend_process_batch(const int ntasks, dbm_task_t batch[ntasks],
     130             :                                   const int mnk_range[3][2], const double alpha,
     131             :                                   const dbm_pack_t *pack_a,
     132             :                                   const dbm_pack_t *pack_b, const int kshard,
     133             :                                   dbm_shard_t *shard_c,
     134             :                                   backend_context_t *ctx) {
     135             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     136             :   dbm_multiply_gpu_process_batch(ntasks, batch, mnk_range, alpha, kshard,
     137             :                                  &ctx->gpu);
     138             : #if defined(DBM_VALIDATE_AGAINST_LIBXSMM) && defined(__LIBXSMM)
     139             :   dbm_shard_gpu_t *const shard_g = &ctx->gpu.shards_c_dev[kshard];
     140             :   dbm_shard_t shard_r;
     141             :   dbm_shard_allocate_promised_blocks(shard_c);
     142             :   /* start transferring GPU result to host */
     143             :   assert(shard_c->data_size == shard_g->data_size);
     144             :   dbm_shard_init(&shard_r);
     145             :   dbm_shard_copy(&shard_r, shard_c);
     146             :   offloadMemcpyAsyncDtoH(shard_c->data, shard_g->data,
     147             :                          shard_c->data_size * sizeof(double), shard_g->stream);
     148             :   dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
     149             :                                  &shard_r);
     150             :   /* finish transferring GPU result to host */
     151             :   offloadStreamSynchronize(shard_g->stream);
     152             :   libxsmm_matdiff_info diff;
     153             :   libxsmm_matdiff_clear(&diff);
     154             :   for (int itask = 0; itask < ntasks; ++itask) {
     155             :     const dbm_task_t task = batch[itask];
     156             :     const double *const tst = &shard_c->data[task.offset_c];
     157             :     const double *const ref = &shard_r.data[task.offset_c];
     158             :     libxsmm_matdiff_info d;
     159             :     if (EXIT_SUCCESS == libxsmm_matdiff(&d, LIBXSMM_DATATYPE(double), task.m,
     160             :                                         task.n, ref, tst, NULL /*ldref*/,
     161             :                                         NULL /*ldtst*/)) {
     162             :       libxsmm_matdiff_reduce(&diff, &d);
     163             :     }
     164             :   }
     165             :   const double epsilon = libxsmm_matdiff_epsilon(&diff);
     166             :   if (1E-15 < epsilon) {
     167             :     fprintf(stderr, "INFO ACC/LIBDBM: mnk=%ix%ix%i ntasks=%i diff=%g\n",
     168             :             mnk_range[0][1], mnk_range[1][1], mnk_range[2][1], ntasks, epsilon);
     169             :   }
     170             :   dbm_shard_release(&shard_r);
     171             : #else
     172             :   (void)pack_a;
     173             :   (void)pack_b;
     174             :   (void)shard_c; // mark as used
     175             : #endif
     176             : #else
     177      229286 :   (void)mnk_range;
     178      229286 :   (void)kshard;
     179      229286 :   (void)ctx; // mark as used
     180      229286 :   dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b, shard_c);
     181             : #endif
     182      229286 : }
     183             : 
     184             : /*******************************************************************************
     185             :  * \brief Private routine for downloading results of the multiplication backend.
     186             :  * \author Ole Schuett
     187             :  ******************************************************************************/
     188           0 : static void backend_download_results(backend_context_t *ctx) {
     189             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     190             :   dbm_multiply_gpu_download_results(&ctx->gpu);
     191             : #else
     192           0 :   (void)ctx; // mark as used
     193             : #endif
     194           0 : }
     195             : 
     196             : /*******************************************************************************
     197             :  * \brief Private routine for shutting down the multiplication backend.
     198             :  * \author Ole Schuett
     199             :  ******************************************************************************/
     200      211795 : static void backend_stop(backend_context_t *ctx) {
     201             : #if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
     202             :   dbm_multiply_gpu_stop(&ctx->gpu);
     203             : #endif
     204      211795 :   free(ctx);
     205      211795 : }
     206             : 
     207             : /*******************************************************************************
     208             :  * \brief Private routine for multipling two packs.
     209             :  * \author Ole Schuett
     210             :  ******************************************************************************/
     211      229231 : static void multiply_packs(const bool transa, const bool transb,
     212             :                            const double alpha, const dbm_pack_t *pack_a,
     213             :                            const dbm_pack_t *pack_b,
     214             :                            const dbm_matrix_t *matrix_a,
     215             :                            const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
     216             :                            const bool retain_sparsity,
     217             :                            const float *rows_max_eps, int64_t *flop,
     218      229231 :                            backend_context_t *ctx) {
     219      229231 :   const float alpha2 = alpha * alpha;
     220      229231 :   int64_t flop_sum = 0;
     221             : 
     222      229231 :   const int nshard_rows = matrix_c->dist->rows.nshards;
     223      229231 :   const int nshard_cols = matrix_c->dist->cols.nshards;
     224      229231 :   int shard_row_start[nshard_rows], shard_col_start[nshard_cols];
     225      229231 :   memset(shard_row_start, 0, nshard_rows * sizeof(int));
     226      229231 :   memset(shard_col_start, 0, nshard_cols * sizeof(int));
     227             : 
     228      229231 :   const int *sum_index_sizes_a =
     229             :       (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
     230      229231 :   const int *sum_index_sizes_b =
     231             :       (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
     232      229231 :   const int *free_index_sizes_a =
     233             :       (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
     234      229231 :   const int *free_index_sizes_b =
     235             :       (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
     236             : 
     237      229231 : #pragma omp parallel reduction(+ : flop_sum)
     238             :   {
     239             :     // Blocks are ordered first by shard. Creating lookup tables of boundaries.
     240             : #pragma omp for nowait
     241             :     for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
     242             :       const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
     243             :       const int prev_shard_row =
     244             :           pack_a->blocks[iblock - 1].free_index % nshard_rows;
     245             :       if (prev_shard_row != shard_row) {
     246             :         shard_row_start[shard_row] = iblock;
     247             :       }
     248             :     }
     249             : #pragma omp for
     250             :     for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
     251             :       const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
     252             :       const int prev_shard_col =
     253             :           pack_b->blocks[jblock - 1].free_index % nshard_cols;
     254             :       if (prev_shard_col != shard_col) {
     255             :         shard_col_start[shard_col] = jblock;
     256             :       }
     257             :     }
     258             : 
     259             : #pragma omp for collapse(2) DBM_OMP_SCHEDULE
     260             :     for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
     261             :       for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
     262             :         const int ishard = shard_row * nshard_cols + shard_col;
     263             :         dbm_shard_t *shard_c = &matrix_c->shards[ishard];
     264             :         dbm_task_t batch[DBM_MAX_BATCH_SIZE];
     265             :         int mnk_range[][2] = {{INT_MAX, 0}, {INT_MAX, 0}, {INT_MAX, 0}};
     266             :         int ntasks = 0;
     267             : 
     268             :         // Use a merge-join to find pairs of blocks with matching sum indices.
     269             :         // This utilizes that blocks within a shard are ordered by sum_index.
     270             :         const int iblock_start = shard_row_start[shard_row];
     271             :         int jblock_start = shard_col_start[shard_col];
     272             :         for (int iblock = iblock_start; iblock < pack_a->nblocks; iblock++) {
     273             :           const dbm_pack_block_t *blk_a = &pack_a->blocks[iblock];
     274             :           if (blk_a->free_index % nshard_rows != shard_row) {
     275             :             break;
     276             :           }
     277             :           for (int jblock = jblock_start; jblock < pack_b->nblocks; jblock++) {
     278             :             const dbm_pack_block_t *blk_b = &pack_b->blocks[jblock];
     279             :             if (blk_b->free_index % nshard_cols != shard_col) {
     280             :               break;
     281             :             }
     282             :             if (blk_a->sum_index < blk_b->sum_index) {
     283             :               break;
     284             :             }
     285             :             if (blk_a->sum_index > blk_b->sum_index) {
     286             :               jblock_start++;
     287             :               continue;
     288             :             }
     289             :             // Found block pair with blk_a->sum_index == blk_b->sum_index.
     290             : 
     291             :             // Check norms.
     292             :             const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
     293             :             if (result_norm < rows_max_eps[blk_a->free_index]) {
     294             :               continue;
     295             :             }
     296             : 
     297             :             // Check block sizes.
     298             :             const int m = free_index_sizes_a[blk_a->free_index];
     299             :             const int n = free_index_sizes_b[blk_b->free_index];
     300             :             const int k = sum_index_sizes_a[blk_a->sum_index];
     301             :             assert(m == matrix_c->row_sizes[blk_a->free_index]);
     302             :             assert(n == matrix_c->col_sizes[blk_b->free_index]);
     303             :             assert(k == sum_index_sizes_b[blk_b->sum_index]);
     304             : 
     305             :             // Get C block.
     306             :             const int row = blk_a->free_index, col = blk_b->free_index;
     307             :             dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
     308             :             if (blk_c == NULL && retain_sparsity) {
     309             :               continue;
     310             :             } else if (blk_c == NULL) {
     311             :               assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
     312             :               assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
     313             :                      matrix_c->dist->my_rank);
     314             :               blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
     315             :             }
     316             : 
     317             :             // Count flops.
     318             :             const int64_t task_flops = 2LL * m * n * k;
     319             :             if (task_flops == 0) {
     320             :               continue;
     321             :             }
     322             :             flop_sum += task_flops;
     323             :             dbm_library_counter_increment(m, n, k);
     324             : 
     325             :             // Add block multiplication to batch.
     326             :             batch[ntasks].m = m;
     327             :             batch[ntasks].n = n;
     328             :             batch[ntasks].k = k;
     329             :             batch[ntasks].offset_a = blk_a->offset;
     330             :             batch[ntasks].offset_b = blk_b->offset;
     331             :             batch[ntasks].offset_c = blk_c->offset;
     332             :             ntasks++;
     333             : 
     334             :             // track MxN-shape covering an entire batch
     335             :             min_max(mnk_range[0], m);
     336             :             min_max(mnk_range[1], n);
     337             :             min_max(mnk_range[2], k);
     338             : 
     339             :             if (ntasks == DBM_MAX_BATCH_SIZE) {
     340             :               backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a,
     341             :                                     pack_b, ishard, shard_c, ctx);
     342             :               mnk_range[0][0] = mnk_range[1][0] = mnk_range[2][0] = INT_MAX;
     343             :               mnk_range[0][1] = mnk_range[1][1] = mnk_range[2][1] = 0;
     344             :               ntasks = 0;
     345             :             }
     346             :           }
     347             :         }
     348             :         backend_process_batch(ntasks, batch, mnk_range, alpha, pack_a, pack_b,
     349             :                               ishard, shard_c, ctx);
     350             :       }
     351             :     }
     352             :   }
     353      229231 :   *flop += flop_sum;
     354      229231 : }
     355             : 
     356             : /*******************************************************************************
     357             :  * \brief Performs a multiplication of two dbm_matrix_t matrices.
     358             :  *        See dbm_matrix.h for details.
     359             :  * \author Ole Schuett
     360             :  ******************************************************************************/
     361      211795 : void dbm_multiply(const bool transa, const bool transb, const double alpha,
     362             :                   const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
     363             :                   const double beta, dbm_matrix_t *matrix_c,
     364             :                   const bool retain_sparsity, const double filter_eps,
     365             :                   int64_t *flop) {
     366             : 
     367      211795 :   assert(omp_get_num_threads() == 1);
     368             : 
     369             :   // Throughout the matrix multiplication code the "sum_index" and "free_index"
     370             :   // denote the summation (aka dummy) and free index from the Einstein notation.
     371      211795 :   const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
     372      211795 :   const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
     373      211795 :   const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
     374      211795 :   const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
     375             : 
     376             :   // Sanity check matrix dimensions.
     377      211795 :   assert(num_sum_index_a == num_sum_index_b);
     378      211795 :   assert(num_free_index_a == matrix_c->nrows);
     379      211795 :   assert(num_free_index_b == matrix_c->ncols);
     380             : 
     381             :   // Prepare matrix_c.
     382      211795 :   dbm_scale(matrix_c, beta);
     383             : 
     384             :   // Start uploading matrix_c to the GPU.
     385      211795 :   backend_context_t *ctx = backend_start(matrix_c);
     386             : 
     387             :   // Compute filter thresholds for each row.
     388      211795 :   float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
     389             : 
     390             :   // Redistribute matrix_a and matrix_b across MPI ranks.
     391      211795 :   dbm_comm_iterator_t *iter =
     392      211795 :       dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
     393             : 
     394             :   // Main loop.
     395      211795 :   *flop = 0;
     396      211795 :   dbm_pack_t *pack_a, *pack_b;
     397      441026 :   while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
     398      229231 :     backend_upload_packs(pack_a, pack_b, ctx);
     399      229231 :     multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
     400             :                    matrix_c, retain_sparsity, rows_max_eps, flop, ctx);
     401             :   }
     402             : 
     403             :   // Start downloading matrix_c from the GPU.
     404      211795 :   backend_download_results(ctx);
     405             : 
     406             :   // Wait for all other MPI ranks to complete, then release ressources.
     407      211795 :   dbm_comm_iterator_stop(iter);
     408      211795 :   free(rows_max_eps);
     409      211795 :   backend_stop(ctx);
     410             : 
     411             :   // Compute average flops per rank.
     412      211795 :   dbm_mpi_sum_int64(flop, 1, matrix_c->dist->comm);
     413      211795 :   *flop = (*flop + matrix_c->dist->nranks - 1) / matrix_c->dist->nranks;
     414             : 
     415             :   // Final filter pass.
     416      211795 :   dbm_filter(matrix_c, filter_eps);
     417      211795 : }
     418             : 
     419             : // EOF

Generated by: LCOV version 1.15