LCOV - code coverage report
Current view: top level - src/grid/dgemm - grid_dgemm_utils.h (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 4 4 100.0 %
Date: 2024-11-21 06:45:46 Functions: 1 1 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #ifndef GRID_DGEMM_UTILS_H
       9             : #define GRID_DGEMM_UTILS_H
      10             : 
      11             : #include <stdbool.h>
      12             : #include <stdio.h>
      13             : #include <string.h>
      14             : 
      15             : #if defined(__MKL)
      16             : #include <mkl.h>
      17             : #include <mkl_cblas.h>
      18             : #endif
      19             : 
      20             : #if defined(__LIBXSMM)
      21             : #include <libxsmm.h>
      22             : #endif
      23             : 
      24             : #include "../common/grid_common.h"
      25             : #include "grid_dgemm_private_header.h"
      26             : #include "grid_dgemm_tensor_local.h"
      27             : 
      28             : /* inverse of the factorials */
      29             : static const double inv_fac[] = {1.0,
      30             :                                  1.0,
      31             :                                  0.5,
      32             :                                  0.166666666666666666666666666667,
      33             :                                  0.0416666666666666666666666666667,
      34             :                                  0.00833333333333333333333333333333,
      35             :                                  0.00138888888888888888888888888889,
      36             :                                  0.000198412698412698412698412698413,
      37             :                                  0.0000248015873015873015873015873016,
      38             :                                  2.7557319223985890652557319224e-6,
      39             :                                  2.7557319223985890652557319224e-7,
      40             :                                  2.50521083854417187750521083854e-8,
      41             :                                  2.08767569878680989792100903212e-9,
      42             :                                  1.60590438368216145993923771702e-10,
      43             :                                  1.14707455977297247138516979787e-11,
      44             :                                  7.64716373181981647590113198579e-13,
      45             :                                  4.77947733238738529743820749112e-14,
      46             :                                  2.81145725434552076319894558301e-15,
      47             :                                  1.56192069685862264622163643501e-16,
      48             :                                  8.22063524662432971695598123687e-18,
      49             :                                  4.11031762331216485847799061844e-19,
      50             :                                  1.95729410633912612308475743735e-20,
      51             :                                  8.8967913924505732867488974425e-22,
      52             :                                  3.86817017063068403771691193152e-23,
      53             :                                  1.6117375710961183490487133048e-24,
      54             :                                  6.4469502843844733961948532192e-26,
      55             :                                  2.47959626322479746007494354585e-27,
      56             :                                  9.18368986379554614842571683647e-29,
      57             :                                  3.27988923706983791015204172731e-30,
      58             :                                  1.13099628864477169315587645769e-31,
      59             :                                  3.76998762881590564385292152565e-33};
      60             : 
      61             : inline int coset_without_offset(int lx, int ly, int lz) {
      62             :   const int l = lx + ly + lz;
      63             :   if (l == 0) {
      64             :     return 0;
      65             :   } else {
      66             :     return ((l - lx) * (l - lx + 1)) / 2 + lz;
      67             :   }
      68             : }
      69             : 
      70             : typedef struct dgemm_params_ {
      71             :   char storage;
      72             :   char op1;
      73             :   char op2;
      74             :   double alpha;
      75             :   double beta;
      76             :   double *a, *b, *c;
      77             :   int m, n, k, lda, ldb, ldc;
      78             :   int x, y, z;
      79             :   int x1, y1, z1;
      80             :   bool use_libxsmm;
      81             : #if defined(__LIBXSMM)
      82             :   libxsmm_dmmfunction kernel;
      83             :   int prefetch;
      84             :   int flags;
      85             : #endif
      86             : } dgemm_params;
      87             : 
      88             : extern void dgemm_simplified(dgemm_params *const m);
      89             : extern void batched_dgemm_simplified(dgemm_params *const m,
      90             :                                      const int batch_size);
      91             : 
      92             : /*******************************************************************************
      93             :  * \brief Prototype for BLAS dgemm.
      94             :  * \author Ole Schuett
      95             :  ******************************************************************************/
      96             : void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
      97             :             const int *k, const double *alpha, const double *a, const int *lda,
      98             :             const double *b, const int *ldb, const double *beta, double *c,
      99             :             const int *ldc);
     100             : 
     101             : extern void extract_sub_grid(const int *lower_corner, const int *upper_corner,
     102             :                              const int *position, const tensor *const grid,
     103             :                              tensor *const subgrid);
     104             : extern void add_sub_grid(const int *lower_corner, const int *upper_corner,
     105             :                          const int *position, const tensor *subgrid,
     106             :                          tensor *grid);
     107             : extern void return_cube_position(const int *lb_grid, const int *cube_center,
     108             :                                  const int *lower_boundaries_cube,
     109             :                                  const int *period, int *const position);
     110             : 
     111             : extern void verify_orthogonality(const double dh[3][3], bool orthogonal[3]);
     112             : 
     113             : extern int compute_cube_properties(const bool ortho, const double radius,
     114             :                                    const double dh[3][3],
     115             :                                    const double dh_inv[3][3], const double *rp,
     116             :                                    double *disr_radius, double *roffset,
     117             :                                    int *cubecenter, int *lb_cube, int *ub_cube,
     118             :                                    int *cube_size);
     119             : 
     120             : inline int return_offset_l(const int l) {
     121             :   static const int offset_[] = {1,   4,   7,   11,  16,  22,  29,
     122             :                                 37,  46,  56,  67,  79,  92,  106,
     123             :                                 121, 137, 154, 172, 191, 211, 232};
     124             :   return offset_[l];
     125             : }
     126             : 
     127             : inline int return_linear_index_from_exponents(const int alpha, const int beta,
     128             :                                               const int gamma) {
     129             :   const int l = alpha + beta + gamma;
     130             :   return return_offset_l(l) + (l - alpha) * (l - alpha + 1) / 2 + gamma;
     131             : }
     132             : 
     133      183246 : static inline void *grid_allocate_scratch(size_t size) {
     134             : #ifdef __LIBXSMM
     135      183246 :   return libxsmm_aligned_scratch(size, 0 /*auto-alignment*/);
     136             : #else
     137             :   return malloc(size);
     138             : #endif
     139             : }
     140             : 
     141      183246 : static inline void grid_free_scratch(void *ptr) {
     142             : #ifdef __LIBXSMM
     143       91702 :   libxsmm_free(ptr);
     144             : #else
     145             :   free(ptr);
     146             : #endif
     147             : }
     148             : 
     149             : /* even openblas and lapack has cblas versions of lapack and blas. */
     150             : #ifndef __MKL
     151             : enum CBLAS_LAYOUT { CblasRowMajor = 101, CblasColMajor = 102 };
     152             : enum CBLAS_TRANSPOSE {
     153             :   CblasNoTrans = 111,
     154             :   CblasTrans = 112,
     155             :   CblasConjTrans = 113
     156             : };
     157             : enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 };
     158             : enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 };
     159             : enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 };
     160             : 
     161             : typedef enum CBLAS_LAYOUT CBLAS_LAYOUT;
     162             : typedef enum CBLAS_TRANSPOSE CBLAS_TRANSPOSE;
     163             : typedef enum CBLAS_UPLO CBLAS_UPLO;
     164             : typedef enum CBLAS_DIAG CBLAS_DIAG;
     165             : 
     166             : double cblas_ddot(const int N, const double *X, const int incX, const double *Y,
     167             :                   const int incY);
     168             : 
     169             : void cblas_dger(const CBLAS_LAYOUT Layout, const int M, const int N,
     170             :                 const double alpha, const double *X, const int incX,
     171             :                 const double *Y, const int incY, double *A, const int lda);
     172             : 
     173             : void cblas_daxpy(const int N, const double alpha, const double *X,
     174             :                  const int incX, double *Y, const int incY);
     175             : 
     176             : void cblas_dgemv(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA,
     177             :                  const int M, const int N, const double alpha, const double *A,
     178             :                  const int lda, const double *X, const int incX,
     179             :                  const double beta, double *Y, const int incY);
     180             : 
     181             : #endif
     182             : 
     183             : extern void compute_interval(const int *const map, const int full_size,
     184             :                              const int size, const int cube_size, const int x1,
     185             :                              int *x, int *const lower_corner,
     186             :                              int *const upper_corner, Interval window);
     187             : #endif

Generated by: LCOV version 1.15