Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #include <assert.h>
9 : #include <stdio.h>
10 : #include <stdlib.h>
11 : #include <string.h>
12 :
13 : #include "dbm_mpi.h"
14 :
15 : #if defined(__parallel)
16 : /*******************************************************************************
17 : * \brief Check given MPI status and upon failure abort with a nice message.
18 : * \author Ole Schuett
19 : ******************************************************************************/
20 : #define CHECK(status) \
21 : if (status != MPI_SUCCESS) { \
22 : fprintf(stderr, "MPI error in %s:%i\n", __FILE__, __LINE__); \
23 : MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); \
24 : }
25 : #endif
26 :
27 : /*******************************************************************************
28 : * \brief Wrapper around MPI_Init.
29 : * \author Ole Schuett
30 : ******************************************************************************/
31 0 : void dbm_mpi_init(int *argc, char ***argv) {
32 : #if defined(__parallel)
33 0 : CHECK(MPI_Init(argc, argv));
34 : #else
35 : (void)argc; // mark used
36 : (void)argv;
37 : #endif
38 0 : }
39 :
40 : /*******************************************************************************
41 : * \brief Wrapper around MPI_Finalize.
42 : * \author Ole Schuett
43 : ******************************************************************************/
44 0 : void dbm_mpi_finalize() {
45 : #if defined(__parallel)
46 0 : CHECK(MPI_Finalize());
47 : #endif
48 0 : }
49 :
50 : /*******************************************************************************
51 : * \brief Returns MPI_COMM_WORLD.
52 : * \author Ole Schuett
53 : ******************************************************************************/
54 0 : dbm_mpi_comm_t dbm_mpi_get_comm_world() {
55 : #if defined(__parallel)
56 0 : return MPI_COMM_WORLD;
57 : #else
58 : return -1;
59 : #endif
60 : }
61 :
62 : /*******************************************************************************
63 : * \brief Wrapper around MPI_Comm_f2c.
64 : * \author Ole Schuett
65 : ******************************************************************************/
66 825510 : dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm) {
67 : #if defined(__parallel)
68 825510 : return MPI_Comm_f2c(fortran_comm);
69 : #else
70 : (void)fortran_comm; // mark used
71 : return -1;
72 : #endif
73 : }
74 :
75 : /*******************************************************************************
76 : * \brief Wrapper around MPI_Comm_c2f.
77 : * \author Ole Schuett
78 : ******************************************************************************/
79 0 : int dbm_mpi_comm_c2f(const dbm_mpi_comm_t comm) {
80 : #if defined(__parallel)
81 0 : return MPI_Comm_c2f(comm);
82 : #else
83 : (void)comm; // mark used
84 : return -1;
85 : #endif
86 : }
87 :
88 : /*******************************************************************************
89 : * \brief Wrapper around MPI_Comm_rank.
90 : * \author Ole Schuett
91 : ******************************************************************************/
92 2448726 : int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm) {
93 : #if defined(__parallel)
94 2448726 : int rank;
95 2448726 : CHECK(MPI_Comm_rank(comm, &rank));
96 2448726 : return rank;
97 : #else
98 : (void)comm; // mark used
99 : return 0;
100 : #endif
101 : }
102 :
103 : /*******************************************************************************
104 : * \brief Wrapper around MPI_Comm_size.
105 : * \author Ole Schuett
106 : ******************************************************************************/
107 2448870 : int dbm_mpi_comm_size(const dbm_mpi_comm_t comm) {
108 : #if defined(__parallel)
109 2448870 : int nranks;
110 2448870 : CHECK(MPI_Comm_size(comm, &nranks));
111 2448870 : return nranks;
112 : #else
113 : (void)comm; // mark used
114 : return 1;
115 : #endif
116 : }
117 :
118 : /*******************************************************************************
119 : * \brief Wrapper around MPI_Dims_create.
120 : * \author Ole Schuett
121 : ******************************************************************************/
122 0 : void dbm_mpi_dims_create(const int nnodes, const int ndims, int dims[]) {
123 : #if defined(__parallel)
124 0 : CHECK(MPI_Dims_create(nnodes, ndims, dims));
125 : #else
126 : dims[0] = nnodes;
127 : for (int i = 1; i < ndims; i++) {
128 : dims[i] = 1;
129 : }
130 : #endif
131 0 : }
132 :
133 : /*******************************************************************************
134 : * \brief Wrapper around MPI_Cart_create.
135 : * \author Ole Schuett
136 : ******************************************************************************/
137 0 : dbm_mpi_comm_t dbm_mpi_cart_create(const dbm_mpi_comm_t comm_old,
138 : const int ndims, const int dims[],
139 : const int periods[], const int reorder) {
140 : #if defined(__parallel)
141 0 : dbm_mpi_comm_t comm_cart;
142 0 : CHECK(MPI_Cart_create(comm_old, ndims, dims, periods, reorder, &comm_cart));
143 0 : return comm_cart;
144 : #else
145 : (void)comm_old; // mark used
146 : (void)ndims;
147 : (void)dims;
148 : (void)periods;
149 : (void)reorder;
150 : return -1;
151 : #endif
152 : }
153 :
154 : /*******************************************************************************
155 : * \brief Wrapper around MPI_Cart_get.
156 : * \author Ole Schuett
157 : ******************************************************************************/
158 1632484 : void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[],
159 : int periods[], int coords[]) {
160 : #if defined(__parallel)
161 1632484 : CHECK(MPI_Cart_get(comm, maxdims, dims, periods, coords));
162 : #else
163 : (void)comm; // mark used
164 : for (int i = 0; i < maxdims; i++) {
165 : dims[i] = 1;
166 : periods[i] = 1;
167 : coords[i] = 0;
168 : }
169 : #endif
170 1632484 : }
171 :
172 : /*******************************************************************************
173 : * \brief Wrapper around MPI_Cart_rank.
174 : * \author Ole Schuett
175 : ******************************************************************************/
176 105514749 : int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[]) {
177 : #if defined(__parallel)
178 105514749 : int rank;
179 105514749 : CHECK(MPI_Cart_rank(comm, coords, &rank));
180 105514749 : return rank;
181 : #else
182 : (void)comm; // mark used
183 : (void)coords;
184 : return 0;
185 : #endif
186 : }
187 :
188 : /*******************************************************************************
189 : * \brief Wrapper around MPI_Cart_sub.
190 : * \author Ole Schuett
191 : ******************************************************************************/
192 1632484 : dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm,
193 : const int remain_dims[]) {
194 : #if defined(__parallel)
195 1632484 : dbm_mpi_comm_t newcomm;
196 1632484 : CHECK(MPI_Cart_sub(comm, remain_dims, &newcomm));
197 1632484 : return newcomm;
198 : #else
199 : (void)comm; // mark used
200 : (void)remain_dims;
201 : return -1;
202 : #endif
203 : }
204 :
205 : /*******************************************************************************
206 : * \brief Wrapper around MPI_Comm_free.
207 : * \author Ole Schuett
208 : ******************************************************************************/
209 1632484 : void dbm_mpi_comm_free(dbm_mpi_comm_t *comm) {
210 : #if defined(__parallel)
211 1632484 : CHECK(MPI_Comm_free(comm));
212 : #else
213 : (void)comm; // mark used
214 : #endif
215 1632484 : }
216 :
217 : /*******************************************************************************
218 : * \brief Wrapper around MPI_Comm_compare.
219 : * \author Ole Schuett
220 : ******************************************************************************/
221 423734 : bool dbm_mpi_comms_are_similar(const dbm_mpi_comm_t comm1,
222 : const dbm_mpi_comm_t comm2) {
223 : #if defined(__parallel)
224 423734 : int res;
225 423734 : CHECK(MPI_Comm_compare(comm1, comm2, &res));
226 423734 : return res == MPI_IDENT || res == MPI_CONGRUENT || res == MPI_SIMILAR;
227 : #else
228 : (void)comm1; // mark used
229 : (void)comm2;
230 : return true;
231 : #endif
232 : }
233 :
234 : /*******************************************************************************
235 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_INT.
236 : * \author Ole Schuett
237 : ******************************************************************************/
238 847180 : void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm) {
239 : #if defined(__parallel)
240 847180 : int value = 0;
241 847180 : void *recvbuf = (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int)) : &value);
242 847180 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_MAX, comm));
243 847180 : memcpy(values, recvbuf, count * sizeof(int));
244 847180 : if (1 < count) {
245 0 : dbm_mpi_free_mem(recvbuf);
246 : }
247 : #else
248 : (void)comm; // mark used
249 : (void)values;
250 : (void)count;
251 : #endif
252 847180 : }
253 :
254 : /*******************************************************************************
255 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_UINT64_T.
256 : * \author Ole Schuett
257 : ******************************************************************************/
258 18536 : void dbm_mpi_max_uint64(uint64_t *values, const int count,
259 : const dbm_mpi_comm_t comm) {
260 : #if defined(__parallel)
261 18536 : uint64_t value = 0;
262 37072 : void *recvbuf =
263 18536 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(uint64_t)) : &value);
264 18536 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_UINT64_T, MPI_MAX, comm));
265 18536 : memcpy(values, recvbuf, count * sizeof(uint64_t));
266 18536 : if (1 < count) {
267 0 : dbm_mpi_free_mem(recvbuf);
268 : }
269 : #else
270 : (void)comm; // mark used
271 : (void)values;
272 : (void)count;
273 : #endif
274 18536 : }
275 :
276 : /*******************************************************************************
277 : * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_DOUBLE.
278 : * \author Ole Schuett
279 : ******************************************************************************/
280 48 : void dbm_mpi_max_double(double *values, const int count,
281 : const dbm_mpi_comm_t comm) {
282 : #if defined(__parallel)
283 48 : double value = 0;
284 96 : void *recvbuf =
285 48 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(double)) : &value);
286 48 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_MAX, comm));
287 48 : memcpy(values, recvbuf, count * sizeof(double));
288 48 : if (1 < count) {
289 0 : dbm_mpi_free_mem(recvbuf);
290 : }
291 : #else
292 : (void)comm; // mark used
293 : (void)values;
294 : (void)count;
295 : #endif
296 48 : }
297 :
298 : /*******************************************************************************
299 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT.
300 : * \author Ole Schuett
301 : ******************************************************************************/
302 211795 : void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm) {
303 : #if defined(__parallel)
304 211795 : int value = 0;
305 211795 : void *recvbuf = (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int)) : &value);
306 211795 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_SUM, comm));
307 211795 : memcpy(values, recvbuf, count * sizeof(int));
308 211795 : if (1 < count) {
309 209856 : dbm_mpi_free_mem(recvbuf);
310 : }
311 : #else
312 : (void)comm; // mark used
313 : (void)values;
314 : (void)count;
315 : #endif
316 211795 : }
317 :
318 : /*******************************************************************************
319 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT64_T.
320 : * \author Ole Schuett
321 : ******************************************************************************/
322 804947 : void dbm_mpi_sum_int64(int64_t *values, const int count,
323 : const dbm_mpi_comm_t comm) {
324 : #if defined(__parallel)
325 804947 : int64_t value = 0;
326 1609894 : void *recvbuf =
327 804947 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(int64_t)) : &value);
328 804947 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT64_T, MPI_SUM, comm));
329 804947 : memcpy(values, recvbuf, count * sizeof(int64_t));
330 804947 : if (1 < count) {
331 0 : dbm_mpi_free_mem(recvbuf);
332 : }
333 : #else
334 : (void)comm; // mark used
335 : (void)values;
336 : (void)count;
337 : #endif
338 804947 : }
339 :
340 : /*******************************************************************************
341 : * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_DOUBLE.
342 : * \author Ole Schuett
343 : ******************************************************************************/
344 190 : void dbm_mpi_sum_double(double *values, const int count,
345 : const dbm_mpi_comm_t comm) {
346 : #if defined(__parallel)
347 190 : double value = 0;
348 380 : void *recvbuf =
349 190 : (1 < count ? dbm_mpi_alloc_mem(count * sizeof(double)) : &value);
350 190 : CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_SUM, comm));
351 190 : memcpy(values, recvbuf, count * sizeof(double));
352 190 : if (1 < count) {
353 0 : dbm_mpi_free_mem(recvbuf);
354 : }
355 : #else
356 : (void)comm; // mark used
357 : (void)values;
358 : (void)count;
359 : #endif
360 190 : }
361 :
362 : /*******************************************************************************
363 : * \brief Wrapper around MPI_Sendrecv for datatype MPI_BYTE.
364 : * \author Ole Schuett
365 : ******************************************************************************/
366 17436 : int dbm_mpi_sendrecv_byte(const void *sendbuf, const int sendcount,
367 : const int dest, const int sendtag, void *recvbuf,
368 : const int recvcount, const int source,
369 : const int recvtag, const dbm_mpi_comm_t comm) {
370 : #if defined(__parallel)
371 17436 : MPI_Status status;
372 17436 : CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_BYTE, dest, sendtag, recvbuf,
373 : recvcount, MPI_BYTE, source, recvtag, comm, &status))
374 17436 : int count_received;
375 17436 : CHECK(MPI_Get_count(&status, MPI_BYTE, &count_received));
376 17436 : return count_received;
377 : #else
378 : (void)sendbuf; // mark used
379 : (void)sendcount;
380 : (void)dest;
381 : (void)sendtag;
382 : (void)recvbuf;
383 : (void)recvcount;
384 : (void)source;
385 : (void)recvtag;
386 : (void)comm;
387 : fprintf(stderr, "Error: dbm_mpi_sendrecv_byte not available without MPI\n");
388 : abort();
389 : #endif
390 : }
391 :
392 : /*******************************************************************************
393 : * \brief Wrapper around MPI_Sendrecv for datatype MPI_DOUBLE.
394 : * \author Ole Schuett
395 : ******************************************************************************/
396 17436 : int dbm_mpi_sendrecv_double(const double *sendbuf, const int sendcount,
397 : const int dest, const int sendtag, double *recvbuf,
398 : const int recvcount, const int source,
399 : const int recvtag, const dbm_mpi_comm_t comm) {
400 : #if defined(__parallel)
401 17436 : MPI_Status status;
402 17436 : CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_DOUBLE, dest, sendtag, recvbuf,
403 : recvcount, MPI_DOUBLE, source, recvtag, comm, &status))
404 17436 : int count_received;
405 17436 : CHECK(MPI_Get_count(&status, MPI_DOUBLE, &count_received));
406 17436 : return count_received;
407 : #else
408 : (void)sendbuf; // mark used
409 : (void)sendcount;
410 : (void)dest;
411 : (void)sendtag;
412 : (void)recvbuf;
413 : (void)recvcount;
414 : (void)source;
415 : (void)recvtag;
416 : (void)comm;
417 : fprintf(stderr, "Error: dbm_mpi_sendrecv_double not available without MPI\n");
418 : abort();
419 : #endif
420 : }
421 :
422 : /*******************************************************************************
423 : * \brief Wrapper around MPI_Alltoall for datatype MPI_INT.
424 : * \author Ole Schuett
425 : ******************************************************************************/
426 882196 : void dbm_mpi_alltoall_int(const int *sendbuf, const int sendcount, int *recvbuf,
427 : const int recvcount, const dbm_mpi_comm_t comm) {
428 : #if defined(__parallel)
429 882196 : CHECK(MPI_Alltoall(sendbuf, sendcount, MPI_INT, recvbuf, recvcount, MPI_INT,
430 882196 : comm));
431 : #else
432 : (void)comm; // mark used
433 : assert(sendcount == recvcount);
434 : memcpy(recvbuf, sendbuf, sendcount * sizeof(int));
435 : #endif
436 882196 : }
437 :
438 : /*******************************************************************************
439 : * \brief Wrapper around MPI_Alltoallv for datatype MPI_BYTE.
440 : * \author Ole Schuett
441 : ******************************************************************************/
442 441026 : void dbm_mpi_alltoallv_byte(const void *sendbuf, const int *sendcounts,
443 : const int *sdispls, void *recvbuf,
444 : const int *recvcounts, const int *rdispls,
445 : const dbm_mpi_comm_t comm) {
446 : #if defined(__parallel)
447 441026 : CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_BYTE, recvbuf,
448 441026 : recvcounts, rdispls, MPI_BYTE, comm));
449 : #else
450 : (void)comm; // mark used
451 : assert(sendcounts[0] == recvcounts[0]);
452 : assert(sdispls[0] == 0 && rdispls[0] == 0);
453 : memcpy(recvbuf, sendbuf, sendcounts[0]);
454 : #endif
455 441026 : }
456 :
457 : /*******************************************************************************
458 : * \brief Wrapper around MPI_Alltoallv for datatype MPI_DOUBLE.
459 : * \author Ole Schuett
460 : ******************************************************************************/
461 441170 : void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts,
462 : const int *sdispls, double *recvbuf,
463 : const int *recvcounts, const int *rdispls,
464 : const dbm_mpi_comm_t comm) {
465 : #if defined(__parallel)
466 441170 : CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_DOUBLE, recvbuf,
467 441170 : recvcounts, rdispls, MPI_DOUBLE, comm));
468 : #else
469 : (void)comm; // mark used
470 : assert(sendcounts[0] == recvcounts[0]);
471 : assert(sdispls[0] == 0 && rdispls[0] == 0);
472 : memcpy(recvbuf, sendbuf, sendcounts[0] * sizeof(double));
473 : #endif
474 441170 : }
475 :
476 : /*******************************************************************************
477 : * \brief Wrapper around MPI_Alloc_mem.
478 : * \author Hans Pabst
479 : ******************************************************************************/
480 2786556 : void *dbm_mpi_alloc_mem(size_t size) {
481 2786556 : void *result = NULL;
482 : #if defined(__parallel)
483 2786556 : CHECK(MPI_Alloc_mem((MPI_Aint)size, MPI_INFO_NULL, &result));
484 : #else
485 : result = malloc(size);
486 : #endif
487 2786556 : return result;
488 : }
489 :
490 : /*******************************************************************************
491 : * \brief Wrapper around MPI_Free_mem.
492 : * \author Hans Pabst
493 : ******************************************************************************/
494 2786556 : void dbm_mpi_free_mem(void *mem) {
495 : #if defined(__parallel)
496 2786556 : CHECK(MPI_Free_mem(mem));
497 : #else
498 : free(mem);
499 : #endif
500 2786556 : }
501 :
502 : // EOF
|