LCOV - code coverage report
Current view: top level - src/dbm - dbm_multiply_comm.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 169 169 100.0 %
Date: 2024-12-21 06:28:57 Functions: 13 13 100.0 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include "dbm_multiply_comm.h"
       9             : 
      10             : #include <assert.h>
      11             : #include <stdlib.h>
      12             : #include <string.h>
      13             : 
      14             : #include "dbm_hyperparams.h"
      15             : #include "dbm_mempool.h"
      16             : #include "dbm_mpi.h"
      17             : 
      18             : /*******************************************************************************
      19             :  * \brief Returns the larger of two given integer (missing from the C standard)
      20             :  * \author Ole Schuett
      21             :  ******************************************************************************/
      22      851452 : static inline int imax(int x, int y) { return (x > y ? x : y); }
      23             : 
      24             : /*******************************************************************************
      25             :  * \brief Private routine for computing greatest common divisor of two numbers.
      26             :  * \author Ole Schuett
      27             :  ******************************************************************************/
      28      425374 : static int gcd(const int a, const int b) {
      29      425374 :   if (a == 0)
      30             :     return b;
      31      223411 :   return gcd(b % a, a); // Euclid's algorithm.
      32             : }
      33             : 
      34             : /*******************************************************************************
      35             :  * \brief Private routine for computing least common multiple of two numbers.
      36             :  * \author Ole Schuett
      37             :  ******************************************************************************/
      38      201963 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
      39             : 
      40             : /*******************************************************************************
      41             :  * \brief Private routine for computing the sum of the given integers.
      42             :  * \author Ole Schuett
      43             :  ******************************************************************************/
      44      851452 : static inline int isum(const int n, const int input[n]) {
      45      851452 :   int output = 0;
      46     1833704 :   for (int i = 0; i < n; i++) {
      47      982252 :     output += input[i];
      48             :   }
      49      851452 :   return output;
      50             : }
      51             : 
      52             : /*******************************************************************************
      53             :  * \brief Private routine for computing the cumulative sums of given numbers.
      54             :  * \author Ole Schuett
      55             :  ******************************************************************************/
      56     2128630 : static inline void icumsum(const int n, const int input[n], int output[n]) {
      57     2128630 :   output[0] = 0;
      58     2390230 :   for (int i = 1; i < n; i++) {
      59      261600 :     output[i] = output[i - 1] + input[i - 1];
      60             :   }
      61     2128630 : }
      62             : 
      63             : /*******************************************************************************
      64             :  * \brief Private struct used for planing during pack_matrix.
      65             :  * \author Ole Schuett
      66             :  ******************************************************************************/
      67             : typedef struct {
      68             :   const dbm_block_t *blk; // source block
      69             :   int rank;               // target mpi rank
      70             :   int row_size;
      71             :   int col_size;
      72             : } plan_t;
      73             : 
      74             : /*******************************************************************************
      75             :  * \brief Private routine for planing packs.
      76             :  * \author Ole Schuett
      77             :  ******************************************************************************/
      78      403926 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
      79             :                               const dbm_matrix_t *matrix,
      80             :                               const dbm_mpi_comm_t comm,
      81             :                               const dbm_dist_1d_t *dist_indices,
      82             :                               const dbm_dist_1d_t *dist_ticks, const int nticks,
      83             :                               const int npacks, plan_t *plans_per_pack[npacks],
      84             :                               int nblks_per_pack[npacks],
      85             :                               int ndata_per_pack[npacks]) {
      86             : 
      87      403926 :   memset(nblks_per_pack, 0, npacks * sizeof(int));
      88      403926 :   memset(ndata_per_pack, 0, npacks * sizeof(int));
      89             : 
      90      403926 : #pragma omp parallel
      91             :   {
      92             :     // 1st pass: Compute number of blocks that will be send in each pack.
      93             :     int nblks_mythread[npacks];
      94             :     memset(nblks_mythread, 0, npacks * sizeof(int));
      95             : #pragma omp for schedule(static)
      96             :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
      97             :       dbm_shard_t *shard = &matrix->shards[ishard];
      98             :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
      99             :         const dbm_block_t *blk = &shard->blocks[iblock];
     100             :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
     101             :         const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
     102             :         const int ipack = itick / dist_ticks->nranks;
     103             :         nblks_mythread[ipack]++;
     104             :       }
     105             :     }
     106             : 
     107             :     // Sum nblocks across threads and allocate arrays for plans.
     108             : #pragma omp critical
     109             :     for (int ipack = 0; ipack < npacks; ipack++) {
     110             :       nblks_per_pack[ipack] += nblks_mythread[ipack];
     111             :       nblks_mythread[ipack] = nblks_per_pack[ipack];
     112             :     }
     113             : #pragma omp barrier
     114             : #pragma omp for
     115             :     for (int ipack = 0; ipack < npacks; ipack++) {
     116             :       plans_per_pack[ipack] = malloc(nblks_per_pack[ipack] * sizeof(plan_t));
     117             :       assert(plans_per_pack[ipack] != NULL);
     118             :     }
     119             : 
     120             :     // 2nd pass: Plan where to send each block.
     121             :     int ndata_mythread[npacks];
     122             :     memset(ndata_mythread, 0, npacks * sizeof(int));
     123             : #pragma omp for schedule(static) // Need static to match previous loop.
     124             :     for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
     125             :       dbm_shard_t *shard = &matrix->shards[ishard];
     126             :       for (int iblock = 0; iblock < shard->nblocks; iblock++) {
     127             :         const dbm_block_t *blk = &shard->blocks[iblock];
     128             :         const int free_index = (trans_matrix) ? blk->col : blk->row;
     129             :         const int sum_index = (trans_matrix) ? blk->row : blk->col;
     130             :         const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
     131             :         const int ipack = itick / dist_ticks->nranks;
     132             :         // Compute rank to which this block should be sent.
     133             :         const int coord_free_idx = dist_indices->index2coord[free_index];
     134             :         const int coord_sum_idx = itick % dist_ticks->nranks;
     135             :         const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
     136             :                                (trans_dist) ? coord_free_idx : coord_sum_idx};
     137             :         const int rank = dbm_mpi_cart_rank(comm, coords);
     138             :         const int row_size = matrix->row_sizes[blk->row];
     139             :         const int col_size = matrix->col_sizes[blk->col];
     140             :         ndata_mythread[ipack] += row_size * col_size;
     141             :         // Create plan.
     142             :         const int iplan = --nblks_mythread[ipack];
     143             :         plans_per_pack[ipack][iplan].blk = blk;
     144             :         plans_per_pack[ipack][iplan].rank = rank;
     145             :         plans_per_pack[ipack][iplan].row_size = row_size;
     146             :         plans_per_pack[ipack][iplan].col_size = col_size;
     147             :       }
     148             :     }
     149             : #pragma omp critical
     150             :     for (int ipack = 0; ipack < npacks; ipack++) {
     151             :       ndata_per_pack[ipack] += ndata_mythread[ipack];
     152             :     }
     153             :   } // end of omp parallel region
     154      403926 : }
     155             : 
     156             : /*******************************************************************************
     157             :  * \brief Private routine for filling send buffers.
     158             :  * \author Ole Schuett
     159             :  ******************************************************************************/
     160      425726 : static void fill_send_buffers(
     161             :     const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
     162             :     const int ndata_send, plan_t plans[nblks_send], const int nranks,
     163             :     int blks_send_count[nranks], int data_send_count[nranks],
     164             :     int blks_send_displ[nranks], int data_send_displ[nranks],
     165             :     dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
     166             : 
     167      425726 :   memset(blks_send_count, 0, nranks * sizeof(int));
     168      425726 :   memset(data_send_count, 0, nranks * sizeof(int));
     169             : 
     170      425726 : #pragma omp parallel
     171             :   {
     172             :     // 3th pass: Compute per rank nblks and ndata.
     173             :     int nblks_mythread[nranks], ndata_mythread[nranks];
     174             :     memset(nblks_mythread, 0, nranks * sizeof(int));
     175             :     memset(ndata_mythread, 0, nranks * sizeof(int));
     176             : #pragma omp for schedule(static)
     177             :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     178             :       const plan_t *plan = &plans[iblock];
     179             :       nblks_mythread[plan->rank] += 1;
     180             :       ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
     181             :     }
     182             : 
     183             :     // Sum nblks and ndata across threads.
     184             : #pragma omp critical
     185             :     for (int irank = 0; irank < nranks; irank++) {
     186             :       blks_send_count[irank] += nblks_mythread[irank];
     187             :       data_send_count[irank] += ndata_mythread[irank];
     188             :       nblks_mythread[irank] = blks_send_count[irank];
     189             :       ndata_mythread[irank] = data_send_count[irank];
     190             :     }
     191             : #pragma omp barrier
     192             : 
     193             :     // Compute send displacements.
     194             : #pragma omp master
     195             :     {
     196             :       icumsum(nranks, blks_send_count, blks_send_displ);
     197             :       icumsum(nranks, data_send_count, data_send_displ);
     198             :       const int m = nranks - 1;
     199             :       assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
     200             :       assert(ndata_send == data_send_displ[m] + data_send_count[m]);
     201             :     }
     202             : #pragma omp barrier
     203             : 
     204             :     // 4th pass: Fill blks_send and data_send arrays.
     205             : #pragma omp for schedule(static) // Need static to match previous loop.
     206             :     for (int iblock = 0; iblock < nblks_send; iblock++) {
     207             :       const plan_t *plan = &plans[iblock];
     208             :       const dbm_block_t *blk = plan->blk;
     209             :       const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
     210             :       const dbm_shard_t *shard = &matrix->shards[ishard];
     211             :       const double *blk_data = &shard->data[blk->offset];
     212             :       const int row_size = plan->row_size, col_size = plan->col_size;
     213             :       const int irank = plan->rank;
     214             : 
     215             :       // The blk_send_data is ordered by rank, thread, and block.
     216             :       //   data_send_displ[irank]: Start of data for irank within blk_send_data.
     217             :       //   ndata_mythread[irank]: Current threads offset within data for irank.
     218             :       nblks_mythread[irank] -= 1;
     219             :       ndata_mythread[irank] -= row_size * col_size;
     220             :       const int offset = data_send_displ[irank] + ndata_mythread[irank];
     221             :       const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
     222             : 
     223             :       double norm = 0.0; // Compute norm as double...
     224             :       if (trans_matrix) {
     225             :         // Transpose block to allow for outer-product style multiplication.
     226             :         for (int i = 0; i < row_size; i++) {
     227             :           for (int j = 0; j < col_size; j++) {
     228             :             const double element = blk_data[j * row_size + i];
     229             :             norm += element * element;
     230             :             data_send[offset + i * col_size + j] = element;
     231             :           }
     232             :         }
     233             :         blks_send[jblock].free_index = plan->blk->col;
     234             :         blks_send[jblock].sum_index = plan->blk->row;
     235             :       } else {
     236             :         for (int i = 0; i < row_size * col_size; i++) {
     237             :           const double element = blk_data[i];
     238             :           norm += element * element;
     239             :           data_send[offset + i] = element;
     240             :         }
     241             :         blks_send[jblock].free_index = plan->blk->row;
     242             :         blks_send[jblock].sum_index = plan->blk->col;
     243             :       }
     244             :       blks_send[jblock].norm = (float)norm; // ...store norm as float.
     245             : 
     246             :       // After the block exchange data_recv_displ will be added to the offsets.
     247             :       blks_send[jblock].offset = offset - data_send_displ[irank];
     248             :     }
     249             :   } // end of omp parallel region
     250      425726 : }
     251             : 
     252             : /*******************************************************************************
     253             :  * \brief Private comperator passed to qsort to compare two blocks by sum_index.
     254             :  * \author Ole Schuett
     255             :  ******************************************************************************/
     256    78196695 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
     257    78196695 :   const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
     258    78196695 :   const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
     259    78196695 :   return blk_a->sum_index - blk_b->sum_index;
     260             : }
     261             : 
     262             : /*******************************************************************************
     263             :  * \brief Private routine for post-processing received blocks.
     264             :  * \author Ole Schuett
     265             :  ******************************************************************************/
     266      425726 : static void postprocess_received_blocks(
     267             :     const int nranks, const int nshards, const int nblocks_recv,
     268             :     const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
     269             :     const int data_recv_displ[nranks],
     270      425726 :     dbm_pack_block_t blks_recv[nblocks_recv]) {
     271             : 
     272      425726 :   int nblocks_per_shard[nshards], shard_start[nshards];
     273      425726 :   memset(nblocks_per_shard, 0, nshards * sizeof(int));
     274      425726 :   dbm_pack_block_t *blocks_tmp =
     275      425726 :       malloc(nblocks_recv * sizeof(dbm_pack_block_t));
     276      425726 :   assert(blocks_tmp != NULL);
     277             : 
     278      425726 : #pragma omp parallel
     279             :   {
     280             :     // Add data_recv_displ to recveived block offsets.
     281             :     for (int irank = 0; irank < nranks; irank++) {
     282             : #pragma omp for
     283             :       for (int i = 0; i < blks_recv_count[irank]; i++) {
     284             :         blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
     285             :       }
     286             :     }
     287             : 
     288             :     // First use counting sort to group blocks by their free_index shard.
     289             :     int nblocks_mythread[nshards];
     290             :     memset(nblocks_mythread, 0, nshards * sizeof(int));
     291             : #pragma omp for schedule(static)
     292             :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     293             :       blocks_tmp[iblock] = blks_recv[iblock];
     294             :       const int ishard = blks_recv[iblock].free_index % nshards;
     295             :       nblocks_mythread[ishard]++;
     296             :     }
     297             : #pragma omp critical
     298             :     for (int ishard = 0; ishard < nshards; ishard++) {
     299             :       nblocks_per_shard[ishard] += nblocks_mythread[ishard];
     300             :       nblocks_mythread[ishard] = nblocks_per_shard[ishard];
     301             :     }
     302             : #pragma omp barrier
     303             : #pragma omp master
     304             :     icumsum(nshards, nblocks_per_shard, shard_start);
     305             : #pragma omp barrier
     306             : #pragma omp for schedule(static) // Need static to match previous loop.
     307             :     for (int iblock = 0; iblock < nblocks_recv; iblock++) {
     308             :       const int ishard = blocks_tmp[iblock].free_index % nshards;
     309             :       const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
     310             :       blks_recv[jblock] = blocks_tmp[iblock];
     311             :     }
     312             : 
     313             :     // Then sort blocks within each shard by their sum_index.
     314             : #pragma omp for
     315             :     for (int ishard = 0; ishard < nshards; ishard++) {
     316             :       if (nblocks_per_shard[ishard] > 1) {
     317             :         qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
     318             :               sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
     319             :       }
     320             :     }
     321             :   } // end of omp parallel region
     322             : 
     323      425726 :   free(blocks_tmp);
     324      425726 : }
     325             : 
     326             : /*******************************************************************************
     327             :  * \brief Private routine for redistributing a matrix along selected dimensions.
     328             :  * \author Ole Schuett
     329             :  ******************************************************************************/
     330      403926 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
     331             :                                        const bool trans_dist,
     332             :                                        const dbm_matrix_t *matrix,
     333             :                                        const dbm_distribution_t *dist,
     334      403926 :                                        const int nticks) {
     335             : 
     336      403926 :   assert(dbm_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
     337             : 
     338             :   // The row/col indicies are distributed along one cart dimension and the
     339             :   // ticks are distributed along the other cart dimension.
     340      403926 :   const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
     341      403926 :   const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
     342             : 
     343             :   // Allocate packed matrix.
     344      403926 :   const int nsend_packs = nticks / dist_ticks->nranks;
     345      403926 :   assert(nsend_packs * dist_ticks->nranks == nticks);
     346      403926 :   dbm_packed_matrix_t packed;
     347      403926 :   packed.dist_indices = dist_indices;
     348      403926 :   packed.dist_ticks = dist_ticks;
     349      403926 :   packed.nsend_packs = nsend_packs;
     350      403926 :   packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
     351      403926 :   assert(packed.send_packs != NULL);
     352             : 
     353             :   // Plan all packs.
     354      403926 :   plan_t *plans_per_pack[nsend_packs];
     355      403926 :   int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
     356      403926 :   create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
     357             :                     dist_ticks, nticks, nsend_packs, plans_per_pack,
     358             :                     nblks_send_per_pack, ndata_send_per_pack);
     359             : 
     360             :   // Allocate send buffers for maximum number of blocks/data over all packs.
     361      403926 :   int nblks_send_max = 0, ndata_send_max = 0;
     362      829652 :   for (int ipack = 0; ipack < nsend_packs; ++ipack) {
     363      425726 :     nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
     364      425726 :     ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
     365             :   }
     366      403926 :   dbm_pack_block_t *blks_send =
     367      403926 :       dbm_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
     368      403926 :   double *data_send = dbm_mempool_host_malloc(ndata_send_max * sizeof(double));
     369             : 
     370             :   // Cannot parallelize over packs (there might be too few of them).
     371      829652 :   for (int ipack = 0; ipack < nsend_packs; ipack++) {
     372             :     // Fill send buffers according to plans.
     373      425726 :     const int nranks = dist->nranks;
     374      425726 :     int blks_send_count[nranks], data_send_count[nranks];
     375      425726 :     int blks_send_displ[nranks], data_send_displ[nranks];
     376      425726 :     fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
     377             :                       ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
     378             :                       blks_send_count, data_send_count, blks_send_displ,
     379             :                       data_send_displ, blks_send, data_send);
     380      425726 :     free(plans_per_pack[ipack]);
     381             : 
     382             :     // 1st communication: Exchange block counts.
     383      425726 :     int blks_recv_count[nranks], blks_recv_displ[nranks];
     384      425726 :     dbm_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
     385      425726 :     icumsum(nranks, blks_recv_count, blks_recv_displ);
     386      425726 :     const int nblocks_recv = isum(nranks, blks_recv_count);
     387             : 
     388             :     // 2nd communication: Exchange blocks.
     389      425726 :     dbm_pack_block_t *blks_recv =
     390      425726 :         dbm_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
     391      425726 :     int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
     392      425726 :     int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
     393      916852 :     for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
     394      491126 :       blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
     395      491126 :       blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
     396      491126 :       blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
     397      491126 :       blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
     398             :     }
     399      425726 :     dbm_mpi_alltoallv_byte(
     400             :         blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
     401      425726 :         blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
     402             : 
     403             :     // 3rd communication: Exchange data counts.
     404             :     // TODO: could be computed from blks_recv.
     405      425726 :     int data_recv_count[nranks], data_recv_displ[nranks];
     406      425726 :     dbm_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
     407      425726 :     icumsum(nranks, data_recv_count, data_recv_displ);
     408      425726 :     const int ndata_recv = isum(nranks, data_recv_count);
     409             : 
     410             :     // 4th communication: Exchange data.
     411      425726 :     double *data_recv = dbm_mempool_host_malloc(ndata_recv * sizeof(double));
     412      425726 :     dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
     413             :                              data_recv, data_recv_count, data_recv_displ,
     414      425726 :                              dist->comm);
     415             : 
     416             :     // Post-process received blocks and assemble them into a pack.
     417      425726 :     postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
     418             :                                 blks_recv_count, blks_recv_displ,
     419             :                                 data_recv_displ, blks_recv);
     420      425726 :     packed.send_packs[ipack].nblocks = nblocks_recv;
     421      425726 :     packed.send_packs[ipack].data_size = ndata_recv;
     422      425726 :     packed.send_packs[ipack].blocks = blks_recv;
     423      425726 :     packed.send_packs[ipack].data = data_recv;
     424             :   }
     425             : 
     426             :   // Deallocate send buffers.
     427      403926 :   dbm_mpi_free_mem(blks_send);
     428      403926 :   dbm_mempool_free(data_send);
     429             : 
     430             :   // Allocate pack_recv.
     431      403926 :   int max_nblocks = 0, max_data_size = 0;
     432      829652 :   for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
     433      425726 :     max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
     434      425726 :     max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
     435             :   }
     436      403926 :   dbm_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
     437      403926 :   dbm_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
     438      403926 :   packed.max_nblocks = max_nblocks;
     439      403926 :   packed.max_data_size = max_data_size;
     440      807852 :   packed.recv_pack.blocks =
     441      403926 :       dbm_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
     442      807852 :   packed.recv_pack.data =
     443      403926 :       dbm_mempool_host_malloc(packed.max_data_size * sizeof(double));
     444             : 
     445      403926 :   return packed; // Ownership of packed transfers to caller.
     446             : }
     447             : 
     448             : /*******************************************************************************
     449             :  * \brief Private routine for sending and receiving the pack for the given tick.
     450             :  * \author Ole Schuett
     451             :  ******************************************************************************/
     452      447526 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
     453             :                                  dbm_packed_matrix_t *packed) {
     454      447526 :   const int nranks = packed->dist_ticks->nranks;
     455      447526 :   const int my_rank = packed->dist_ticks->my_rank;
     456             : 
     457             :   // Compute send rank and pack.
     458      447526 :   const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
     459      447526 :   const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
     460      447526 :   const int send_itick = (itick_of_rank0 + send_rank) % nticks;
     461      447526 :   const int send_ipack = send_itick / nranks;
     462      447526 :   assert(send_itick % nranks == my_rank);
     463             : 
     464             :   // Compute receive rank and pack.
     465      447526 :   const int recv_rank = itick % nranks;
     466      447526 :   const int recv_ipack = itick / nranks;
     467             : 
     468      447526 :   if (send_rank == my_rank) {
     469      425726 :     assert(send_rank == recv_rank && send_ipack == recv_ipack);
     470      425726 :     return &packed->send_packs[send_ipack]; // Local pack, no mpi needed.
     471             :   } else {
     472       21800 :     const dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
     473             : 
     474             :     // Exchange blocks.
     475       43600 :     const int nblocks_in_bytes = dbm_mpi_sendrecv_byte(
     476       21800 :         /*sendbuf=*/send_pack->blocks,
     477       21800 :         /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
     478             :         /*dest=*/send_rank,
     479             :         /*sendtag=*/send_ipack,
     480       21800 :         /*recvbuf=*/packed->recv_pack.blocks,
     481       21800 :         /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
     482             :         /*source=*/recv_rank,
     483             :         /*recvtag=*/recv_ipack,
     484       21800 :         /*comm=*/packed->dist_ticks->comm);
     485             : 
     486       21800 :     assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
     487       21800 :     packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
     488             : 
     489             :     // Exchange data.
     490       43600 :     packed->recv_pack.data_size = dbm_mpi_sendrecv_double(
     491       21800 :         /*sendbuf=*/send_pack->data,
     492       21800 :         /*sendcound=*/send_pack->data_size,
     493             :         /*dest=*/send_rank,
     494             :         /*sendtag=*/send_ipack,
     495             :         /*recvbuf=*/packed->recv_pack.data,
     496             :         /*recvcount=*/packed->max_data_size,
     497             :         /*source=*/recv_rank,
     498             :         /*recvtag=*/recv_ipack,
     499       21800 :         /*comm=*/packed->dist_ticks->comm);
     500             : 
     501       21800 :     return &packed->recv_pack;
     502             :   }
     503             : }
     504             : 
     505             : /*******************************************************************************
     506             :  * \brief Private routine for releasing a packed matrix.
     507             :  * \author Ole Schuett
     508             :  ******************************************************************************/
     509      403926 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
     510      403926 :   dbm_mpi_free_mem(packed->recv_pack.blocks);
     511      403926 :   dbm_mempool_free(packed->recv_pack.data);
     512      829652 :   for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
     513      425726 :     dbm_mpi_free_mem(packed->send_packs[ipack].blocks);
     514      425726 :     dbm_mempool_free(packed->send_packs[ipack].data);
     515             :   }
     516      403926 :   free(packed->send_packs);
     517      403926 : }
     518             : 
     519             : /*******************************************************************************
     520             :  * \brief Internal routine for creating a communication iterator.
     521             :  * \author Ole Schuett
     522             :  ******************************************************************************/
     523      201963 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
     524             :                                              const bool transb,
     525             :                                              const dbm_matrix_t *matrix_a,
     526             :                                              const dbm_matrix_t *matrix_b,
     527             :                                              const dbm_matrix_t *matrix_c) {
     528             : 
     529      201963 :   dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
     530      201963 :   assert(iter != NULL);
     531      201963 :   iter->dist = matrix_c->dist;
     532             : 
     533             :   // During each communication tick we'll fetch a pack_a and pack_b.
     534             :   // Since the cart might be non-squared, the number of communication ticks is
     535             :   // chosen as the least common multiple of the cart's dimensions.
     536      201963 :   iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
     537      201963 :   iter->itick = 0;
     538             : 
     539             :   // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
     540      201963 :   iter->packed_a =
     541      201963 :       pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
     542      201963 :   iter->packed_b =
     543      201963 :       pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
     544             : 
     545      201963 :   return iter;
     546             : }
     547             : 
     548             : /*******************************************************************************
     549             :  * \brief Internal routine for retriving next pair of packs from given iterator.
     550             :  * \author Ole Schuett
     551             :  ******************************************************************************/
     552      425726 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
     553             :                             dbm_pack_t **pack_b) {
     554      425726 :   if (iter->itick >= iter->nticks) {
     555             :     return false; // end of iterator reached
     556             :   }
     557             : 
     558             :   // Start each rank at a different tick to spread the load on the sources.
     559      223763 :   const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
     560      223763 :   const int shifted_itick = (iter->itick + shift) % iter->nticks;
     561      223763 :   *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
     562      223763 :   *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
     563             : 
     564      223763 :   iter->itick++;
     565      223763 :   return true;
     566             : }
     567             : 
     568             : /*******************************************************************************
     569             :  * \brief Internal routine for releasing the given communication iterator.
     570             :  * \author Ole Schuett
     571             :  ******************************************************************************/
     572      201963 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
     573      201963 :   free_packed_matrix(&iter->packed_a);
     574      201963 :   free_packed_matrix(&iter->packed_b);
     575      201963 :   free(iter);
     576      201963 : }
     577             : 
     578             : // EOF

Generated by: LCOV version 1.15