Line data Source code
1 : /*----------------------------------------------------------------------------*/ 2 : /* CP2K: A general program to perform molecular dynamics simulations */ 3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */ 4 : /* */ 5 : /* SPDX-License-Identifier: BSD-3-Clause */ 6 : /*----------------------------------------------------------------------------*/ 7 : 8 : #include <assert.h> 9 : #include <math.h> 10 : #include <omp.h> 11 : #include <stdbool.h> 12 : #include <stddef.h> 13 : #include <stdlib.h> 14 : #include <string.h> 15 : 16 : #include "dbm_distribution.h" 17 : #include "dbm_hyperparams.h" 18 : #include "dbm_internal.h" 19 : 20 : /******************************************************************************* 21 : * \brief Private routine for creating a new one dimensional distribution. 22 : * \author Ole Schuett 23 : ******************************************************************************/ 24 1632484 : static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, 25 : const int coords[length], const dbm_mpi_comm_t comm, 26 : const int nshards) { 27 1632484 : dist->comm = comm; 28 1632484 : dist->nshards = nshards; 29 1632484 : dist->my_rank = dbm_mpi_comm_rank(comm); 30 1632484 : dist->nranks = dbm_mpi_comm_size(comm); 31 1632484 : dist->length = length; 32 1632484 : dist->index2coord = malloc(length * sizeof(int)); 33 1632484 : assert(dist->index2coord != NULL); 34 1632484 : memcpy(dist->index2coord, coords, length * sizeof(int)); 35 : 36 : // Check that cart coordinates and ranks are equivalent. 37 1632484 : int cart_dims[1], cart_periods[1], cart_coords[1]; 38 1632484 : dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords); 39 1632484 : assert(dist->nranks == cart_dims[0]); 40 1632484 : assert(dist->my_rank == cart_coords[0]); 41 : 42 : // Count local rows/columns. 43 20571438 : for (int i = 0; i < length; i++) { 44 18938954 : assert(0 <= coords[i] && coords[i] < dist->nranks); 45 18938954 : if (coords[i] == dist->my_rank) { 46 18045128 : dist->nlocals++; 47 : } 48 : } 49 : 50 : // Store local rows/columns. 51 1632484 : dist->local_indicies = malloc(dist->nlocals * sizeof(int)); 52 1632484 : assert(dist->local_indicies != NULL); 53 : int j = 0; 54 20571438 : for (int i = 0; i < length; i++) { 55 18938954 : if (coords[i] == dist->my_rank) { 56 18045128 : dist->local_indicies[j++] = i; 57 : } 58 : } 59 1632484 : assert(j == dist->nlocals); 60 1632484 : } 61 : 62 : /******************************************************************************* 63 : * \brief Private routine for releasing a one dimensional distribution. 64 : * \author Ole Schuett 65 : ******************************************************************************/ 66 1632484 : static void dbm_dist_1d_free(dbm_dist_1d_t *dist) { 67 1632484 : free(dist->index2coord); 68 1632484 : free(dist->local_indicies); 69 1632484 : dbm_mpi_comm_free(&dist->comm); 70 1632484 : } 71 : 72 : /******************************************************************************* 73 : * \brief Private routine for finding the optimal number of shard rows. 74 : * \author Ole Schuett 75 : ******************************************************************************/ 76 816242 : static int find_best_nrow_shards(const int nshards, const int nrows, 77 : const int ncols) { 78 816242 : const double target = imax(nrows, 1) / (double)imax(ncols, 1); 79 816242 : int best_nrow_shards = nshards; 80 816242 : double best_error = fabs(log(target / (double)nshards)); 81 : 82 1632484 : for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) { 83 816242 : const int ncol_shards = nshards / nrow_shards; 84 816242 : if (nrow_shards * ncol_shards != nshards) { 85 0 : continue; // Not a factor of nshards. 86 : } 87 816242 : const double ratio = (double)nrow_shards / (double)ncol_shards; 88 816242 : const double error = fabs(log(target / ratio)); 89 816242 : if (error < best_error) { 90 0 : best_error = error; 91 0 : best_nrow_shards = nrow_shards; 92 : } 93 : } 94 816242 : return best_nrow_shards; 95 : } 96 : 97 : /******************************************************************************* 98 : * \brief Creates a new two dimensional distribution. 99 : * \author Ole Schuett 100 : ******************************************************************************/ 101 816242 : void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm, 102 : const int nrows, const int ncols, 103 : const int row_dist[nrows], 104 : const int col_dist[ncols]) { 105 816242 : assert(omp_get_num_threads() == 1); 106 816242 : dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t)); 107 816242 : dist->ref_count = 1; 108 : 109 816242 : dist->comm = dbm_mpi_comm_f2c(fortran_comm); 110 816242 : dist->my_rank = dbm_mpi_comm_rank(dist->comm); 111 816242 : dist->nranks = dbm_mpi_comm_size(dist->comm); 112 : 113 816242 : const int row_dim_remains[2] = {1, 0}; 114 816242 : const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains); 115 : 116 816242 : const int col_dim_remains[2] = {0, 1}; 117 816242 : const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains); 118 : 119 816242 : const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads(); 120 816242 : const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols); 121 816242 : const int ncol_shards = nshards / nrow_shards; 122 : 123 816242 : dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards); 124 816242 : dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards); 125 : 126 816242 : assert(*dist_out == NULL); 127 816242 : *dist_out = dist; 128 816242 : } 129 : 130 : /******************************************************************************* 131 : * \brief Increases the reference counter of the given distribution. 132 : * \author Ole Schuett 133 : ******************************************************************************/ 134 2379397 : void dbm_distribution_hold(dbm_distribution_t *dist) { 135 2379397 : assert(dist->ref_count > 0); 136 2379397 : dist->ref_count++; 137 2379397 : } 138 : 139 : /******************************************************************************* 140 : * \brief Decreases the reference counter of the given distribution. 141 : * \author Ole Schuett 142 : ******************************************************************************/ 143 3195639 : void dbm_distribution_release(dbm_distribution_t *dist) { 144 3195639 : assert(dist->ref_count > 0); 145 3195639 : dist->ref_count--; 146 3195639 : if (dist->ref_count == 0) { 147 816242 : dbm_dist_1d_free(&dist->rows); 148 816242 : dbm_dist_1d_free(&dist->cols); 149 816242 : free(dist); 150 : } 151 3195639 : } 152 : 153 : /******************************************************************************* 154 : * \brief Returns the rows of the given distribution. 155 : * \author Ole Schuett 156 : ******************************************************************************/ 157 327686 : void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows, 158 : const int **row_dist) { 159 327686 : assert(dist->ref_count > 0); 160 327686 : *nrows = dist->rows.length; 161 327686 : *row_dist = dist->rows.index2coord; 162 327686 : } 163 : 164 : /******************************************************************************* 165 : * \brief Returns the columns of the given distribution. 166 : * \author Ole Schuett 167 : ******************************************************************************/ 168 327686 : void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols, 169 : const int **col_dist) { 170 327686 : assert(dist->ref_count > 0); 171 327686 : *ncols = dist->cols.length; 172 327686 : *col_dist = dist->cols.index2coord; 173 327686 : } 174 : 175 : /******************************************************************************* 176 : * \brief Returns the MPI rank on which the given block should be stored. 177 : * \author Ole Schuett 178 : ******************************************************************************/ 179 92074402 : int dbm_distribution_stored_coords(const dbm_distribution_t *dist, 180 : const int row, const int col) { 181 92074402 : assert(dist->ref_count > 0); 182 92074402 : assert(0 <= row && row < dist->rows.length); 183 92074402 : assert(0 <= col && col < dist->cols.length); 184 92074402 : int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]}; 185 92074402 : return dbm_mpi_cart_rank(dist->comm, coords); 186 : } 187 : 188 : // EOF