LCOV - code coverage report
Current view: top level - src/grid/common - grid_process_vab.h (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:69d170a) Lines: 95 95 100.0 %
Date: 2024-12-23 07:26:16 Functions: 10 10 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <stdbool.h>
       9             : 
      10             : #if defined(__CUDACC__) || defined(__HIPCC__)
      11             : #define GRID_DEVICE __device__
      12             : #else
      13             : #define GRID_DEVICE
      14             : #endif
      15             : 
      16             : /*******************************************************************************
      17             :  * \brief Returns matrix element cab[idx(b)][idx(a)].
      18             :  *        This function has to be implemented by the importing compilation unit.
      19             :  *        A simple implementation is just: returns cab[idx(b) * n1 + idx(a)];
      20             :  * \author Ole Schuett
      21             :  ******************************************************************************/
      22             : GRID_DEVICE static inline double cab_get(const cab_store *cab, const orbital a,
      23             :                                          const orbital b);
      24             : 
      25             : /*******************************************************************************
      26             :  * \brief Returns i'th component of force on atom a for compute_tau=false.
      27             :  * \author Ole Schuett
      28             :  ******************************************************************************/
      29             : GRID_DEVICE static inline double
      30   457019442 : get_force_a_normal(const orbital a, const orbital b, const int i,
      31             :                    const double zeta, const cab_store *cab) {
      32   457019442 :   const double aip1 = cab_get(cab, up(i, a), b);
      33   457019442 :   const double aim1 = cab_get(cab, down(i, a), b);
      34   457019442 :   return 2.0 * zeta * aip1 - a.l[i] * aim1;
      35             : }
      36             : 
      37             : /*******************************************************************************
      38             :  * \brief Returns i'th component of force on atom a.
      39             :  * \author Ole Schuett
      40             :  ******************************************************************************/
      41             : GRID_DEVICE static inline double
      42   421071552 : get_force_a(const orbital a, const orbital b, const int i, const double zeta,
      43             :             const double zetb, const cab_store *cab, const bool compute_tau) {
      44   421071552 :   if (!compute_tau) {
      45   417803562 :     return get_force_a_normal(a, b, i, zeta, cab);
      46             :   } else {
      47             :     double force = 0.0;
      48    13071960 :     for (int k = 0; k < 3; k++) {
      49     9803970 :       force += 0.5 * a.l[k] * b.l[k] *
      50     9803970 :                get_force_a_normal(down(k, a), down(k, b), i, zeta, cab);
      51     9803970 :       force -= zeta * b.l[k] *
      52     9803970 :                get_force_a_normal(up(k, a), down(k, b), i, zeta, cab);
      53     9803970 :       force -= a.l[k] * zetb *
      54     9803970 :                get_force_a_normal(down(k, a), up(k, b), i, zeta, cab);
      55     9803970 :       force += 2.0 * zeta * zetb *
      56     9803970 :                get_force_a_normal(up(k, a), up(k, b), i, zeta, cab);
      57             :     }
      58             :     return force;
      59             :   }
      60             : }
      61             : 
      62             : /*******************************************************************************
      63             :  * \brief Returns i'th component of force on atom b for compute_tau=false.
      64             :  * \author Ole Schuett
      65             :  ******************************************************************************/
      66             : GRID_DEVICE static inline double
      67   456993378 : get_force_b_normal(const orbital a, const orbital b, const int i,
      68             :                    const double zetb, const double rab[3],
      69             :                    const cab_store *cab) {
      70   456993378 :   const double axpm0 = cab_get(cab, a, b);
      71   456993378 :   const double aip1 = cab_get(cab, up(i, a), b);
      72   456993378 :   const double bim1 = cab_get(cab, a, down(i, b));
      73   456993378 :   return 2.0 * zetb * (aip1 - rab[i] * axpm0) - b.l[i] * bim1;
      74             : }
      75             : 
      76             : /*******************************************************************************
      77             :  * \brief Returns i'th component of force on atom b.
      78             :  * \author Ole Schuett
      79             :  ******************************************************************************/
      80             : GRID_DEVICE static inline double
      81   421045488 : get_force_b(const orbital a, const orbital b, const int i, const double zeta,
      82             :             const double zetb, const double rab[3], const cab_store *cab,
      83             :             const bool compute_tau) {
      84   421045488 :   if (!compute_tau) {
      85   417777498 :     return get_force_b_normal(a, b, i, zetb, rab, cab);
      86             :   } else {
      87             :     double force = 0.0;
      88    13071960 :     for (int k = 0; k < 3; k++) {
      89     9803970 :       force += 0.5 * a.l[k] * b.l[k] *
      90     9803970 :                get_force_b_normal(down(k, a), down(k, b), i, zetb, rab, cab);
      91     9803970 :       force -= zeta * b.l[k] *
      92     9803970 :                get_force_b_normal(up(k, a), down(k, b), i, zetb, rab, cab);
      93     9803970 :       force -= a.l[k] * zetb *
      94     9803970 :                get_force_b_normal(down(k, a), up(k, b), i, zetb, rab, cab);
      95     9803970 :       force += 2.0 * zeta * zetb *
      96     9803970 :                get_force_b_normal(up(k, a), up(k, b), i, zetb, rab, cab);
      97             :     }
      98             :     return force;
      99             :   }
     100             : }
     101             : 
     102             : /*******************************************************************************
     103             :  * \brief Returns element i,j of virial on atom a for compute_tau=false.
     104             :  * \author Ole Schuett
     105             :  ******************************************************************************/
     106             : GRID_DEVICE static inline double
     107   263129859 : get_virial_a_normal(const orbital a, const orbital b, const int i, const int j,
     108             :                     const double zeta, const cab_store *cab) {
     109   263129859 :   return 2.0 * zeta * cab_get(cab, up(i, up(j, a)), b) -
     110   263129859 :          a.l[j] * cab_get(cab, up(i, down(j, a)), b);
     111             : }
     112             : 
     113             : /*******************************************************************************
     114             :  * \brief Returns element i,j of virial on atom a.
     115             :  * \author Ole Schuett
     116             :  ******************************************************************************/
     117             : GRID_DEVICE static inline double
     118   198313371 : get_virial_a(const orbital a, const orbital b, const int i, const int j,
     119             :              const double zeta, const double zetb, const cab_store *cab,
     120             :              const bool compute_tau) {
     121             : 
     122   198313371 :   if (!compute_tau) {
     123   192420963 :     return get_virial_a_normal(a, b, i, j, zeta, cab);
     124             :   } else {
     125             :     double virial = 0.0;
     126    23569632 :     for (int k = 0; k < 3; k++) {
     127    17677224 :       virial += 0.5 * a.l[k] * b.l[k] *
     128    17677224 :                 get_virial_a_normal(down(k, a), down(k, b), i, j, zeta, cab);
     129    17677224 :       virial -= zeta * b.l[k] *
     130    17677224 :                 get_virial_a_normal(up(k, a), down(k, b), i, j, zeta, cab);
     131    17677224 :       virial -= a.l[k] * zetb *
     132    17677224 :                 get_virial_a_normal(down(k, a), up(k, b), i, j, zeta, cab);
     133    17677224 :       virial += 2.0 * zeta * zetb *
     134    17677224 :                 get_virial_a_normal(up(k, a), up(k, b), i, j, zeta, cab);
     135             :     }
     136             :     return virial;
     137             :   }
     138             : }
     139             : 
     140             : /*******************************************************************************
     141             :  * \brief Returns element i,j of virial on atom b for compute_tau=false.
     142             :  * \author Ole Schuett
     143             :  ******************************************************************************/
     144             : GRID_DEVICE static inline double
     145   263126619 : get_virial_b_normal(const orbital a, const orbital b, const int i, const int j,
     146             :                     const double zetb, const double rab[3],
     147             :                     const cab_store *cab) {
     148             : 
     149   263126619 :   return 2.0 * zetb *
     150   263126619 :              (cab_get(cab, up(i, up(j, a)), b) -
     151   263126619 :               cab_get(cab, up(i, a), b) * rab[j] -
     152   263126619 :               cab_get(cab, up(j, a), b) * rab[i] +
     153   263126619 :               cab_get(cab, a, b) * rab[j] * rab[i]) -
     154   263126619 :          b.l[j] * cab_get(cab, a, up(i, down(j, b)));
     155             : }
     156             : 
     157             : /*******************************************************************************
     158             :  * \brief Returns element i,j of virial on atom b.
     159             :  * \author Ole Schuett
     160             :  ******************************************************************************/
     161             : GRID_DEVICE static inline double
     162   198310131 : get_virial_b(const orbital a, const orbital b, const int i, const int j,
     163             :              const double zeta, const double zetb, const double rab[3],
     164             :              const cab_store *cab, const bool compute_tau) {
     165             : 
     166   198310131 :   if (!compute_tau) {
     167   192417723 :     return get_virial_b_normal(a, b, i, j, zetb, rab, cab);
     168             :   } else {
     169             :     double virial = 0.0;
     170    23569632 :     for (int k = 0; k < 3; k++) {
     171    17677224 :       virial +=
     172    17677224 :           0.5 * a.l[k] * b.l[k] *
     173    17677224 :           get_virial_b_normal(down(k, a), down(k, b), i, j, zetb, rab, cab);
     174    17677224 :       virial -= zeta * b.l[k] *
     175    17677224 :                 get_virial_b_normal(up(k, a), down(k, b), i, j, zetb, rab, cab);
     176    17677224 :       virial -= a.l[k] * zetb *
     177    17677224 :                 get_virial_b_normal(down(k, a), up(k, b), i, j, zetb, rab, cab);
     178    17677224 :       virial += 2.0 * zeta * zetb *
     179    17677224 :                 get_virial_b_normal(up(k, a), up(k, b), i, j, zetb, rab, cab);
     180             :     }
     181             :     return virial;
     182             :   }
     183             : }
     184             : 
     185             : /*******************************************************************************
     186             :  * \brief Returns element i,j of hab matrix.
     187             :  * \author Ole Schuett
     188             :  ******************************************************************************/
     189   933016274 : GRID_DEVICE static inline double get_hab(const orbital a, const orbital b,
     190             :                                          const double zeta, const double zetb,
     191             :                                          const cab_store *cab,
     192             :                                          const bool compute_tau) {
     193   933016274 :   if (!compute_tau) {
     194   922316073 :     return cab_get(cab, a, b);
     195             :   } else {
     196             :     double hab = 0.0;
     197    42800804 :     for (int k = 0; k < 3; k++) {
     198    32100603 :       hab += 0.5 * a.l[k] * b.l[k] * cab_get(cab, down(k, a), down(k, b));
     199    32100603 :       hab -= zeta * b.l[k] * cab_get(cab, up(k, a), down(k, b));
     200    32100603 :       hab -= a.l[k] * zetb * cab_get(cab, down(k, a), up(k, b));
     201    32100603 :       hab += 2.0 * zeta * zetb * cab_get(cab, up(k, a), up(k, b));
     202             :     }
     203             :     return hab;
     204             :   }
     205             : }
     206             : 
     207             : /*******************************************************************************
     208             :  * \brief Differences in angular momentum.
     209             :  * \author Ole Schuett
     210             :  ******************************************************************************/
     211             : typedef struct {
     212             :   int la_max_diff;
     213             :   int la_min_diff;
     214             :   int lb_max_diff;
     215             :   int lb_min_diff;
     216             : } process_ldiffs;
     217             : 
     218             : /*******************************************************************************
     219             :  * \brief Returns difference in angular momentum range for given flags.
     220             :  * \author Ole Schuett
     221             :  ******************************************************************************/
     222    88078096 : static process_ldiffs process_get_ldiffs(bool calculate_forces,
     223             :                                          bool calculate_virial,
     224             :                                          bool compute_tau) {
     225    88078096 :   process_ldiffs ldiffs;
     226             : 
     227    88078096 :   ldiffs.la_max_diff = 0;
     228    88078096 :   ldiffs.lb_max_diff = 0;
     229    88078096 :   ldiffs.la_min_diff = 0;
     230    88078096 :   ldiffs.lb_min_diff = 0;
     231             : 
     232    88078096 :   if (calculate_forces || calculate_virial) {
     233    18043610 :     ldiffs.la_max_diff += 1; // for deriv. of gaussian, unimportant which one
     234    18043610 :     ldiffs.la_min_diff -= 1;
     235    18043610 :     ldiffs.lb_min_diff -= 1;
     236             :   }
     237             : 
     238    18043610 :   if (calculate_virial) {
     239     3016090 :     ldiffs.la_max_diff += 1;
     240     3016090 :     ldiffs.lb_max_diff += 1;
     241             :   }
     242             : 
     243    88078096 :   if (compute_tau) {
     244     1335552 :     ldiffs.la_max_diff += 1;
     245     1335552 :     ldiffs.lb_max_diff += 1;
     246     1335552 :     ldiffs.la_min_diff -= 1;
     247     1335552 :     ldiffs.lb_min_diff -= 1;
     248             :   }
     249             : 
     250    88078096 :   return ldiffs;
     251             : }
     252             : 
     253             : // EOF

Generated by: LCOV version 1.15