LCOV - code coverage report
Current view: top level - src/dbm - dbm_distribution.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b4bd748) Lines: 87 90 96.7 %
Date: 2025-03-09 07:56:22 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <math.h>
      10             : #include <omp.h>
      11             : #include <stdbool.h>
      12             : #include <stddef.h>
      13             : #include <stdlib.h>
      14             : #include <string.h>
      15             : 
      16             : #include "dbm_distribution.h"
      17             : #include "dbm_hyperparams.h"
      18             : #include "dbm_internal.h"
      19             : 
      20             : /*******************************************************************************
      21             :  * \brief Private routine for creating a new one dimensional distribution.
      22             :  * \author Ole Schuett
      23             :  ******************************************************************************/
      24     1632484 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
      25             :                             const int coords[length], const dbm_mpi_comm_t comm,
      26             :                             const int nshards) {
      27     1632484 :   dist->comm = comm;
      28     1632484 :   dist->nshards = nshards;
      29     1632484 :   dist->my_rank = dbm_mpi_comm_rank(comm);
      30     1632484 :   dist->nranks = dbm_mpi_comm_size(comm);
      31     1632484 :   dist->length = length;
      32     1632484 :   dist->index2coord = malloc(length * sizeof(int));
      33     1632484 :   assert(dist->index2coord != NULL);
      34     1632484 :   memcpy(dist->index2coord, coords, length * sizeof(int));
      35             : 
      36             :   // Check that cart coordinates and ranks are equivalent.
      37     1632484 :   int cart_dims[1], cart_periods[1], cart_coords[1];
      38     1632484 :   dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
      39     1632484 :   assert(dist->nranks == cart_dims[0]);
      40     1632484 :   assert(dist->my_rank == cart_coords[0]);
      41             : 
      42             :   // Count local rows/columns.
      43    20571438 :   for (int i = 0; i < length; i++) {
      44    18938954 :     assert(0 <= coords[i] && coords[i] < dist->nranks);
      45    18938954 :     if (coords[i] == dist->my_rank) {
      46    18045128 :       dist->nlocals++;
      47             :     }
      48             :   }
      49             : 
      50             :   // Store local rows/columns.
      51     1632484 :   dist->local_indicies = malloc(dist->nlocals * sizeof(int));
      52     1632484 :   assert(dist->local_indicies != NULL);
      53             :   int j = 0;
      54    20571438 :   for (int i = 0; i < length; i++) {
      55    18938954 :     if (coords[i] == dist->my_rank) {
      56    18045128 :       dist->local_indicies[j++] = i;
      57             :     }
      58             :   }
      59     1632484 :   assert(j == dist->nlocals);
      60     1632484 : }
      61             : 
      62             : /*******************************************************************************
      63             :  * \brief Private routine for releasing a one dimensional distribution.
      64             :  * \author Ole Schuett
      65             :  ******************************************************************************/
      66     1632484 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
      67     1632484 :   free(dist->index2coord);
      68     1632484 :   free(dist->local_indicies);
      69     1632484 :   dbm_mpi_comm_free(&dist->comm);
      70     1632484 : }
      71             : 
      72             : /*******************************************************************************
      73             :  * \brief Private routine for finding the optimal number of shard rows.
      74             :  * \author Ole Schuett
      75             :  ******************************************************************************/
      76      816242 : static int find_best_nrow_shards(const int nshards, const int nrows,
      77             :                                  const int ncols) {
      78      816242 :   const double target = imax(nrows, 1) / (double)imax(ncols, 1);
      79      816242 :   int best_nrow_shards = nshards;
      80      816242 :   double best_error = fabs(log(target / (double)nshards));
      81             : 
      82     1632484 :   for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
      83      816242 :     const int ncol_shards = nshards / nrow_shards;
      84      816242 :     if (nrow_shards * ncol_shards != nshards) {
      85           0 :       continue; // Not a factor of nshards.
      86             :     }
      87      816242 :     const double ratio = (double)nrow_shards / (double)ncol_shards;
      88      816242 :     const double error = fabs(log(target / ratio));
      89      816242 :     if (error < best_error) {
      90           0 :       best_error = error;
      91           0 :       best_nrow_shards = nrow_shards;
      92             :     }
      93             :   }
      94      816242 :   return best_nrow_shards;
      95             : }
      96             : 
      97             : /*******************************************************************************
      98             :  * \brief Creates a new two dimensional distribution.
      99             :  * \author Ole Schuett
     100             :  ******************************************************************************/
     101      816242 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
     102             :                           const int nrows, const int ncols,
     103             :                           const int row_dist[nrows],
     104             :                           const int col_dist[ncols]) {
     105      816242 :   assert(omp_get_num_threads() == 1);
     106      816242 :   dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
     107      816242 :   dist->ref_count = 1;
     108             : 
     109      816242 :   dist->comm = dbm_mpi_comm_f2c(fortran_comm);
     110      816242 :   dist->my_rank = dbm_mpi_comm_rank(dist->comm);
     111      816242 :   dist->nranks = dbm_mpi_comm_size(dist->comm);
     112             : 
     113      816242 :   const int row_dim_remains[2] = {1, 0};
     114      816242 :   const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
     115             : 
     116      816242 :   const int col_dim_remains[2] = {0, 1};
     117      816242 :   const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
     118             : 
     119      816242 :   const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
     120      816242 :   const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
     121      816242 :   const int ncol_shards = nshards / nrow_shards;
     122             : 
     123      816242 :   dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
     124      816242 :   dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
     125             : 
     126      816242 :   assert(*dist_out == NULL);
     127      816242 :   *dist_out = dist;
     128      816242 : }
     129             : 
     130             : /*******************************************************************************
     131             :  * \brief Increases the reference counter of the given distribution.
     132             :  * \author Ole Schuett
     133             :  ******************************************************************************/
     134     2379397 : void dbm_distribution_hold(dbm_distribution_t *dist) {
     135     2379397 :   assert(dist->ref_count > 0);
     136     2379397 :   dist->ref_count++;
     137     2379397 : }
     138             : 
     139             : /*******************************************************************************
     140             :  * \brief Decreases the reference counter of the given distribution.
     141             :  * \author Ole Schuett
     142             :  ******************************************************************************/
     143     3195639 : void dbm_distribution_release(dbm_distribution_t *dist) {
     144     3195639 :   assert(dist->ref_count > 0);
     145     3195639 :   dist->ref_count--;
     146     3195639 :   if (dist->ref_count == 0) {
     147      816242 :     dbm_dist_1d_free(&dist->rows);
     148      816242 :     dbm_dist_1d_free(&dist->cols);
     149      816242 :     free(dist);
     150             :   }
     151     3195639 : }
     152             : 
     153             : /*******************************************************************************
     154             :  * \brief Returns the rows of the given distribution.
     155             :  * \author Ole Schuett
     156             :  ******************************************************************************/
     157      327686 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
     158             :                                const int **row_dist) {
     159      327686 :   assert(dist->ref_count > 0);
     160      327686 :   *nrows = dist->rows.length;
     161      327686 :   *row_dist = dist->rows.index2coord;
     162      327686 : }
     163             : 
     164             : /*******************************************************************************
     165             :  * \brief Returns the columns of the given distribution.
     166             :  * \author Ole Schuett
     167             :  ******************************************************************************/
     168      327686 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
     169             :                                const int **col_dist) {
     170      327686 :   assert(dist->ref_count > 0);
     171      327686 :   *ncols = dist->cols.length;
     172      327686 :   *col_dist = dist->cols.index2coord;
     173      327686 : }
     174             : 
     175             : /*******************************************************************************
     176             :  * \brief Returns the MPI rank on which the given block should be stored.
     177             :  * \author Ole Schuett
     178             :  ******************************************************************************/
     179    92074402 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
     180             :                                    const int row, const int col) {
     181    92074402 :   assert(dist->ref_count > 0);
     182    92074402 :   assert(0 <= row && row < dist->rows.length);
     183    92074402 :   assert(0 <= col && col < dist->cols.length);
     184    92074402 :   int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
     185    92074402 :   return dbm_mpi_cart_rank(dist->comm, coords);
     186             : }
     187             : 
     188             : // EOF

Generated by: LCOV version 1.15