LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_cpu.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b4bd748) Lines: 48 48 100.0 %
Date: 2025-03-09 07:56:22 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <stddef.h>
      10             : #include <string.h>
      11             : 
      12             : #if defined(__LIBXSMM)
      13             : #include <libxsmm.h>
      14             : #if !defined(DBM_LIBXSMM_PREFETCH)
      15             : // #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
      16             : #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
      17             : #endif
      18             : #if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
      19             : #define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
      20             : #endif
      21             : #endif
      22             : 
      23             : #include "dbm_hyperparams.h"
      24             : #include "dbm_multiply_cpu.h"
      25             : 
      26             : /*******************************************************************************
      27             :  * \brief Prototype for BLAS dgemm.
      28             :  * \author Ole Schuett
      29             :  ******************************************************************************/
      30             : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
      31             :             const int *k, const double *alpha, const double *a, const int *lda,
      32             :             const double *b, const int *ldb, const double *beta, double *c,
      33             :             const int *ldc);
      34             : 
      35             : /*******************************************************************************
      36             :  * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
      37             :  * \author Ole Schuett
      38             :  ******************************************************************************/
      39     3490357 : static inline void dbm_dgemm(const char transa, const char transb, const int m,
      40             :                              const int n, const int k, const double alpha,
      41             :                              const double *a, const int lda, const double *b,
      42             :                              const int ldb, const double beta, double *c,
      43             :                              const int ldc) {
      44             : 
      45     3490357 :   dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
      46             :          &ldc);
      47             : }
      48             : 
      49             : /*******************************************************************************
      50             :  * \brief Private hash function based on Szudzik's elegant pairing.
      51             :  *        Using unsigned int to return a positive number even after overflow.
      52             :  *        https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
      53             :  *        https://stackoverflow.com/a/13871379
      54             :  *        http://szudzik.com/ElegantPairing.pdf
      55             :  * \author Ole Schuett
      56             :  ******************************************************************************/
      57             : #if defined(__LIBXSMM)
      58    43280576 : static inline unsigned int hash(const dbm_task_t task) {
      59    43280576 :   const unsigned int m = task.m, n = task.n, k = task.k;
      60    43280576 :   const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
      61    43280576 :   const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
      62    43280576 :   return mnk;
      63             : }
      64             : #endif
      65             : 
      66             : /*******************************************************************************
      67             :  * \brief Internal routine for executing the tasks in given batch on the CPU.
      68             :  * \author Ole Schuett
      69             :  ******************************************************************************/
      70      229286 : void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks],
      71             :                                     const double alpha,
      72             :                                     const dbm_pack_t *pack_a,
      73             :                                     const dbm_pack_t *pack_b,
      74      229286 :                                     dbm_shard_t *shard_c) {
      75             : 
      76      229286 :   if (0 >= ntasks) { // nothing to do
      77       50962 :     return;
      78             :   }
      79      178324 :   dbm_shard_allocate_promised_blocks(shard_c);
      80             : 
      81             : #if defined(__LIBXSMM)
      82             : 
      83             :   // Sort tasks approximately by m,n,k via bucket sort.
      84      178324 :   int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
      85    21818612 :   for (int itask = 0; itask < ntasks; ++itask) {
      86    21640288 :     const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      87    21640288 :     ++buckets[i];
      88             :   }
      89   178324000 :   for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
      90   178145676 :     buckets[i] += buckets[i - 1];
      91             :   }
      92      178324 :   assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
      93      178324 :   int batch_order[ntasks];
      94    21818612 :   for (int itask = 0; itask < ntasks; ++itask) {
      95    21640288 :     const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
      96    21640288 :     --buckets[i];
      97    21640288 :     batch_order[buckets[i]] = itask;
      98             :   }
      99             : 
     100             :   // Prepare arguments for libxsmm's kernel-dispatch.
     101      178324 :   const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
     102      178324 :   const int prefetch = DBM_LIBXSMM_PREFETCH;
     103      178324 :   int kernel_m = 0, kernel_n = 0, kernel_k = 0;
     104      178324 :   dbm_task_t task_next = batch[batch_order[0]];
     105             : 
     106             : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     107             :   double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
     108             : #endif
     109             : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     110      178324 :   libxsmm_gemmfunction kernel_func = NULL;
     111             : #else
     112             :   libxsmm_dmmfunction kernel_func = NULL;
     113             :   const double beta = 1.0;
     114             : #endif
     115             : 
     116             :   // Loop over tasks.
     117    21818612 :   for (int itask = 0; itask < ntasks; ++itask) {
     118    21640288 :     const dbm_task_t task = task_next;
     119    21640288 :     task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
     120             : 
     121    21640288 :     if (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k) {
     122     1478019 :       if (LIBXSMM_SMM(task.m, task.n, task.m, 1 /*assume in-$, no RFO*/,
     123             :                       sizeof(double))) {
     124             : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     125     1421397 :         const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
     126             :             task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb*/,
     127             :             task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
     128             :             LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
     129             :             LIBXSMM_DATATYPE_F64 /*calcp*/);
     130     1421397 :         kernel_func =
     131             :             (LIBXSMM_FEQ(1.0, alpha)
     132     1147935 :                  ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
     133             :                                          (libxsmm_bitfield)prefetch)
     134     1421397 :                  : NULL);
     135             : #else
     136             :         kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
     137             :                                           NULL /*ldb*/, NULL /*ldc*/, &alpha,
     138             :                                           &beta, &flags, &prefetch);
     139             : #endif
     140             :       } else {
     141             :         kernel_func = NULL;
     142             :       }
     143             :       kernel_m = task.m;
     144             :       kernel_n = task.n;
     145             :       kernel_k = task.k;
     146             :     }
     147             : 
     148             :     // gemm_param wants non-const data even for A and B
     149    21640288 :     double *const data_a = pack_a->data + task.offset_a;
     150    21640288 :     double *const data_b = pack_b->data + task.offset_b;
     151    21640288 :     double *const data_c = shard_c->data + task.offset_c;
     152             : 
     153    21640288 :     if (kernel_func != NULL) {
     154             : #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
     155    18149931 :       libxsmm_gemm_param gemm_param;
     156    18149931 :       gemm_param.a.primary = data_a;
     157    18149931 :       gemm_param.b.primary = data_b;
     158    18149931 :       gemm_param.c.primary = data_c;
     159             : #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     160             :       gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
     161             :       gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
     162             :       gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
     163             : #endif
     164    18149931 :       kernel_func(&gemm_param);
     165             : #elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
     166             :       kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
     167             :                   pack_b->data + task_next.offset_b,
     168             :                   shard_c->data + task_next.offset_c);
     169             : #else
     170             :       kernel_func(data_a, data_b, data_c);
     171             : #endif
     172             :     } else {
     173     3490357 :       dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
     174             :                 task.n, 1.0, data_c, task.m);
     175             :     }
     176             :   }
     177             : #else
     178             :   // Fallback to BLAS when libxsmm is not available.
     179             :   for (int itask = 0; itask < ntasks; ++itask) {
     180             :     const dbm_task_t task = batch[itask];
     181             :     const double *data_a = &pack_a->data[task.offset_a];
     182             :     const double *data_b = &pack_b->data[task.offset_b];
     183             :     double *data_c = &shard_c->data[task.offset_c];
     184             :     dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
     185             :               task.n, 1.0, data_c, task.m);
     186             :   }
     187             : #endif
     188             : }
     189             : 
     190             : // EOF

Generated by: LCOV version 1.15