Line data Source code
1 : /*----------------------------------------------------------------------------*/ 2 : /* CP2K: A general program to perform molecular dynamics simulations */ 3 : /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */ 4 : /* */ 5 : /* SPDX-License-Identifier: BSD-3-Clause */ 6 : /*----------------------------------------------------------------------------*/ 7 : 8 : #include <assert.h> 9 : #include <math.h> 10 : #include <omp.h> 11 : #include <stdbool.h> 12 : #include <stddef.h> 13 : #include <stdlib.h> 14 : #include <string.h> 15 : 16 : #include "dbm_distribution.h" 17 : #include "dbm_hyperparams.h" 18 : 19 : /******************************************************************************* 20 : * \brief Private routine for creating a new one dimensional distribution. 21 : * \author Ole Schuett 22 : ******************************************************************************/ 23 1617892 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, 24 : const int coords[length], const dbm_mpi_comm_t comm, 25 : const int nshards) { 26 1617892 : dist->comm = comm; 27 1617892 : dist->nshards = nshards; 28 1617892 : dist->my_rank = dbm_mpi_comm_rank(comm); 29 1617892 : dist->nranks = dbm_mpi_comm_size(comm); 30 1617892 : dist->length = length; 31 1617892 : dist->index2coord = malloc(length * sizeof(int)); 32 1617892 : assert(dist->index2coord != NULL); 33 1617892 : memcpy(dist->index2coord, coords, length * sizeof(int)); 34 : 35 : // Check that cart coordinates and ranks are equivalent. 36 1617892 : int cart_dims[1], cart_periods[1], cart_coords[1]; 37 1617892 : dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords); 38 1617892 : assert(dist->nranks == cart_dims[0]); 39 1617892 : assert(dist->my_rank == cart_coords[0]); 40 : 41 : // Count local rows/columns. 42 20385804 : for (int i = 0; i < length; i++) { 43 18767912 : assert(0 <= coords[i] && coords[i] < dist->nranks); 44 18767912 : if (coords[i] == dist->my_rank) { 45 17848369 : dist->nlocals++; 46 : } 47 : } 48 : 49 : // Store local rows/columns. 50 1617892 : dist->local_indicies = malloc(dist->nlocals * sizeof(int)); 51 1617892 : assert(dist->local_indicies != NULL); 52 : int j = 0; 53 20385804 : for (int i = 0; i < length; i++) { 54 18767912 : if (coords[i] == dist->my_rank) { 55 17848369 : dist->local_indicies[j++] = i; 56 : } 57 : } 58 1617892 : assert(j == dist->nlocals); 59 1617892 : } 60 : 61 : /******************************************************************************* 62 : * \brief Private routine for releasing a one dimensional distribution. 63 : * \author Ole Schuett 64 : ******************************************************************************/ 65 1617892 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) { 66 1617892 : free(dist->index2coord); 67 1617892 : free(dist->local_indicies); 68 1617892 : dbm_mpi_comm_free(&dist->comm); 69 1617892 : } 70 : 71 : /******************************************************************************* 72 : * \brief Returns the larger of two given integer (missing from the C standard) 73 : * \author Ole Schuett 74 : ******************************************************************************/ 75 808946 : static inline int imax(int x, int y) { return (x > y ? x : y); } 76 : 77 : /******************************************************************************* 78 : * \brief Private routine for finding the optimal number of shard rows. 79 : * \author Ole Schuett 80 : ******************************************************************************/ 81 808946 : static int find_best_nrow_shards(const int nshards, const int nrows, 82 : const int ncols) { 83 808946 : const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1); 84 808946 : int best_nrow_shards = nshards; 85 808946 : double best_error = fabs(log(target / (double)nshards)); 86 : 87 1617892 : for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) { 88 808946 : const int ncol_shards = nshards / nrow_shards; 89 808946 : if (nrow_shards * ncol_shards != nshards) 90 0 : continue; // Not a factor of nshards. 91 808946 : const double ratio = (double)nrow_shards / (double)ncol_shards; 92 808946 : const double error = fabs(log(target / ratio)); 93 808946 : if (error < best_error) { 94 0 : best_error = error; 95 0 : best_nrow_shards = nrow_shards; 96 : } 97 : } 98 808946 : return best_nrow_shards; 99 : } 100 : 101 : /******************************************************************************* 102 : * \brief Creates a new two dimensional distribution. 103 : * \author Ole Schuett 104 : ******************************************************************************/ 105 808946 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm, 106 : const int nrows, const int ncols, 107 : const int row_dist[nrows], 108 : const int col_dist[ncols]) { 109 808946 : assert(omp_get_num_threads() == 1); 110 808946 : dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t)); 111 808946 : dist->ref_count = 1; 112 : 113 808946 : dist->comm = dbm_mpi_comm_f2c(fortran_comm); 114 808946 : dist->my_rank = dbm_mpi_comm_rank(dist->comm); 115 808946 : dist->nranks = dbm_mpi_comm_size(dist->comm); 116 : 117 808946 : const int row_dim_remains[2] = {1, 0}; 118 808946 : const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains); 119 : 120 808946 : const int col_dim_remains[2] = {0, 1}; 121 808946 : const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains); 122 : 123 808946 : const int nshards = SHARDS_PER_THREAD * omp_get_max_threads(); 124 808946 : const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols); 125 808946 : const int ncol_shards = nshards / nrow_shards; 126 : 127 808946 : dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards); 128 808946 : dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards); 129 : 130 808946 : assert(*dist_out == NULL); 131 808946 : *dist_out = dist; 132 808946 : } 133 : 134 : /******************************************************************************* 135 : * \brief Increases the reference counter of the given distribution. 136 : * \author Ole Schuett 137 : ******************************************************************************/ 138 2319033 : void dbm_distribution_hold(dbm_distribution_t *dist) { 139 2319033 : assert(dist->ref_count > 0); 140 2319033 : dist->ref_count++; 141 2319033 : } 142 : 143 : /******************************************************************************* 144 : * \brief Decreases the reference counter of the given distribution. 145 : * \author Ole Schuett 146 : ******************************************************************************/ 147 3127979 : void dbm_distribution_release(dbm_distribution_t *dist) { 148 3127979 : assert(dist->ref_count > 0); 149 3127979 : dist->ref_count--; 150 3127979 : if (dist->ref_count == 0) { 151 808946 : dbm_dist_1d_free(&dist->rows); 152 808946 : dbm_dist_1d_free(&dist->cols); 153 808946 : free(dist); 154 : } 155 3127979 : } 156 : 157 : /******************************************************************************* 158 : * \brief Returns the rows of the given distribution. 159 : * \author Ole Schuett 160 : ******************************************************************************/ 161 320714 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows, 162 : const int **row_dist) { 163 320714 : assert(dist->ref_count > 0); 164 320714 : *nrows = dist->rows.length; 165 320714 : *row_dist = dist->rows.index2coord; 166 320714 : } 167 : 168 : /******************************************************************************* 169 : * \brief Returns the columns of the given distribution. 170 : * \author Ole Schuett 171 : ******************************************************************************/ 172 320714 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols, 173 : const int **col_dist) { 174 320714 : assert(dist->ref_count > 0); 175 320714 : *ncols = dist->cols.length; 176 320714 : *col_dist = dist->cols.index2coord; 177 320714 : } 178 : 179 : /******************************************************************************* 180 : * \brief Returns the MPI rank on which the given block should be stored. 181 : * \author Ole Schuett 182 : ******************************************************************************/ 183 91788356 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist, 184 : const int row, const int col) { 185 91788356 : assert(dist->ref_count > 0); 186 91788356 : assert(0 <= row && row < dist->rows.length); 187 91788356 : assert(0 <= col && col < dist->cols.length); 188 91788356 : int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]}; 189 91788356 : return dbm_mpi_cart_rank(dist->comm, coords); 190 : } 191 : 192 : // EOF