Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include "dbm_multiply_comm.h"
9 :
10 : #include <assert.h>
11 : #include <stdlib.h>
12 : #include <string.h>
13 :
14 : #include "dbm_hyperparams.h"
15 : #include "dbm_mpi.h"
16 :
17 : /*******************************************************************************
18 : * \brief Private routine for computing greatest common divisor of two numbers.
19 : * \author Ole Schuett
20 : ******************************************************************************/
21 440674 : static int gcd(const int a, const int b) {
22 440674 : if (a == 0)
23 : return b;
24 228879 : return gcd(b % a, a); // Euclid's algorithm.
25 : }
26 :
27 : /*******************************************************************************
28 : * \brief Private routine for computing least common multiple of two numbers.
29 : * \author Ole Schuett
30 : ******************************************************************************/
31 211795 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
32 :
33 : /*******************************************************************************
34 : * \brief Private routine for computing the sum of the given integers.
35 : * \author Ole Schuett
36 : ******************************************************************************/
37 882052 : static inline int isum(const int n, const int input[n]) {
38 882052 : int output = 0;
39 1868720 : for (int i = 0; i < n; i++) {
40 986668 : output += input[i];
41 : }
42 882052 : return output;
43 : }
44 :
45 : /*******************************************************************************
46 : * \brief Private routine for computing the cumulative sums of given numbers.
47 : * \author Ole Schuett
48 : ******************************************************************************/
49 2205130 : static inline void icumsum(const int n, const int input[n], int output[n]) {
50 2205130 : output[0] = 0;
51 2414362 : for (int i = 1; i < n; i++) {
52 209232 : output[i] = output[i - 1] + input[i - 1];
53 : }
54 2205130 : }
55 :
56 : /*******************************************************************************
57 : * \brief Private struct used for planing during pack_matrix.
58 : * \author Ole Schuett
59 : ******************************************************************************/
60 : typedef struct {
61 : const dbm_block_t *blk; // source block
62 : int rank; // target mpi rank
63 : int row_size;
64 : int col_size;
65 : } plan_t;
66 :
67 : /*******************************************************************************
68 : * \brief Private routine for planing packs.
69 : * \author Ole Schuett
70 : ******************************************************************************/
71 423590 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
72 : const dbm_matrix_t *matrix,
73 : const dbm_mpi_comm_t comm,
74 : const dbm_dist_1d_t *dist_indices,
75 : const dbm_dist_1d_t *dist_ticks, const int nticks,
76 : const int npacks, plan_t *plans_per_pack[npacks],
77 : int nblks_per_pack[npacks],
78 : int ndata_per_pack[npacks]) {
79 :
80 423590 : memset(nblks_per_pack, 0, npacks * sizeof(int));
81 423590 : memset(ndata_per_pack, 0, npacks * sizeof(int));
82 :
83 423590 : #pragma omp parallel
84 : {
85 : // 1st pass: Compute number of blocks that will be send in each pack.
86 : int nblks_mythread[npacks];
87 : memset(nblks_mythread, 0, npacks * sizeof(int));
88 : #pragma omp for schedule(static)
89 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
90 : dbm_shard_t *shard = &matrix->shards[ishard];
91 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
92 : const dbm_block_t *blk = &shard->blocks[iblock];
93 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
94 : const int itick = (1021 * sum_index) % nticks; // 1021 = a random prime
95 : const int ipack = itick / dist_ticks->nranks;
96 : nblks_mythread[ipack]++;
97 : }
98 : }
99 :
100 : // Sum nblocks across threads and allocate arrays for plans.
101 : #pragma omp critical
102 : for (int ipack = 0; ipack < npacks; ipack++) {
103 : nblks_per_pack[ipack] += nblks_mythread[ipack];
104 : nblks_mythread[ipack] = nblks_per_pack[ipack];
105 : }
106 : #pragma omp barrier
107 : #pragma omp for
108 : for (int ipack = 0; ipack < npacks; ipack++) {
109 : plans_per_pack[ipack] = malloc(nblks_per_pack[ipack] * sizeof(plan_t));
110 : assert(plans_per_pack[ipack] != NULL);
111 : }
112 :
113 : // 2nd pass: Plan where to send each block.
114 : int ndata_mythread[npacks];
115 : memset(ndata_mythread, 0, npacks * sizeof(int));
116 : #pragma omp for schedule(static) // Need static to match previous loop.
117 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
118 : dbm_shard_t *shard = &matrix->shards[ishard];
119 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
120 : const dbm_block_t *blk = &shard->blocks[iblock];
121 : const int free_index = (trans_matrix) ? blk->col : blk->row;
122 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
123 : const int itick = (1021 * sum_index) % nticks; // Same mapping as above.
124 : const int ipack = itick / dist_ticks->nranks;
125 : // Compute rank to which this block should be sent.
126 : const int coord_free_idx = dist_indices->index2coord[free_index];
127 : const int coord_sum_idx = itick % dist_ticks->nranks;
128 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
129 : (trans_dist) ? coord_free_idx : coord_sum_idx};
130 : const int rank = dbm_mpi_cart_rank(comm, coords);
131 : const int row_size = matrix->row_sizes[blk->row];
132 : const int col_size = matrix->col_sizes[blk->col];
133 : ndata_mythread[ipack] += row_size * col_size;
134 : // Create plan.
135 : const int iplan = --nblks_mythread[ipack];
136 : plans_per_pack[ipack][iplan].blk = blk;
137 : plans_per_pack[ipack][iplan].rank = rank;
138 : plans_per_pack[ipack][iplan].row_size = row_size;
139 : plans_per_pack[ipack][iplan].col_size = col_size;
140 : }
141 : }
142 : #pragma omp critical
143 : for (int ipack = 0; ipack < npacks; ipack++) {
144 : ndata_per_pack[ipack] += ndata_mythread[ipack];
145 : }
146 : } // end of omp parallel region
147 423590 : }
148 :
149 : /*******************************************************************************
150 : * \brief Private routine for filling send buffers.
151 : * \author Ole Schuett
152 : ******************************************************************************/
153 441026 : static void fill_send_buffers(
154 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
155 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
156 : int blks_send_count[nranks], int data_send_count[nranks],
157 : int blks_send_displ[nranks], int data_send_displ[nranks],
158 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
159 :
160 441026 : memset(blks_send_count, 0, nranks * sizeof(int));
161 441026 : memset(data_send_count, 0, nranks * sizeof(int));
162 :
163 441026 : #pragma omp parallel
164 : {
165 : // 3th pass: Compute per rank nblks and ndata.
166 : int nblks_mythread[nranks], ndata_mythread[nranks];
167 : memset(nblks_mythread, 0, nranks * sizeof(int));
168 : memset(ndata_mythread, 0, nranks * sizeof(int));
169 : #pragma omp for schedule(static)
170 : for (int iblock = 0; iblock < nblks_send; iblock++) {
171 : const plan_t *plan = &plans[iblock];
172 : nblks_mythread[plan->rank] += 1;
173 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
174 : }
175 :
176 : // Sum nblks and ndata across threads.
177 : #pragma omp critical
178 : for (int irank = 0; irank < nranks; irank++) {
179 : blks_send_count[irank] += nblks_mythread[irank];
180 : data_send_count[irank] += ndata_mythread[irank];
181 : nblks_mythread[irank] = blks_send_count[irank];
182 : ndata_mythread[irank] = data_send_count[irank];
183 : }
184 : #pragma omp barrier
185 :
186 : // Compute send displacements.
187 : #pragma omp master
188 : {
189 : icumsum(nranks, blks_send_count, blks_send_displ);
190 : icumsum(nranks, data_send_count, data_send_displ);
191 : const int m = nranks - 1;
192 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
193 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
194 : }
195 : #pragma omp barrier
196 :
197 : // 4th pass: Fill blks_send and data_send arrays.
198 : #pragma omp for schedule(static) // Need static to match previous loop.
199 : for (int iblock = 0; iblock < nblks_send; iblock++) {
200 : const plan_t *plan = &plans[iblock];
201 : const dbm_block_t *blk = plan->blk;
202 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
203 : const dbm_shard_t *shard = &matrix->shards[ishard];
204 : const double *blk_data = &shard->data[blk->offset];
205 : const int row_size = plan->row_size, col_size = plan->col_size;
206 : const int plan_size = row_size * col_size;
207 : const int irank = plan->rank;
208 :
209 : // The blk_send_data is ordered by rank, thread, and block.
210 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
211 : // ndata_mythread[irank]: Current threads offset within data for irank.
212 : nblks_mythread[irank] -= 1;
213 : ndata_mythread[irank] -= plan_size;
214 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
215 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
216 :
217 : double norm = 0.0; // Compute norm as double...
218 : if (trans_matrix) {
219 : // Transpose block to allow for outer-product style multiplication.
220 : for (int i = 0; i < row_size; i++) {
221 : for (int j = 0; j < col_size; j++) {
222 : const double element = blk_data[j * row_size + i];
223 : data_send[offset + i * col_size + j] = element;
224 : norm += element * element;
225 : }
226 : }
227 : blks_send[jblock].free_index = plan->blk->col;
228 : blks_send[jblock].sum_index = plan->blk->row;
229 : } else {
230 : for (int i = 0; i < plan_size; i++) {
231 : const double element = blk_data[i];
232 : data_send[offset + i] = element;
233 : norm += element * element;
234 : }
235 : blks_send[jblock].free_index = plan->blk->row;
236 : blks_send[jblock].sum_index = plan->blk->col;
237 : }
238 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
239 :
240 : // After the block exchange data_recv_displ will be added to the offsets.
241 : blks_send[jblock].offset = offset - data_send_displ[irank];
242 : }
243 : } // end of omp parallel region
244 441026 : }
245 :
246 : /*******************************************************************************
247 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
248 : * \author Ole Schuett
249 : ******************************************************************************/
250 78505688 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
251 78505688 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
252 78505688 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
253 78505688 : return blk_a->sum_index - blk_b->sum_index;
254 : }
255 :
256 : /*******************************************************************************
257 : * \brief Private routine for post-processing received blocks.
258 : * \author Ole Schuett
259 : ******************************************************************************/
260 441026 : static void postprocess_received_blocks(
261 : const int nranks, const int nshards, const int nblocks_recv,
262 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
263 : const int data_recv_displ[nranks],
264 441026 : dbm_pack_block_t blks_recv[nblocks_recv]) {
265 :
266 441026 : int nblocks_per_shard[nshards], shard_start[nshards];
267 441026 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
268 441026 : dbm_pack_block_t *blocks_tmp =
269 441026 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
270 441026 : assert(blocks_tmp != NULL);
271 :
272 441026 : #pragma omp parallel
273 : {
274 : // Add data_recv_displ to recveived block offsets.
275 : for (int irank = 0; irank < nranks; irank++) {
276 : #pragma omp for
277 : for (int i = 0; i < blks_recv_count[irank]; i++) {
278 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
279 : }
280 : }
281 :
282 : // First use counting sort to group blocks by their free_index shard.
283 : int nblocks_mythread[nshards];
284 : memset(nblocks_mythread, 0, nshards * sizeof(int));
285 : #pragma omp for schedule(static)
286 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
287 : blocks_tmp[iblock] = blks_recv[iblock];
288 : const int ishard = blks_recv[iblock].free_index % nshards;
289 : nblocks_mythread[ishard]++;
290 : }
291 : #pragma omp critical
292 : for (int ishard = 0; ishard < nshards; ishard++) {
293 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
294 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
295 : }
296 : #pragma omp barrier
297 : #pragma omp master
298 : icumsum(nshards, nblocks_per_shard, shard_start);
299 : #pragma omp barrier
300 : #pragma omp for schedule(static) // Need static to match previous loop.
301 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
302 : const int ishard = blocks_tmp[iblock].free_index % nshards;
303 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
304 : blks_recv[jblock] = blocks_tmp[iblock];
305 : }
306 :
307 : // Then sort blocks within each shard by their sum_index.
308 : #pragma omp for
309 : for (int ishard = 0; ishard < nshards; ishard++) {
310 : if (nblocks_per_shard[ishard] > 1) {
311 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
312 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
313 : }
314 : }
315 : } // end of omp parallel region
316 :
317 441026 : free(blocks_tmp);
318 441026 : }
319 :
320 : /*******************************************************************************
321 : * \brief Private routine for redistributing a matrix along selected dimensions.
322 : * \author Ole Schuett
323 : ******************************************************************************/
324 423590 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
325 : const bool trans_dist,
326 : const dbm_matrix_t *matrix,
327 : const dbm_distribution_t *dist,
328 423590 : const int nticks) {
329 :
330 423590 : assert(dbm_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
331 :
332 : // The row/col indicies are distributed along one cart dimension and the
333 : // ticks are distributed along the other cart dimension.
334 423590 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
335 423590 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
336 :
337 : // Allocate packed matrix.
338 423590 : const int nsend_packs = nticks / dist_ticks->nranks;
339 423590 : assert(nsend_packs * dist_ticks->nranks == nticks);
340 423590 : dbm_packed_matrix_t packed;
341 423590 : packed.dist_indices = dist_indices;
342 423590 : packed.dist_ticks = dist_ticks;
343 423590 : packed.nsend_packs = nsend_packs;
344 423590 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
345 423590 : assert(packed.send_packs != NULL);
346 :
347 : // Plan all packs.
348 423590 : plan_t *plans_per_pack[nsend_packs];
349 423590 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
350 423590 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
351 : dist_ticks, nticks, nsend_packs, plans_per_pack,
352 : nblks_send_per_pack, ndata_send_per_pack);
353 :
354 : // Allocate send buffers for maximum number of blocks/data over all packs.
355 423590 : int nblks_send_max = 0, ndata_send_max = 0;
356 864616 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
357 441026 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
358 441026 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
359 : }
360 423590 : dbm_pack_block_t *blks_send =
361 423590 : dbm_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
362 423590 : double *data_send = dbm_mpi_alloc_mem(ndata_send_max * sizeof(double));
363 :
364 : // Cannot parallelize over packs (there might be too few of them).
365 864616 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
366 : // Fill send buffers according to plans.
367 441026 : const int nranks = dist->nranks;
368 441026 : int blks_send_count[nranks], data_send_count[nranks];
369 441026 : int blks_send_displ[nranks], data_send_displ[nranks];
370 441026 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
371 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
372 : blks_send_count, data_send_count, blks_send_displ,
373 : data_send_displ, blks_send, data_send);
374 441026 : free(plans_per_pack[ipack]);
375 :
376 : // 1st communication: Exchange block counts.
377 441026 : int blks_recv_count[nranks], blks_recv_displ[nranks];
378 441026 : dbm_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
379 441026 : icumsum(nranks, blks_recv_count, blks_recv_displ);
380 441026 : const int nblocks_recv = isum(nranks, blks_recv_count);
381 :
382 : // 2nd communication: Exchange blocks.
383 441026 : dbm_pack_block_t *blks_recv =
384 441026 : dbm_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
385 441026 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
386 441026 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
387 934360 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
388 493334 : blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
389 493334 : blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
390 493334 : blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
391 493334 : blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
392 : }
393 441026 : dbm_mpi_alltoallv_byte(
394 : blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
395 441026 : blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
396 :
397 : // 3rd communication: Exchange data counts.
398 : // TODO: could be computed from blks_recv.
399 441026 : int data_recv_count[nranks], data_recv_displ[nranks];
400 441026 : dbm_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
401 441026 : icumsum(nranks, data_recv_count, data_recv_displ);
402 441026 : const int ndata_recv = isum(nranks, data_recv_count);
403 :
404 : // 4th communication: Exchange data.
405 441026 : double *data_recv = dbm_mpi_alloc_mem(ndata_recv * sizeof(double));
406 441026 : dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
407 : data_recv, data_recv_count, data_recv_displ,
408 441026 : dist->comm);
409 :
410 : // Post-process received blocks and assemble them into a pack.
411 441026 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
412 : blks_recv_count, blks_recv_displ,
413 : data_recv_displ, blks_recv);
414 441026 : packed.send_packs[ipack].nblocks = nblocks_recv;
415 441026 : packed.send_packs[ipack].data_size = ndata_recv;
416 441026 : packed.send_packs[ipack].blocks = blks_recv;
417 441026 : packed.send_packs[ipack].data = data_recv;
418 : }
419 :
420 : // Deallocate send buffers.
421 423590 : dbm_mpi_free_mem(blks_send);
422 423590 : dbm_mpi_free_mem(data_send);
423 :
424 : // Allocate pack_recv.
425 423590 : int max_nblocks = 0, max_data_size = 0;
426 864616 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
427 441026 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
428 441026 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
429 : }
430 423590 : dbm_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
431 423590 : dbm_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
432 423590 : packed.max_nblocks = max_nblocks;
433 423590 : packed.max_data_size = max_data_size;
434 847180 : packed.recv_pack.blocks =
435 423590 : dbm_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
436 847180 : packed.recv_pack.data =
437 423590 : dbm_mpi_alloc_mem(packed.max_data_size * sizeof(double));
438 :
439 423590 : return packed; // Ownership of packed transfers to caller.
440 : }
441 :
442 : /*******************************************************************************
443 : * \brief Private routine for sending and receiving the pack for the given tick.
444 : * \author Ole Schuett
445 : ******************************************************************************/
446 458462 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
447 : dbm_packed_matrix_t *packed) {
448 458462 : const int nranks = packed->dist_ticks->nranks;
449 458462 : const int my_rank = packed->dist_ticks->my_rank;
450 :
451 : // Compute send rank and pack.
452 458462 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
453 458462 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
454 458462 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
455 458462 : const int send_ipack = send_itick / nranks;
456 458462 : assert(send_itick % nranks == my_rank);
457 :
458 : // Compute receive rank and pack.
459 458462 : const int recv_rank = itick % nranks;
460 458462 : const int recv_ipack = itick / nranks;
461 :
462 458462 : if (send_rank == my_rank) {
463 441026 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
464 441026 : return &packed->send_packs[send_ipack]; // Local pack, no mpi needed.
465 : } else {
466 17436 : const dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
467 :
468 : // Exchange blocks.
469 34872 : const int nblocks_in_bytes = dbm_mpi_sendrecv_byte(
470 17436 : /*sendbuf=*/send_pack->blocks,
471 17436 : /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
472 : /*dest=*/send_rank,
473 : /*sendtag=*/send_ipack,
474 17436 : /*recvbuf=*/packed->recv_pack.blocks,
475 17436 : /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
476 : /*source=*/recv_rank,
477 : /*recvtag=*/recv_ipack,
478 17436 : /*comm=*/packed->dist_ticks->comm);
479 :
480 17436 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
481 17436 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
482 :
483 : // Exchange data.
484 34872 : packed->recv_pack.data_size = dbm_mpi_sendrecv_double(
485 17436 : /*sendbuf=*/send_pack->data,
486 17436 : /*sendcound=*/send_pack->data_size,
487 : /*dest=*/send_rank,
488 : /*sendtag=*/send_ipack,
489 : /*recvbuf=*/packed->recv_pack.data,
490 : /*recvcount=*/packed->max_data_size,
491 : /*source=*/recv_rank,
492 : /*recvtag=*/recv_ipack,
493 17436 : /*comm=*/packed->dist_ticks->comm);
494 :
495 17436 : return &packed->recv_pack;
496 : }
497 : }
498 :
499 : /*******************************************************************************
500 : * \brief Private routine for releasing a packed matrix.
501 : * \author Ole Schuett
502 : ******************************************************************************/
503 423590 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
504 423590 : dbm_mpi_free_mem(packed->recv_pack.blocks);
505 423590 : dbm_mpi_free_mem(packed->recv_pack.data);
506 864616 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
507 441026 : dbm_mpi_free_mem(packed->send_packs[ipack].blocks);
508 441026 : dbm_mpi_free_mem(packed->send_packs[ipack].data);
509 : }
510 423590 : free(packed->send_packs);
511 423590 : }
512 :
513 : /*******************************************************************************
514 : * \brief Internal routine for creating a communication iterator.
515 : * \author Ole Schuett
516 : ******************************************************************************/
517 211795 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
518 : const bool transb,
519 : const dbm_matrix_t *matrix_a,
520 : const dbm_matrix_t *matrix_b,
521 : const dbm_matrix_t *matrix_c) {
522 :
523 211795 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
524 211795 : assert(iter != NULL);
525 211795 : iter->dist = matrix_c->dist;
526 :
527 : // During each communication tick we'll fetch a pack_a and pack_b.
528 : // Since the cart might be non-squared, the number of communication ticks is
529 : // chosen as the least common multiple of the cart's dimensions.
530 211795 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
531 211795 : iter->itick = 0;
532 :
533 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
534 211795 : iter->packed_a =
535 211795 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
536 211795 : iter->packed_b =
537 211795 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
538 :
539 211795 : return iter;
540 : }
541 :
542 : /*******************************************************************************
543 : * \brief Internal routine for retriving next pair of packs from given iterator.
544 : * \author Ole Schuett
545 : ******************************************************************************/
546 441026 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
547 : dbm_pack_t **pack_b) {
548 441026 : if (iter->itick >= iter->nticks) {
549 : return false; // end of iterator reached
550 : }
551 :
552 : // Start each rank at a different tick to spread the load on the sources.
553 229231 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
554 229231 : const int shifted_itick = (iter->itick + shift) % iter->nticks;
555 229231 : *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
556 229231 : *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
557 :
558 229231 : iter->itick++;
559 229231 : return true;
560 : }
561 :
562 : /*******************************************************************************
563 : * \brief Internal routine for releasing the given communication iterator.
564 : * \author Ole Schuett
565 : ******************************************************************************/
566 211795 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
567 211795 : free_packed_matrix(&iter->packed_a);
568 211795 : free_packed_matrix(&iter->packed_b);
569 211795 : free(iter);
570 211795 : }
571 :
572 : // EOF
|