Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include "dbm_multiply_comm.h"
9 :
10 : #include <assert.h>
11 : #include <stdlib.h>
12 : #include <string.h>
13 :
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_mempool.h"
16 : #include "dbm_mpi.h"
17 :
18 : /*******************************************************************************
19 : * \brief Returns the larger of two given integer (missing from the C standard)
20 : * \author Ole Schuett
21 : ******************************************************************************/
22 855772 : static inline int imax(int x, int y) { return (x > y ? x : y); }
23 :
24 : /*******************************************************************************
25 : * \brief Private routine for computing greatest common divisor of two numbers.
26 : * \author Ole Schuett
27 : ******************************************************************************/
28 427534 : static int gcd(const int a, const int b) {
29 427534 : if (a == 0)
30 : return b;
31 224491 : return gcd(b % a, a); // Euclid's algorithm.
32 : }
33 :
34 : /*******************************************************************************
35 : * \brief Private routine for computing least common multiple of two numbers.
36 : * \author Ole Schuett
37 : ******************************************************************************/
38 203043 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
39 :
40 : /*******************************************************************************
41 : * \brief Private routine for computing the sum of the given integers.
42 : * \author Ole Schuett
43 : ******************************************************************************/
44 855772 : static inline int isum(const int n, const int input[n]) {
45 855772 : int output = 0;
46 1842344 : for (int i = 0; i < n; i++) {
47 986572 : output += input[i];
48 : }
49 855772 : return output;
50 : }
51 :
52 : /*******************************************************************************
53 : * \brief Private routine for computing the cumulative sums of given numbers.
54 : * \author Ole Schuett
55 : ******************************************************************************/
56 2139430 : static inline void icumsum(const int n, const int input[n], int output[n]) {
57 2139430 : output[0] = 0;
58 2401030 : for (int i = 1; i < n; i++) {
59 261600 : output[i] = output[i - 1] + input[i - 1];
60 : }
61 2139430 : }
62 :
63 : /*******************************************************************************
64 : * \brief Private struct used for planing during pack_matrix.
65 : * \author Ole Schuett
66 : ******************************************************************************/
67 : typedef struct {
68 : const dbm_block_t *blk; // source block
69 : int rank; // target mpi rank
70 : int row_size;
71 : int col_size;
72 : } plan_t;
73 :
74 : /*******************************************************************************
75 : * \brief Private routine for planing packs.
76 : * \author Ole Schuett
77 : ******************************************************************************/
78 406086 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
79 : const dbm_matrix_t *matrix,
80 : const dbm_mpi_comm_t comm,
81 : const dbm_dist_1d_t *dist_indices,
82 : const dbm_dist_1d_t *dist_ticks, const int nticks,
83 : const int npacks, plan_t *plans_per_pack[npacks],
84 : int nblks_per_pack[npacks],
85 : int ndata_per_pack[npacks]) {
86 :
87 406086 : memset(nblks_per_pack, 0, npacks * sizeof(int));
88 406086 : memset(ndata_per_pack, 0, npacks * sizeof(int));
89 :
90 406086 : #pragma omp parallel
91 : {
92 : // 1st pass: Compute number of blocks that will be send in each pack.
93 : int nblks_mythread[npacks];
94 : memset(nblks_mythread, 0, npacks * sizeof(int));
95 : #pragma omp for schedule(static)
96 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
97 : dbm_shard_t *shard = &matrix->shards[ishard];
98 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
99 : const dbm_block_t *blk = &shard->blocks[iblock];
100 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
101 : const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
102 : const int ipack = itick / dist_ticks->nranks;
103 : nblks_mythread[ipack]++;
104 : }
105 : }
106 :
107 : // Sum nblocks across threads and allocate arrays for plans.
108 : #pragma omp critical
109 : for (int ipack = 0; ipack < npacks; ipack++) {
110 : nblks_per_pack[ipack] += nblks_mythread[ipack];
111 : nblks_mythread[ipack] = nblks_per_pack[ipack];
112 : }
113 : #pragma omp barrier
114 : #pragma omp for
115 : for (int ipack = 0; ipack < npacks; ipack++) {
116 : plans_per_pack[ipack] = malloc(nblks_per_pack[ipack] * sizeof(plan_t));
117 : assert(plans_per_pack[ipack] != NULL);
118 : }
119 :
120 : // 2nd pass: Plan where to send each block.
121 : int ndata_mythread[npacks];
122 : memset(ndata_mythread, 0, npacks * sizeof(int));
123 : #pragma omp for schedule(static) // Need static to match previous loop.
124 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
125 : dbm_shard_t *shard = &matrix->shards[ishard];
126 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
127 : const dbm_block_t *blk = &shard->blocks[iblock];
128 : const int free_index = (trans_matrix) ? blk->col : blk->row;
129 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
130 : const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
131 : const int ipack = itick / dist_ticks->nranks;
132 : // Compute rank to which this block should be sent.
133 : const int coord_free_idx = dist_indices->index2coord[free_index];
134 : const int coord_sum_idx = itick % dist_ticks->nranks;
135 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
136 : (trans_dist) ? coord_free_idx : coord_sum_idx};
137 : const int rank = dbm_mpi_cart_rank(comm, coords);
138 : const int row_size = matrix->row_sizes[blk->row];
139 : const int col_size = matrix->col_sizes[blk->col];
140 : ndata_mythread[ipack] += row_size * col_size;
141 : // Create plan.
142 : const int iplan = --nblks_mythread[ipack];
143 : plans_per_pack[ipack][iplan].blk = blk;
144 : plans_per_pack[ipack][iplan].rank = rank;
145 : plans_per_pack[ipack][iplan].row_size = row_size;
146 : plans_per_pack[ipack][iplan].col_size = col_size;
147 : }
148 : }
149 : #pragma omp critical
150 : for (int ipack = 0; ipack < npacks; ipack++) {
151 : ndata_per_pack[ipack] += ndata_mythread[ipack];
152 : }
153 : } // end of omp parallel region
154 406086 : }
155 :
156 : /*******************************************************************************
157 : * \brief Private routine for filling send buffers.
158 : * \author Ole Schuett
159 : ******************************************************************************/
160 427886 : static void fill_send_buffers(
161 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
162 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
163 : int blks_send_count[nranks], int data_send_count[nranks],
164 : int blks_send_displ[nranks], int data_send_displ[nranks],
165 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
166 :
167 427886 : memset(blks_send_count, 0, nranks * sizeof(int));
168 427886 : memset(data_send_count, 0, nranks * sizeof(int));
169 :
170 427886 : #pragma omp parallel
171 : {
172 : // 3th pass: Compute per rank nblks and ndata.
173 : int nblks_mythread[nranks], ndata_mythread[nranks];
174 : memset(nblks_mythread, 0, nranks * sizeof(int));
175 : memset(ndata_mythread, 0, nranks * sizeof(int));
176 : #pragma omp for schedule(static)
177 : for (int iblock = 0; iblock < nblks_send; iblock++) {
178 : const plan_t *plan = &plans[iblock];
179 : nblks_mythread[plan->rank] += 1;
180 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
181 : }
182 :
183 : // Sum nblks and ndata across threads.
184 : #pragma omp critical
185 : for (int irank = 0; irank < nranks; irank++) {
186 : blks_send_count[irank] += nblks_mythread[irank];
187 : data_send_count[irank] += ndata_mythread[irank];
188 : nblks_mythread[irank] = blks_send_count[irank];
189 : ndata_mythread[irank] = data_send_count[irank];
190 : }
191 : #pragma omp barrier
192 :
193 : // Compute send displacements.
194 : #pragma omp master
195 : {
196 : icumsum(nranks, blks_send_count, blks_send_displ);
197 : icumsum(nranks, data_send_count, data_send_displ);
198 : const int m = nranks - 1;
199 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
200 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
201 : }
202 : #pragma omp barrier
203 :
204 : // 4th pass: Fill blks_send and data_send arrays.
205 : #pragma omp for schedule(static) // Need static to match previous loop.
206 : for (int iblock = 0; iblock < nblks_send; iblock++) {
207 : const plan_t *plan = &plans[iblock];
208 : const dbm_block_t *blk = plan->blk;
209 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
210 : const dbm_shard_t *shard = &matrix->shards[ishard];
211 : const double *blk_data = &shard->data[blk->offset];
212 : const int row_size = plan->row_size, col_size = plan->col_size;
213 : const int irank = plan->rank;
214 :
215 : // The blk_send_data is ordered by rank, thread, and block.
216 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
217 : // ndata_mythread[irank]: Current threads offset within data for irank.
218 : nblks_mythread[irank] -= 1;
219 : ndata_mythread[irank] -= row_size * col_size;
220 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
221 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
222 :
223 : double norm = 0.0; // Compute norm as double...
224 : if (trans_matrix) {
225 : // Transpose block to allow for outer-product style multiplication.
226 : for (int i = 0; i < row_size; i++) {
227 : for (int j = 0; j < col_size; j++) {
228 : const double element = blk_data[j * row_size + i];
229 : norm += element * element;
230 : data_send[offset + i * col_size + j] = element;
231 : }
232 : }
233 : blks_send[jblock].free_index = plan->blk->col;
234 : blks_send[jblock].sum_index = plan->blk->row;
235 : } else {
236 : for (int i = 0; i < row_size * col_size; i++) {
237 : const double element = blk_data[i];
238 : norm += element * element;
239 : data_send[offset + i] = element;
240 : }
241 : blks_send[jblock].free_index = plan->blk->row;
242 : blks_send[jblock].sum_index = plan->blk->col;
243 : }
244 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
245 :
246 : // After the block exchange data_recv_displ will be added to the offsets.
247 : blks_send[jblock].offset = offset - data_send_displ[irank];
248 : }
249 : } // end of omp parallel region
250 427886 : }
251 :
252 : /*******************************************************************************
253 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
254 : * \author Ole Schuett
255 : ******************************************************************************/
256 78223196 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
257 78223196 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
258 78223196 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
259 78223196 : return blk_a->sum_index - blk_b->sum_index;
260 : }
261 :
262 : /*******************************************************************************
263 : * \brief Private routine for post-processing received blocks.
264 : * \author Ole Schuett
265 : ******************************************************************************/
266 427886 : static void postprocess_received_blocks(
267 : const int nranks, const int nshards, const int nblocks_recv,
268 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
269 : const int data_recv_displ[nranks],
270 427886 : dbm_pack_block_t blks_recv[nblocks_recv]) {
271 :
272 427886 : int nblocks_per_shard[nshards], shard_start[nshards];
273 427886 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
274 427886 : dbm_pack_block_t *blocks_tmp =
275 427886 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
276 427886 : assert(blocks_tmp != NULL);
277 :
278 427886 : #pragma omp parallel
279 : {
280 : // Add data_recv_displ to recveived block offsets.
281 : for (int irank = 0; irank < nranks; irank++) {
282 : #pragma omp for
283 : for (int i = 0; i < blks_recv_count[irank]; i++) {
284 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
285 : }
286 : }
287 :
288 : // First use counting sort to group blocks by their free_index shard.
289 : int nblocks_mythread[nshards];
290 : memset(nblocks_mythread, 0, nshards * sizeof(int));
291 : #pragma omp for schedule(static)
292 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
293 : blocks_tmp[iblock] = blks_recv[iblock];
294 : const int ishard = blks_recv[iblock].free_index % nshards;
295 : nblocks_mythread[ishard]++;
296 : }
297 : #pragma omp critical
298 : for (int ishard = 0; ishard < nshards; ishard++) {
299 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
300 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
301 : }
302 : #pragma omp barrier
303 : #pragma omp master
304 : icumsum(nshards, nblocks_per_shard, shard_start);
305 : #pragma omp barrier
306 : #pragma omp for schedule(static) // Need static to match previous loop.
307 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
308 : const int ishard = blocks_tmp[iblock].free_index % nshards;
309 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
310 : blks_recv[jblock] = blocks_tmp[iblock];
311 : }
312 :
313 : // Then sort blocks within each shard by their sum_index.
314 : #pragma omp for
315 : for (int ishard = 0; ishard < nshards; ishard++) {
316 : if (nblocks_per_shard[ishard] > 1) {
317 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
318 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
319 : }
320 : }
321 : } // end of omp parallel region
322 :
323 427886 : free(blocks_tmp);
324 427886 : }
325 :
326 : /*******************************************************************************
327 : * \brief Private routine for redistributing a matrix along selected dimensions.
328 : * \author Ole Schuett
329 : ******************************************************************************/
330 406086 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
331 : const bool trans_dist,
332 : const dbm_matrix_t *matrix,
333 : const dbm_distribution_t *dist,
334 406086 : const int nticks) {
335 :
336 406086 : assert(dbm_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
337 :
338 : // The row/col indicies are distributed along one cart dimension and the
339 : // ticks are distributed along the other cart dimension.
340 406086 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
341 406086 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
342 :
343 : // Allocate packed matrix.
344 406086 : const int nsend_packs = nticks / dist_ticks->nranks;
345 406086 : assert(nsend_packs * dist_ticks->nranks == nticks);
346 406086 : dbm_packed_matrix_t packed;
347 406086 : packed.dist_indices = dist_indices;
348 406086 : packed.dist_ticks = dist_ticks;
349 406086 : packed.nsend_packs = nsend_packs;
350 406086 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
351 406086 : assert(packed.send_packs != NULL);
352 :
353 : // Plan all packs.
354 406086 : plan_t *plans_per_pack[nsend_packs];
355 406086 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
356 406086 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
357 : dist_ticks, nticks, nsend_packs, plans_per_pack,
358 : nblks_send_per_pack, ndata_send_per_pack);
359 :
360 : // Allocate send buffers for maximum number of blocks/data over all packs.
361 406086 : int nblks_send_max = 0, ndata_send_max = 0;
362 833972 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
363 427886 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
364 427886 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
365 : }
366 406086 : dbm_pack_block_t *blks_send =
367 406086 : dbm_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
368 406086 : double *data_send = dbm_mempool_host_malloc(ndata_send_max * sizeof(double));
369 :
370 : // Cannot parallelize over packs (there might be too few of them).
371 833972 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
372 : // Fill send buffers according to plans.
373 427886 : const int nranks = dist->nranks;
374 427886 : int blks_send_count[nranks], data_send_count[nranks];
375 427886 : int blks_send_displ[nranks], data_send_displ[nranks];
376 427886 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
377 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
378 : blks_send_count, data_send_count, blks_send_displ,
379 : data_send_displ, blks_send, data_send);
380 427886 : free(plans_per_pack[ipack]);
381 :
382 : // 1st communication: Exchange block counts.
383 427886 : int blks_recv_count[nranks], blks_recv_displ[nranks];
384 427886 : dbm_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
385 427886 : icumsum(nranks, blks_recv_count, blks_recv_displ);
386 427886 : const int nblocks_recv = isum(nranks, blks_recv_count);
387 :
388 : // 2nd communication: Exchange blocks.
389 427886 : dbm_pack_block_t *blks_recv =
390 427886 : dbm_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
391 427886 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
392 427886 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
393 921172 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
394 493286 : blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
395 493286 : blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
396 493286 : blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
397 493286 : blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
398 : }
399 427886 : dbm_mpi_alltoallv_byte(
400 : blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
401 427886 : blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
402 :
403 : // 3rd communication: Exchange data counts.
404 : // TODO: could be computed from blks_recv.
405 427886 : int data_recv_count[nranks], data_recv_displ[nranks];
406 427886 : dbm_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
407 427886 : icumsum(nranks, data_recv_count, data_recv_displ);
408 427886 : const int ndata_recv = isum(nranks, data_recv_count);
409 :
410 : // 4th communication: Exchange data.
411 427886 : double *data_recv = dbm_mempool_host_malloc(ndata_recv * sizeof(double));
412 427886 : dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
413 : data_recv, data_recv_count, data_recv_displ,
414 427886 : dist->comm);
415 :
416 : // Post-process received blocks and assemble them into a pack.
417 427886 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
418 : blks_recv_count, blks_recv_displ,
419 : data_recv_displ, blks_recv);
420 427886 : packed.send_packs[ipack].nblocks = nblocks_recv;
421 427886 : packed.send_packs[ipack].data_size = ndata_recv;
422 427886 : packed.send_packs[ipack].blocks = blks_recv;
423 427886 : packed.send_packs[ipack].data = data_recv;
424 : }
425 :
426 : // Deallocate send buffers.
427 406086 : dbm_mpi_free_mem(blks_send);
428 406086 : dbm_mempool_free(data_send);
429 :
430 : // Allocate pack_recv.
431 406086 : int max_nblocks = 0, max_data_size = 0;
432 833972 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
433 427886 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
434 427886 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
435 : }
436 406086 : dbm_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
437 406086 : dbm_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
438 406086 : packed.max_nblocks = max_nblocks;
439 406086 : packed.max_data_size = max_data_size;
440 812172 : packed.recv_pack.blocks =
441 406086 : dbm_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
442 812172 : packed.recv_pack.data =
443 406086 : dbm_mempool_host_malloc(packed.max_data_size * sizeof(double));
444 :
445 406086 : return packed; // Ownership of packed transfers to caller.
446 : }
447 :
448 : /*******************************************************************************
449 : * \brief Private routine for sending and receiving the pack for the given tick.
450 : * \author Ole Schuett
451 : ******************************************************************************/
452 449686 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
453 : dbm_packed_matrix_t *packed) {
454 449686 : const int nranks = packed->dist_ticks->nranks;
455 449686 : const int my_rank = packed->dist_ticks->my_rank;
456 :
457 : // Compute send rank and pack.
458 449686 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
459 449686 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
460 449686 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
461 449686 : const int send_ipack = send_itick / nranks;
462 449686 : assert(send_itick % nranks == my_rank);
463 :
464 : // Compute receive rank and pack.
465 449686 : const int recv_rank = itick % nranks;
466 449686 : const int recv_ipack = itick / nranks;
467 :
468 449686 : if (send_rank == my_rank) {
469 427886 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
470 427886 : return &packed->send_packs[send_ipack]; // Local pack, no mpi needed.
471 : } else {
472 21800 : const dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
473 :
474 : // Exchange blocks.
475 43600 : const int nblocks_in_bytes = dbm_mpi_sendrecv_byte(
476 21800 : /*sendbuf=*/send_pack->blocks,
477 21800 : /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
478 : /*dest=*/send_rank,
479 : /*sendtag=*/send_ipack,
480 21800 : /*recvbuf=*/packed->recv_pack.blocks,
481 21800 : /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
482 : /*source=*/recv_rank,
483 : /*recvtag=*/recv_ipack,
484 21800 : /*comm=*/packed->dist_ticks->comm);
485 :
486 21800 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
487 21800 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
488 :
489 : // Exchange data.
490 43600 : packed->recv_pack.data_size = dbm_mpi_sendrecv_double(
491 21800 : /*sendbuf=*/send_pack->data,
492 21800 : /*sendcound=*/send_pack->data_size,
493 : /*dest=*/send_rank,
494 : /*sendtag=*/send_ipack,
495 : /*recvbuf=*/packed->recv_pack.data,
496 : /*recvcount=*/packed->max_data_size,
497 : /*source=*/recv_rank,
498 : /*recvtag=*/recv_ipack,
499 21800 : /*comm=*/packed->dist_ticks->comm);
500 :
501 21800 : return &packed->recv_pack;
502 : }
503 : }
504 :
505 : /*******************************************************************************
506 : * \brief Private routine for releasing a packed matrix.
507 : * \author Ole Schuett
508 : ******************************************************************************/
509 406086 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
510 406086 : dbm_mpi_free_mem(packed->recv_pack.blocks);
511 406086 : dbm_mempool_free(packed->recv_pack.data);
512 833972 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
513 427886 : dbm_mpi_free_mem(packed->send_packs[ipack].blocks);
514 427886 : dbm_mempool_free(packed->send_packs[ipack].data);
515 : }
516 406086 : free(packed->send_packs);
517 406086 : }
518 :
519 : /*******************************************************************************
520 : * \brief Internal routine for creating a communication iterator.
521 : * \author Ole Schuett
522 : ******************************************************************************/
523 203043 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
524 : const bool transb,
525 : const dbm_matrix_t *matrix_a,
526 : const dbm_matrix_t *matrix_b,
527 : const dbm_matrix_t *matrix_c) {
528 :
529 203043 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
530 203043 : assert(iter != NULL);
531 203043 : iter->dist = matrix_c->dist;
532 :
533 : // During each communication tick we'll fetch a pack_a and pack_b.
534 : // Since the cart might be non-squared, the number of communication ticks is
535 : // chosen as the least common multiple of the cart's dimensions.
536 203043 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
537 203043 : iter->itick = 0;
538 :
539 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
540 203043 : iter->packed_a =
541 203043 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
542 203043 : iter->packed_b =
543 203043 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
544 :
545 203043 : return iter;
546 : }
547 :
548 : /*******************************************************************************
549 : * \brief Internal routine for retriving next pair of packs from given iterator.
550 : * \author Ole Schuett
551 : ******************************************************************************/
552 427886 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
553 : dbm_pack_t **pack_b) {
554 427886 : if (iter->itick >= iter->nticks) {
555 : return false; // end of iterator reached
556 : }
557 :
558 : // Start each rank at a different tick to spread the load on the sources.
559 224843 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
560 224843 : const int shifted_itick = (iter->itick + shift) % iter->nticks;
561 224843 : *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
562 224843 : *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
563 :
564 224843 : iter->itick++;
565 224843 : return true;
566 : }
567 :
568 : /*******************************************************************************
569 : * \brief Internal routine for releasing the given communication iterator.
570 : * \author Ole Schuett
571 : ******************************************************************************/
572 203043 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
573 203043 : free_packed_matrix(&iter->packed_a);
574 203043 : free_packed_matrix(&iter->packed_b);
575 203043 : free(iter);
576 203043 : }
577 :
578 : // EOF
|