LCOV - code coverage report
Current view: top level - src/dbm - dbm_distribution.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:262480d) Lines: 88 91 96.7 %
Date: 2024-11-22 07:00:40 Functions: 9 9 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <assert.h>
       9             : #include <math.h>
      10             : #include <omp.h>
      11             : #include <stdbool.h>
      12             : #include <stddef.h>
      13             : #include <stdlib.h>
      14             : #include <string.h>
      15             : 
      16             : #include "dbm_distribution.h"
      17             : #include "dbm_hyperparams.h"
      18             : 
      19             : /*******************************************************************************
      20             :  * \brief Private routine for creating a new one dimensional distribution.
      21             :  * \author Ole Schuett
      22             :  ******************************************************************************/
      23     1617892 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
      24             :                             const int coords[length], const dbm_mpi_comm_t comm,
      25             :                             const int nshards) {
      26     1617892 :   dist->comm = comm;
      27     1617892 :   dist->nshards = nshards;
      28     1617892 :   dist->my_rank = dbm_mpi_comm_rank(comm);
      29     1617892 :   dist->nranks = dbm_mpi_comm_size(comm);
      30     1617892 :   dist->length = length;
      31     1617892 :   dist->index2coord = malloc(length * sizeof(int));
      32     1617892 :   assert(dist->index2coord != NULL);
      33     1617892 :   memcpy(dist->index2coord, coords, length * sizeof(int));
      34             : 
      35             :   // Check that cart coordinates and ranks are equivalent.
      36     1617892 :   int cart_dims[1], cart_periods[1], cart_coords[1];
      37     1617892 :   dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
      38     1617892 :   assert(dist->nranks == cart_dims[0]);
      39     1617892 :   assert(dist->my_rank == cart_coords[0]);
      40             : 
      41             :   // Count local rows/columns.
      42    20385804 :   for (int i = 0; i < length; i++) {
      43    18767912 :     assert(0 <= coords[i] && coords[i] < dist->nranks);
      44    18767912 :     if (coords[i] == dist->my_rank) {
      45    17848369 :       dist->nlocals++;
      46             :     }
      47             :   }
      48             : 
      49             :   // Store local rows/columns.
      50     1617892 :   dist->local_indicies = malloc(dist->nlocals * sizeof(int));
      51     1617892 :   assert(dist->local_indicies != NULL);
      52             :   int j = 0;
      53    20385804 :   for (int i = 0; i < length; i++) {
      54    18767912 :     if (coords[i] == dist->my_rank) {
      55    17848369 :       dist->local_indicies[j++] = i;
      56             :     }
      57             :   }
      58     1617892 :   assert(j == dist->nlocals);
      59     1617892 : }
      60             : 
      61             : /*******************************************************************************
      62             :  * \brief Private routine for releasing a one dimensional distribution.
      63             :  * \author Ole Schuett
      64             :  ******************************************************************************/
      65     1617892 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
      66     1617892 :   free(dist->index2coord);
      67     1617892 :   free(dist->local_indicies);
      68     1617892 :   dbm_mpi_comm_free(&dist->comm);
      69     1617892 : }
      70             : 
      71             : /*******************************************************************************
      72             :  * \brief Returns the larger of two given integer (missing from the C standard)
      73             :  * \author Ole Schuett
      74             :  ******************************************************************************/
      75      808946 : static inline int imax(int x, int y) { return (x > y ? x : y); }
      76             : 
      77             : /*******************************************************************************
      78             :  * \brief Private routine for finding the optimal number of shard rows.
      79             :  * \author Ole Schuett
      80             :  ******************************************************************************/
      81      808946 : static int find_best_nrow_shards(const int nshards, const int nrows,
      82             :                                  const int ncols) {
      83      808946 :   const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1);
      84      808946 :   int best_nrow_shards = nshards;
      85      808946 :   double best_error = fabs(log(target / (double)nshards));
      86             : 
      87     1617892 :   for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
      88      808946 :     const int ncol_shards = nshards / nrow_shards;
      89      808946 :     if (nrow_shards * ncol_shards != nshards)
      90           0 :       continue; // Not a factor of nshards.
      91      808946 :     const double ratio = (double)nrow_shards / (double)ncol_shards;
      92      808946 :     const double error = fabs(log(target / ratio));
      93      808946 :     if (error < best_error) {
      94           0 :       best_error = error;
      95           0 :       best_nrow_shards = nrow_shards;
      96             :     }
      97             :   }
      98      808946 :   return best_nrow_shards;
      99             : }
     100             : 
     101             : /*******************************************************************************
     102             :  * \brief Creates a new two dimensional distribution.
     103             :  * \author Ole Schuett
     104             :  ******************************************************************************/
     105      808946 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
     106             :                           const int nrows, const int ncols,
     107             :                           const int row_dist[nrows],
     108             :                           const int col_dist[ncols]) {
     109      808946 :   assert(omp_get_num_threads() == 1);
     110      808946 :   dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
     111      808946 :   dist->ref_count = 1;
     112             : 
     113      808946 :   dist->comm = dbm_mpi_comm_f2c(fortran_comm);
     114      808946 :   dist->my_rank = dbm_mpi_comm_rank(dist->comm);
     115      808946 :   dist->nranks = dbm_mpi_comm_size(dist->comm);
     116             : 
     117      808946 :   const int row_dim_remains[2] = {1, 0};
     118      808946 :   const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
     119             : 
     120      808946 :   const int col_dim_remains[2] = {0, 1};
     121      808946 :   const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
     122             : 
     123      808946 :   const int nshards = SHARDS_PER_THREAD * omp_get_max_threads();
     124      808946 :   const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
     125      808946 :   const int ncol_shards = nshards / nrow_shards;
     126             : 
     127      808946 :   dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
     128      808946 :   dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
     129             : 
     130      808946 :   assert(*dist_out == NULL);
     131      808946 :   *dist_out = dist;
     132      808946 : }
     133             : 
     134             : /*******************************************************************************
     135             :  * \brief Increases the reference counter of the given distribution.
     136             :  * \author Ole Schuett
     137             :  ******************************************************************************/
     138     2319033 : void dbm_distribution_hold(dbm_distribution_t *dist) {
     139     2319033 :   assert(dist->ref_count > 0);
     140     2319033 :   dist->ref_count++;
     141     2319033 : }
     142             : 
     143             : /*******************************************************************************
     144             :  * \brief Decreases the reference counter of the given distribution.
     145             :  * \author Ole Schuett
     146             :  ******************************************************************************/
     147     3127979 : void dbm_distribution_release(dbm_distribution_t *dist) {
     148     3127979 :   assert(dist->ref_count > 0);
     149     3127979 :   dist->ref_count--;
     150     3127979 :   if (dist->ref_count == 0) {
     151      808946 :     dbm_dist_1d_free(&dist->rows);
     152      808946 :     dbm_dist_1d_free(&dist->cols);
     153      808946 :     free(dist);
     154             :   }
     155     3127979 : }
     156             : 
     157             : /*******************************************************************************
     158             :  * \brief Returns the rows of the given distribution.
     159             :  * \author Ole Schuett
     160             :  ******************************************************************************/
     161      320714 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
     162             :                                const int **row_dist) {
     163      320714 :   assert(dist->ref_count > 0);
     164      320714 :   *nrows = dist->rows.length;
     165      320714 :   *row_dist = dist->rows.index2coord;
     166      320714 : }
     167             : 
     168             : /*******************************************************************************
     169             :  * \brief Returns the columns of the given distribution.
     170             :  * \author Ole Schuett
     171             :  ******************************************************************************/
     172      320714 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
     173             :                                const int **col_dist) {
     174      320714 :   assert(dist->ref_count > 0);
     175      320714 :   *ncols = dist->cols.length;
     176      320714 :   *col_dist = dist->cols.index2coord;
     177      320714 : }
     178             : 
     179             : /*******************************************************************************
     180             :  * \brief Returns the MPI rank on which the given block should be stored.
     181             :  * \author Ole Schuett
     182             :  ******************************************************************************/
     183    91788356 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist,
     184             :                                    const int row, const int col) {
     185    91788356 :   assert(dist->ref_count > 0);
     186    91788356 :   assert(0 <= row && row < dist->rows.length);
     187    91788356 :   assert(0 <= col && col < dist->cols.length);
     188    91788356 :   int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
     189    91788356 :   return dbm_mpi_cart_rank(dist->comm, coords);
     190             : }
     191             : 
     192             : // EOF

Generated by: LCOV version 1.15