Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief communication routines to reshape / replicate / merge tall-and-skinny matrices.
10 : !> \author Patrick Seewald
11 : ! **************************************************************************************************
12 : MODULE dbt_tas_reshape_ops
13 : USE OMP_LIB, ONLY: omp_get_num_threads,&
14 : omp_get_thread_num,&
15 : omp_init_lock,&
16 : omp_lock_kind,&
17 : omp_set_lock,&
18 : omp_unset_lock
19 : USE dbm_api, ONLY: &
20 : dbm_clear, dbm_distribution_col_dist, dbm_distribution_obj, dbm_distribution_row_dist, &
21 : dbm_finalize, dbm_get_col_block_sizes, dbm_get_distribution, dbm_get_name, &
22 : dbm_get_row_block_sizes, dbm_get_stored_coordinates, dbm_iterator, &
23 : dbm_iterator_blocks_left, dbm_iterator_next_block, dbm_iterator_start, dbm_iterator_stop, &
24 : dbm_put_block, dbm_reserve_blocks, dbm_type
25 : USE dbt_tas_base, ONLY: &
26 : dbt_repl_get_stored_coordinates, dbt_tas_blk_sizes, dbt_tas_clear, dbt_tas_create, &
27 : dbt_tas_distribution_new, dbt_tas_finalize, dbt_tas_get_stored_coordinates, dbt_tas_info, &
28 : dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
29 : dbt_tas_iterator_stop, dbt_tas_put_block, dbt_tas_reserve_blocks
30 : USE dbt_tas_global, ONLY: dbt_tas_blk_size_arb,&
31 : dbt_tas_blk_size_repl,&
32 : dbt_tas_dist_arb,&
33 : dbt_tas_dist_repl,&
34 : dbt_tas_distribution,&
35 : dbt_tas_rowcol_data
36 : USE dbt_tas_split, ONLY: colsplit,&
37 : dbt_tas_get_split_info,&
38 : rowsplit
39 : USE dbt_tas_types, ONLY: dbt_tas_distribution_type,&
40 : dbt_tas_iterator,&
41 : dbt_tas_split_info,&
42 : dbt_tas_type
43 : USE dbt_tas_util, ONLY: swap
44 : USE kinds, ONLY: dp,&
45 : int_8
46 : USE message_passing, ONLY: mp_cart_type,&
47 : mp_comm_type,&
48 : mp_request_type,&
49 : mp_waitall
50 : #include "../../base/base_uses.f90"
51 :
52 : IMPLICIT NONE
53 : PRIVATE
54 :
55 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_reshape_ops'
56 :
57 : PUBLIC :: &
58 : dbt_tas_merge, &
59 : dbt_tas_replicate, &
60 : dbt_tas_reshape
61 :
62 : TYPE dbt_buffer_type
63 : INTEGER :: nblock = -1
64 : INTEGER(KIND=int_8), DIMENSION(:, :), ALLOCATABLE :: indx
65 : REAL(dp), DIMENSION(:), ALLOCATABLE :: msg
66 : INTEGER :: endpos = -1
67 : END TYPE
68 :
69 : CONTAINS
70 :
71 : ! **************************************************************************************************
72 : !> \brief copy data (involves reshape)
73 : !> \param matrix_in ...
74 : !> \param matrix_out ...
75 : !> \param summation whether matrix_out = matrix_out + matrix_in
76 : !> \param transposed ...
77 : !> \param move_data memory optimization: move data to matrix_out such that matrix_in is empty on return
78 : !> \author Patrick Seewald
79 : ! **************************************************************************************************
80 204982 : RECURSIVE SUBROUTINE dbt_tas_reshape(matrix_in, matrix_out, summation, transposed, move_data)
81 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in, matrix_out
82 : LOGICAL, INTENT(IN), OPTIONAL :: summation, transposed, move_data
83 :
84 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_reshape'
85 :
86 : INTEGER :: a, b, bcount, handle, handle2, iproc, &
87 : nblk, nblk_per_thread, ndata, numnodes
88 204982 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: blks_to_allocate, index_recv
89 : INTEGER(KIND=int_8), DIMENSION(2) :: blk_index
90 : INTEGER(kind=omp_lock_kind), ALLOCATABLE, &
91 204982 : DIMENSION(:) :: locks
92 204982 : INTEGER, ALLOCATABLE, DIMENSION(:) :: num_blocks_recv, num_blocks_send, &
93 204982 : num_entries_recv, num_entries_send, &
94 204982 : num_rec, num_send
95 : INTEGER, DIMENSION(2) :: blk_size
96 : LOGICAL :: move_prv, tr_in
97 204982 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: block
98 204982 : TYPE(dbt_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
99 : TYPE(dbt_tas_iterator) :: iter
100 1024910 : TYPE(dbt_tas_split_info) :: info
101 : TYPE(mp_comm_type) :: mp_comm
102 : TYPE(mp_request_type), ALLOCATABLE, &
103 204982 : DIMENSION(:, :) :: req_array
104 :
105 204982 : CALL timeset(routineN, handle)
106 :
107 204982 : IF (PRESENT(summation)) THEN
108 70019 : IF (.NOT. summation) CALL dbm_clear(matrix_out%matrix)
109 : ELSE
110 134963 : CALL dbm_clear(matrix_out%matrix)
111 : END IF
112 :
113 204982 : IF (PRESENT(move_data)) THEN
114 204982 : move_prv = move_data
115 : ELSE
116 : move_prv = .FALSE.
117 : END IF
118 :
119 204982 : IF (PRESENT(transposed)) THEN
120 204982 : tr_in = transposed
121 : ELSE
122 0 : tr_in = .FALSE.
123 : END IF
124 :
125 204982 : IF (.NOT. matrix_out%valid) THEN
126 0 : CPABORT("can not reshape into invalid matrix")
127 : END IF
128 :
129 204982 : info = dbt_tas_info(matrix_in)
130 204982 : mp_comm = info%mp_comm
131 204982 : numnodes = mp_comm%num_pe
132 933998 : ALLOCATE (buffer_send(0:numnodes - 1))
133 729016 : ALLOCATE (buffer_recv(0:numnodes - 1))
134 614946 : ALLOCATE (num_blocks_recv(0:numnodes - 1))
135 409964 : ALLOCATE (num_blocks_send(0:numnodes - 1))
136 409964 : ALLOCATE (num_entries_recv(0:numnodes - 1))
137 409964 : ALLOCATE (num_entries_send(0:numnodes - 1))
138 614946 : ALLOCATE (num_rec(0:2*numnodes - 1))
139 409964 : ALLOCATE (num_send(0:2*numnodes - 1))
140 843086 : num_send(:) = 0
141 2711082 : ALLOCATE (req_array(1:numnodes, 4))
142 409964 : ALLOCATE (locks(0:numnodes - 1))
143 524034 : DO iproc = 0, numnodes - 1
144 524034 : CALL omp_init_lock(locks(iproc))
145 : END DO
146 :
147 204982 : CALL timeset(routineN//"_get_coord", handle2)
148 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,tr_in,num_send) &
149 204982 : !$OMP PRIVATE(iter,blk_index,blk_size,iproc)
150 : CALL dbt_tas_iterator_start(iter, matrix_in)
151 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
152 : CALL dbt_tas_iterator_next_block(iter, blk_index(1), blk_index(2), &
153 : row_size=blk_size(1), col_size=blk_size(2))
154 : IF (tr_in) THEN
155 : CALL dbt_tas_get_stored_coordinates(matrix_out, blk_index(2), blk_index(1), iproc)
156 : ELSE
157 : CALL dbt_tas_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iproc)
158 : END IF
159 : !$OMP ATOMIC
160 : num_send(2*iproc) = num_send(2*iproc) + PRODUCT(blk_size)
161 : !$OMP ATOMIC
162 : num_send(2*iproc + 1) = num_send(2*iproc + 1) + 1
163 : END DO
164 : CALL dbt_tas_iterator_stop(iter)
165 : !$OMP END PARALLEL
166 204982 : CALL timestop(handle2)
167 :
168 204982 : CALL timeset(routineN//"_alltoall", handle2)
169 204982 : CALL mp_comm%alltoall(num_send, num_rec, 2)
170 204982 : CALL timestop(handle2)
171 :
172 204982 : CALL timeset(routineN//"_buffer_fill", handle2)
173 524034 : DO iproc = 0, numnodes - 1
174 319052 : num_entries_recv(iproc) = num_rec(2*iproc)
175 319052 : num_blocks_recv(iproc) = num_rec(2*iproc + 1)
176 319052 : num_entries_send(iproc) = num_send(2*iproc)
177 319052 : num_blocks_send(iproc) = num_send(2*iproc + 1)
178 :
179 319052 : CALL dbt_buffer_create(buffer_send(iproc), num_blocks_send(iproc), num_entries_send(iproc))
180 :
181 524034 : CALL dbt_buffer_create(buffer_recv(iproc), num_blocks_recv(iproc), num_entries_recv(iproc))
182 :
183 : END DO
184 :
185 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,tr_in,buffer_send,locks) &
186 204982 : !$OMP PRIVATE(iter,blk_index,blk_size,block,iproc)
187 : CALL dbt_tas_iterator_start(iter, matrix_in)
188 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
189 : CALL dbt_tas_iterator_next_block(iter, blk_index(1), blk_index(2), block, &
190 : row_size=blk_size(1), col_size=blk_size(2))
191 : IF (tr_in) THEN
192 : CALL dbt_tas_get_stored_coordinates(matrix_out, blk_index(2), blk_index(1), iproc)
193 : ELSE
194 : CALL dbt_tas_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iproc)
195 : END IF
196 : CALL omp_set_lock(locks(iproc))
197 : CALL dbt_buffer_add_block(buffer_send(iproc), blk_index, block, transposed=tr_in)
198 : CALL omp_unset_lock(locks(iproc))
199 : END DO
200 : CALL dbt_tas_iterator_stop(iter)
201 : !$OMP END PARALLEL
202 :
203 204982 : IF (move_prv) CALL dbt_tas_clear(matrix_in)
204 :
205 204982 : CALL timestop(handle2)
206 :
207 204982 : CALL timeset(routineN//"_communicate_buffer", handle2)
208 204982 : CALL dbt_tas_communicate_buffer(mp_comm, buffer_recv, buffer_send, req_array)
209 :
210 524034 : DO iproc = 0, numnodes - 1
211 524034 : CALL dbt_buffer_destroy(buffer_send(iproc))
212 : END DO
213 :
214 204982 : CALL timestop(handle2)
215 :
216 204982 : CALL timeset(routineN//"_buffer_obtain", handle2)
217 :
218 : ! TODO Add OpenMP to the buffer unpacking.
219 524034 : nblk = SUM(num_blocks_recv)
220 589778 : ALLOCATE (blks_to_allocate(nblk, 2))
221 :
222 204982 : bcount = 0
223 524034 : DO iproc = 0, numnodes - 1
224 319052 : CALL dbt_buffer_get_index(buffer_recv(iproc), index_recv)
225 5419806 : blks_to_allocate(bcount + 1:bcount + SIZE(index_recv, 1), :) = index_recv(:, :)
226 319052 : bcount = bcount + SIZE(index_recv, 1)
227 843086 : DEALLOCATE (index_recv)
228 : END DO
229 :
230 : !TODO: Parallelize creation of block list.
231 204982 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_out,nblk,blks_to_allocate) PRIVATE(nblk_per_thread,A,b)
232 : nblk_per_thread = nblk/omp_get_num_threads() + 1
233 : a = omp_get_thread_num()*nblk_per_thread + 1
234 : b = MIN(a + nblk_per_thread, nblk)
235 : CALL dbt_tas_reserve_blocks(matrix_out, blks_to_allocate(a:b, 1), blks_to_allocate(a:b, 2))
236 : !$OMP END PARALLEL
237 204982 : DEALLOCATE (blks_to_allocate)
238 :
239 524034 : DO iproc = 0, numnodes - 1
240 : ! First, we need to get the index to create block
241 2550377 : DO WHILE (dbt_buffer_blocks_left(buffer_recv(iproc)))
242 2231325 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index)
243 2231325 : CALL dbt_tas_blk_sizes(matrix_out, blk_index(1), blk_index(2), blk_size(1), blk_size(2))
244 8925300 : ALLOCATE (block(blk_size(1), blk_size(2)))
245 2231325 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index, block)
246 2231325 : CALL dbt_tas_put_block(matrix_out, blk_index(1), blk_index(2), block, summation=summation)
247 7013027 : DEALLOCATE (block)
248 : END DO
249 524034 : CALL dbt_buffer_destroy(buffer_recv(iproc))
250 : END DO
251 :
252 204982 : CALL timestop(handle2)
253 :
254 204982 : CALL dbt_tas_finalize(matrix_out)
255 :
256 204982 : CALL timestop(handle)
257 2277960 : END SUBROUTINE
258 :
259 : ! **************************************************************************************************
260 : !> \brief Replicate matrix_in such that each submatrix of matrix_out is an exact copy of matrix_in
261 : !> \param matrix_in ...
262 : !> \param info ...
263 : !> \param matrix_out ...
264 : !> \param nodata Don't copy data but create matrix_out
265 : !> \param move_data memory optimization: move data to matrix_out such that matrix_in is empty on return
266 : !> \author Patrick Seewald
267 : ! **************************************************************************************************
268 1136352 : SUBROUTINE dbt_tas_replicate(matrix_in, info, matrix_out, nodata, move_data)
269 : TYPE(dbm_type), INTENT(INOUT) :: matrix_in
270 : TYPE(dbt_tas_split_info), INTENT(IN) :: info
271 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
272 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
273 :
274 : INTEGER :: a, b, nblk_per_thread, nblkcols, nblkrows
275 : INTEGER, DIMENSION(2) :: pdims
276 378784 : INTEGER, DIMENSION(:), POINTER :: col_blk_size, col_dist, row_blk_size, &
277 189392 : row_dist
278 : TYPE(dbm_distribution_obj) :: dbm_dist
279 189392 : TYPE(dbt_tas_dist_arb), TARGET :: dir_dist
280 189392 : TYPE(dbt_tas_dist_repl), TARGET :: repl_dist
281 :
282 378784 : CLASS(dbt_tas_distribution), ALLOCATABLE :: col_dist_obj, row_dist_obj
283 378784 : CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: row_bsize_obj, col_bsize_obj
284 189392 : TYPE(dbt_tas_blk_size_repl), TARGET :: repl_blksize
285 189392 : TYPE(dbt_tas_blk_size_arb), TARGET :: dir_blksize
286 946960 : TYPE(dbt_tas_distribution_type) :: dist
287 : INTEGER :: numnodes, ngroup
288 189392 : INTEGER(kind=omp_lock_kind), ALLOCATABLE, DIMENSION(:) :: locks
289 189392 : TYPE(dbt_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
290 189392 : INTEGER, ALLOCATABLE, DIMENSION(:) :: num_blocks_recv, num_blocks_send, &
291 189392 : num_entries_recv, num_entries_send, &
292 189392 : num_rec, num_send
293 189392 : TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:, :) :: req_array
294 189392 : INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blks_to_allocate
295 : INTEGER, DIMENSION(2) :: blk_size
296 : INTEGER, DIMENSION(2) :: blk_index
297 : INTEGER(KIND=int_8), DIMENSION(2) :: blk_index_i8
298 : TYPE(dbm_iterator) :: iter
299 : INTEGER :: i, iproc, bcount, nblk
300 189392 : INTEGER, DIMENSION(:), ALLOCATABLE :: iprocs
301 : LOGICAL :: nodata_prv, move_prv
302 189392 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: index_recv
303 : INTEGER :: ndata
304 189392 : TYPE(mp_cart_type) :: mp_comm
305 :
306 189392 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: block
307 :
308 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_replicate'
309 :
310 : INTEGER :: handle, handle2
311 :
312 189392 : NULLIFY (col_blk_size, row_blk_size)
313 :
314 189392 : CALL timeset(routineN, handle)
315 :
316 189392 : IF (PRESENT(nodata)) THEN
317 55797 : nodata_prv = nodata
318 : ELSE
319 : nodata_prv = .FALSE.
320 : END IF
321 :
322 189392 : IF (PRESENT(move_data)) THEN
323 133595 : move_prv = move_data
324 : ELSE
325 : move_prv = .FALSE.
326 : END IF
327 :
328 189392 : row_blk_size => dbm_get_row_block_sizes(matrix_in)
329 189392 : col_blk_size => dbm_get_col_block_sizes(matrix_in)
330 189392 : nblkrows = SIZE(row_blk_size)
331 189392 : nblkcols = SIZE(col_blk_size)
332 189392 : dbm_dist = dbm_get_distribution(matrix_in)
333 189392 : row_dist => dbm_distribution_row_dist(dbm_dist)
334 189392 : col_dist => dbm_distribution_col_dist(dbm_dist)
335 :
336 189392 : mp_comm = info%mp_comm
337 189392 : ngroup = info%ngroup
338 :
339 189392 : numnodes = mp_comm%num_pe
340 568176 : pdims = mp_comm%num_pe_cart
341 :
342 327186 : SELECT CASE (info%split_rowcol)
343 : CASE (rowsplit)
344 137794 : repl_dist = dbt_tas_dist_repl(row_dist, pdims(1), nblkrows, info%ngroup, info%pgrid_split_size)
345 137794 : dir_dist = dbt_tas_dist_arb(col_dist, pdims(2), INT(nblkcols, KIND=int_8))
346 137794 : repl_blksize = dbt_tas_blk_size_repl(row_blk_size, info%ngroup)
347 137794 : dir_blksize = dbt_tas_blk_size_arb(col_blk_size)
348 137794 : ALLOCATE (row_dist_obj, source=repl_dist)
349 137794 : ALLOCATE (col_dist_obj, source=dir_dist)
350 137794 : ALLOCATE (row_bsize_obj, source=repl_blksize)
351 275588 : ALLOCATE (col_bsize_obj, source=dir_blksize)
352 : CASE (colsplit)
353 51598 : dir_dist = dbt_tas_dist_arb(row_dist, pdims(1), INT(nblkrows, KIND=int_8))
354 51598 : repl_dist = dbt_tas_dist_repl(col_dist, pdims(2), nblkcols, info%ngroup, info%pgrid_split_size)
355 51598 : dir_blksize = dbt_tas_blk_size_arb(row_blk_size)
356 51598 : repl_blksize = dbt_tas_blk_size_repl(col_blk_size, info%ngroup)
357 51598 : ALLOCATE (row_dist_obj, source=dir_dist)
358 51598 : ALLOCATE (col_dist_obj, source=repl_dist)
359 51598 : ALLOCATE (row_bsize_obj, source=dir_blksize)
360 895362 : ALLOCATE (col_bsize_obj, source=repl_blksize)
361 : END SELECT
362 :
363 189392 : CALL dbt_tas_distribution_new(dist, mp_comm, row_dist_obj, col_dist_obj, split_info=info)
364 : CALL dbt_tas_create(matrix_out, TRIM(dbm_get_name(matrix_in))//" replicated", &
365 189392 : dist, row_bsize_obj, col_bsize_obj, own_dist=.TRUE.)
366 :
367 189392 : IF (nodata_prv) THEN
368 55797 : CALL dbt_tas_finalize(matrix_out)
369 55797 : CALL timestop(handle)
370 55797 : RETURN
371 : END IF
372 :
373 607366 : ALLOCATE (buffer_send(0:numnodes - 1))
374 473771 : ALLOCATE (buffer_recv(0:numnodes - 1))
375 400785 : ALLOCATE (num_blocks_recv(0:numnodes - 1))
376 267190 : ALLOCATE (num_blocks_send(0:numnodes - 1))
377 267190 : ALLOCATE (num_entries_recv(0:numnodes - 1))
378 267190 : ALLOCATE (num_entries_send(0:numnodes - 1))
379 400785 : ALLOCATE (num_rec(0:2*numnodes - 1))
380 267190 : ALLOCATE (num_send(0:2*numnodes - 1))
381 546757 : num_send(:) = 0
382 1761489 : ALLOCATE (req_array(1:numnodes, 4))
383 267190 : ALLOCATE (locks(0:numnodes - 1))
384 340176 : DO iproc = 0, numnodes - 1
385 340176 : CALL omp_init_lock(locks(iproc))
386 : END DO
387 :
388 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,num_send,ngroup) &
389 133595 : !$OMP PRIVATE(iter,blk_index,blk_size,iprocs)
390 : ALLOCATE (iprocs(ngroup))
391 : CALL dbm_iterator_start(iter, matrix_in)
392 : DO WHILE (dbm_iterator_blocks_left(iter))
393 : CALL dbm_iterator_next_block(iter, blk_index(1), blk_index(2), &
394 : row_size=blk_size(1), col_size=blk_size(2))
395 : CALL dbt_repl_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iprocs)
396 : DO i = 1, SIZE(iprocs)
397 : !$OMP ATOMIC
398 : num_send(2*iprocs(i)) = num_send(2*iprocs(i)) + PRODUCT(blk_size)
399 : !$OMP ATOMIC
400 : num_send(2*iprocs(i) + 1) = num_send(2*iprocs(i) + 1) + 1
401 : END DO
402 : END DO
403 : CALL dbm_iterator_stop(iter)
404 : DEALLOCATE (iprocs)
405 : !$OMP END PARALLEL
406 :
407 133595 : CALL timeset(routineN//"_alltoall", handle2)
408 133595 : CALL mp_comm%alltoall(num_send, num_rec, 2)
409 133595 : CALL timestop(handle2)
410 :
411 340176 : DO iproc = 0, numnodes - 1
412 206581 : num_entries_recv(iproc) = num_rec(2*iproc)
413 206581 : num_blocks_recv(iproc) = num_rec(2*iproc + 1)
414 206581 : num_entries_send(iproc) = num_send(2*iproc)
415 206581 : num_blocks_send(iproc) = num_send(2*iproc + 1)
416 :
417 206581 : CALL dbt_buffer_create(buffer_send(iproc), num_blocks_send(iproc), num_entries_send(iproc))
418 :
419 340176 : CALL dbt_buffer_create(buffer_recv(iproc), num_blocks_recv(iproc), num_entries_recv(iproc))
420 :
421 : END DO
422 :
423 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,buffer_send,locks,ngroup) &
424 133595 : !$OMP PRIVATE(iter,blk_index,blk_size,block,iprocs)
425 : ALLOCATE (iprocs(ngroup))
426 : CALL dbm_iterator_start(iter, matrix_in)
427 : DO WHILE (dbm_iterator_blocks_left(iter))
428 : CALL dbm_iterator_next_block(iter, blk_index(1), blk_index(2), block, &
429 : row_size=blk_size(1), col_size=blk_size(2))
430 : CALL dbt_repl_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iprocs)
431 : DO i = 1, SIZE(iprocs)
432 : CALL omp_set_lock(locks(iprocs(i)))
433 : CALL dbt_buffer_add_block(buffer_send(iprocs(i)), INT(blk_index, KIND=int_8), block)
434 : CALL omp_unset_lock(locks(iprocs(i)))
435 : END DO
436 : END DO
437 : CALL dbm_iterator_stop(iter)
438 : DEALLOCATE (iprocs)
439 : !$OMP END PARALLEL
440 :
441 133595 : IF (move_prv) CALL dbm_clear(matrix_in)
442 :
443 133595 : CALL timeset(routineN//"_communicate_buffer", handle2)
444 133595 : CALL dbt_tas_communicate_buffer(mp_comm, buffer_recv, buffer_send, req_array)
445 :
446 340176 : DO iproc = 0, numnodes - 1
447 340176 : CALL dbt_buffer_destroy(buffer_send(iproc))
448 : END DO
449 :
450 133595 : CALL timestop(handle2)
451 :
452 : ! TODO Add OpenMP to the buffer unpacking.
453 340176 : nblk = SUM(num_blocks_recv)
454 400403 : ALLOCATE (blks_to_allocate(nblk, 2))
455 :
456 133595 : bcount = 0
457 340176 : DO iproc = 0, numnodes - 1
458 206581 : CALL dbt_buffer_get_index(buffer_recv(iproc), index_recv)
459 4754673 : blks_to_allocate(bcount + 1:bcount + SIZE(index_recv, 1), :) = INT(index_recv(:, :))
460 206581 : bcount = bcount + SIZE(index_recv, 1)
461 546757 : DEALLOCATE (index_recv)
462 : END DO
463 :
464 : !TODO: Parallelize creation of block list.
465 133595 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_out,nblk,blks_to_allocate) PRIVATE(nblk_per_thread,A,b)
466 : nblk_per_thread = nblk/omp_get_num_threads() + 1
467 : a = omp_get_thread_num()*nblk_per_thread + 1
468 : b = MIN(a + nblk_per_thread, nblk)
469 : CALL dbm_reserve_blocks(matrix_out%matrix, blks_to_allocate(a:b, 1), blks_to_allocate(a:b, 2))
470 : !$OMP END PARALLEL
471 133595 : DEALLOCATE (blks_to_allocate)
472 :
473 340176 : DO iproc = 0, numnodes - 1
474 : ! First, we need to get the index to create block
475 2274046 : DO WHILE (dbt_buffer_blocks_left(buffer_recv(iproc)))
476 2067465 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index_i8)
477 2067465 : CALL dbt_tas_blk_sizes(matrix_out, blk_index_i8(1), blk_index_i8(2), blk_size(1), blk_size(2))
478 8269860 : ALLOCATE (block(blk_size(1), blk_size(2)))
479 2067465 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index_i8, block)
480 307056257 : CALL dbm_put_block(matrix_out%matrix, INT(blk_index_i8(1)), INT(blk_index_i8(2)), block)
481 6408976 : DEALLOCATE (block)
482 : END DO
483 :
484 340176 : CALL dbt_buffer_destroy(buffer_recv(iproc))
485 : END DO
486 :
487 133595 : CALL dbt_tas_finalize(matrix_out)
488 :
489 133595 : CALL timestop(handle)
490 :
491 2195488 : END SUBROUTINE
492 :
493 : ! **************************************************************************************************
494 : !> \brief Merge submatrices of matrix_in to matrix_out by sum
495 : !> \param matrix_out ...
496 : !> \param matrix_in ...
497 : !> \param summation ...
498 : !> \param move_data memory optimization: move data to matrix_out such that matrix_in is empty on return
499 : !> \author Patrick Seewald
500 : ! **************************************************************************************************
501 55797 : SUBROUTINE dbt_tas_merge(matrix_out, matrix_in, summation, move_data)
502 : TYPE(dbm_type), INTENT(INOUT) :: matrix_out
503 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
504 : LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
505 :
506 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_merge'
507 :
508 : INTEGER :: a, b, bcount, handle, handle2, iproc, &
509 : nblk, nblk_per_thread, ndata, numnodes
510 55797 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: index_recv
511 : INTEGER(KIND=int_8), DIMENSION(2) :: blk_index_i8
512 : INTEGER(kind=omp_lock_kind), ALLOCATABLE, &
513 55797 : DIMENSION(:) :: locks
514 55797 : INTEGER, ALLOCATABLE, DIMENSION(:) :: num_blocks_recv, num_blocks_send, &
515 55797 : num_entries_recv, num_entries_send, &
516 55797 : num_rec, num_send
517 55797 : INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blks_to_allocate
518 : INTEGER, DIMENSION(2) :: blk_index, blk_size
519 111594 : INTEGER, DIMENSION(:), POINTER :: col_block_sizes, row_block_sizes
520 : LOGICAL :: move_prv
521 55797 : REAL(dp), DIMENSION(:, :), POINTER :: block
522 : TYPE(dbm_iterator) :: iter
523 55797 : TYPE(dbt_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
524 278985 : TYPE(dbt_tas_split_info) :: info
525 55797 : TYPE(mp_cart_type) :: mp_comm
526 : TYPE(mp_request_type), ALLOCATABLE, &
527 55797 : DIMENSION(:, :) :: req_array
528 :
529 : !!
530 :
531 55797 : CALL timeset(routineN, handle)
532 :
533 55797 : IF (PRESENT(summation)) THEN
534 0 : IF (.NOT. summation) CALL dbm_clear(matrix_out)
535 : ELSE
536 55797 : CALL dbm_clear(matrix_out)
537 : END IF
538 :
539 55797 : IF (PRESENT(move_data)) THEN
540 55797 : move_prv = move_data
541 : ELSE
542 : move_prv = .FALSE.
543 : END IF
544 :
545 55797 : info = dbt_tas_info(matrix_in)
546 55797 : CALL dbt_tas_get_split_info(info, mp_comm=mp_comm)
547 55797 : numnodes = mp_comm%num_pe
548 :
549 261996 : ALLOCATE (buffer_send(0:numnodes - 1))
550 206199 : ALLOCATE (buffer_recv(0:numnodes - 1))
551 167391 : ALLOCATE (num_blocks_recv(0:numnodes - 1))
552 111594 : ALLOCATE (num_blocks_send(0:numnodes - 1))
553 111594 : ALLOCATE (num_entries_recv(0:numnodes - 1))
554 111594 : ALLOCATE (num_entries_send(0:numnodes - 1))
555 167391 : ALLOCATE (num_rec(0:2*numnodes - 1))
556 111594 : ALLOCATE (num_send(0:2*numnodes - 1))
557 245007 : num_send(:) = 0
558 768999 : ALLOCATE (req_array(1:numnodes, 4))
559 111594 : ALLOCATE (locks(0:numnodes - 1))
560 150402 : DO iproc = 0, numnodes - 1
561 150402 : CALL omp_init_lock(locks(iproc))
562 : END DO
563 :
564 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,num_send) &
565 55797 : !$OMP PRIVATE(iter,blk_index,blk_size,iproc)
566 : CALL dbm_iterator_start(iter, matrix_in%matrix)
567 : DO WHILE (dbm_iterator_blocks_left(iter))
568 : CALL dbm_iterator_next_block(iter, blk_index(1), blk_index(2), &
569 : row_size=blk_size(1), col_size=blk_size(2))
570 : CALL dbm_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iproc)
571 : !$OMP ATOMIC
572 : num_send(2*iproc) = num_send(2*iproc) + PRODUCT(blk_size)
573 : !$OMP ATOMIC
574 : num_send(2*iproc + 1) = num_send(2*iproc + 1) + 1
575 : END DO
576 : CALL dbm_iterator_stop(iter)
577 : !$OMP END PARALLEL
578 :
579 55797 : CALL timeset(routineN//"_alltoall", handle2)
580 55797 : CALL mp_comm%alltoall(num_send, num_rec, 2)
581 55797 : CALL timestop(handle2)
582 :
583 150402 : DO iproc = 0, numnodes - 1
584 94605 : num_entries_recv(iproc) = num_rec(2*iproc)
585 94605 : num_blocks_recv(iproc) = num_rec(2*iproc + 1)
586 94605 : num_entries_send(iproc) = num_send(2*iproc)
587 94605 : num_blocks_send(iproc) = num_send(2*iproc + 1)
588 :
589 94605 : CALL dbt_buffer_create(buffer_send(iproc), num_blocks_send(iproc), num_entries_send(iproc))
590 :
591 150402 : CALL dbt_buffer_create(buffer_recv(iproc), num_blocks_recv(iproc), num_entries_recv(iproc))
592 :
593 : END DO
594 :
595 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out,buffer_send,locks) &
596 55797 : !$OMP PRIVATE(iter,blk_index,blk_size,block,iproc)
597 : CALL dbm_iterator_start(iter, matrix_in%matrix)
598 : DO WHILE (dbm_iterator_blocks_left(iter))
599 : CALL dbm_iterator_next_block(iter, blk_index(1), blk_index(2), block, &
600 : row_size=blk_size(1), col_size=blk_size(2))
601 : CALL dbm_get_stored_coordinates(matrix_out, blk_index(1), blk_index(2), iproc)
602 : CALL omp_set_lock(locks(iproc))
603 : CALL dbt_buffer_add_block(buffer_send(iproc), INT(blk_index, KIND=int_8), block)
604 : CALL omp_unset_lock(locks(iproc))
605 : END DO
606 : CALL dbm_iterator_stop(iter)
607 : !$OMP END PARALLEL
608 :
609 55797 : IF (move_prv) CALL dbt_tas_clear(matrix_in)
610 :
611 55797 : CALL timeset(routineN//"_communicate_buffer", handle2)
612 55797 : CALL dbt_tas_communicate_buffer(mp_comm, buffer_recv, buffer_send, req_array)
613 :
614 150402 : DO iproc = 0, numnodes - 1
615 150402 : CALL dbt_buffer_destroy(buffer_send(iproc))
616 : END DO
617 :
618 55797 : CALL timestop(handle2)
619 :
620 : ! TODO Add OpenMP to the buffer unpacking.
621 150402 : nblk = SUM(num_blocks_recv)
622 154926 : ALLOCATE (blks_to_allocate(nblk, 2))
623 :
624 55797 : bcount = 0
625 150402 : DO iproc = 0, numnodes - 1
626 94605 : CALL dbt_buffer_get_index(buffer_recv(iproc), index_recv)
627 1790513 : blks_to_allocate(bcount + 1:bcount + SIZE(index_recv, 1), :) = INT(index_recv(:, :))
628 94605 : bcount = bcount + SIZE(index_recv, 1)
629 245007 : DEALLOCATE (index_recv)
630 : END DO
631 :
632 : !TODO: Parallelize creation of block list.
633 55797 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_out,nblk,blks_to_allocate) PRIVATE(nblk_per_thread,A,b)
634 : nblk_per_thread = nblk/omp_get_num_threads() + 1
635 : a = omp_get_thread_num()*nblk_per_thread + 1
636 : b = MIN(a + nblk_per_thread, nblk)
637 : CALL dbm_reserve_blocks(matrix_out, blks_to_allocate(a:b, 1), blks_to_allocate(a:b, 2))
638 : !$OMP END PARALLEL
639 55797 : DEALLOCATE (blks_to_allocate)
640 :
641 150402 : DO iproc = 0, numnodes - 1
642 : ! First, we need to get the index to create block
643 847954 : DO WHILE (dbt_buffer_blocks_left(buffer_recv(iproc)))
644 753349 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index_i8)
645 753349 : row_block_sizes => dbm_get_row_block_sizes(matrix_out)
646 753349 : col_block_sizes => dbm_get_col_block_sizes(matrix_out)
647 753349 : blk_size(1) = row_block_sizes(INT(blk_index_i8(1)))
648 753349 : blk_size(2) = col_block_sizes(INT(blk_index_i8(2)))
649 3013396 : ALLOCATE (block(blk_size(1), blk_size(2)))
650 753349 : CALL dbt_buffer_get_next_block(buffer_recv(iproc), ndata, blk_index_i8, block)
651 77891904 : CALL dbm_put_block(matrix_out, INT(blk_index_i8(1)), INT(blk_index_i8(2)), block, summation=.TRUE.)
652 2354652 : DEALLOCATE (block)
653 : END DO
654 150402 : CALL dbt_buffer_destroy(buffer_recv(iproc))
655 : END DO
656 :
657 55797 : CALL dbm_finalize(matrix_out)
658 :
659 55797 : CALL timestop(handle)
660 468195 : END SUBROUTINE
661 :
662 : ! **************************************************************************************************
663 : !> \brief get all indices from buffer
664 : !> \param buffer ...
665 : !> \param index ...
666 : !> \author Patrick Seewald
667 : ! **************************************************************************************************
668 620238 : SUBROUTINE dbt_buffer_get_index(buffer, index)
669 : TYPE(dbt_buffer_type), INTENT(IN) :: buffer
670 : INTEGER(KIND=int_8), ALLOCATABLE, &
671 : DIMENSION(:, :), INTENT(OUT) :: index
672 :
673 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_buffer_get_index'
674 :
675 : INTEGER :: handle
676 : INTEGER, DIMENSION(2) :: indx_shape
677 :
678 620238 : CALL timeset(routineN, handle)
679 :
680 3721428 : indx_shape = SHAPE(buffer%indx) - [0, 1]
681 2311431 : ALLOCATE (INDEX(indx_shape(1), indx_shape(2)))
682 11964992 : INDEX(:, :) = buffer%indx(1:indx_shape(1), 1:indx_shape(2))
683 620238 : CALL timestop(handle)
684 620238 : END SUBROUTINE
685 :
686 : ! **************************************************************************************************
687 : !> \brief how many blocks left in iterator
688 : !> \param buffer ...
689 : !> \return ...
690 : !> \author Patrick Seewald
691 : ! **************************************************************************************************
692 5672377 : PURE FUNCTION dbt_buffer_blocks_left(buffer)
693 : TYPE(dbt_buffer_type), INTENT(IN) :: buffer
694 : LOGICAL :: dbt_buffer_blocks_left
695 :
696 5672377 : dbt_buffer_blocks_left = buffer%endpos .LT. buffer%nblock
697 5672377 : END FUNCTION
698 :
699 : ! **************************************************************************************************
700 : !> \brief Create block buffer for MPI communication.
701 : !> \param buffer block buffer
702 : !> \param nblock number of blocks
703 : !> \param ndata total number of block entries
704 : !> \author Patrick Seewald
705 : ! **************************************************************************************************
706 1240476 : SUBROUTINE dbt_buffer_create(buffer, nblock, ndata)
707 : TYPE(dbt_buffer_type), INTENT(OUT) :: buffer
708 : INTEGER, INTENT(IN) :: nblock, ndata
709 :
710 1240476 : buffer%nblock = nblock
711 1240476 : buffer%endpos = 0
712 3382386 : ALLOCATE (buffer%msg(ndata))
713 3382386 : ALLOCATE (buffer%indx(nblock, 3))
714 1240476 : END SUBROUTINE
715 :
716 : ! **************************************************************************************************
717 : !> \brief ...
718 : !> \param buffer ...
719 : !> \author Patrick Seewald
720 : ! **************************************************************************************************
721 1240476 : SUBROUTINE dbt_buffer_destroy(buffer)
722 : TYPE(dbt_buffer_type), INTENT(INOUT) :: buffer
723 :
724 1240476 : DEALLOCATE (buffer%msg)
725 1240476 : DEALLOCATE (buffer%indx)
726 1240476 : buffer%nblock = -1
727 1240476 : buffer%endpos = -1
728 1240476 : END SUBROUTINE dbt_buffer_destroy
729 :
730 : ! **************************************************************************************************
731 : !> \brief insert a block into block buffer (at current iterator position)
732 : !> \param buffer ...
733 : !> \param index index of block
734 : !> \param block ...
735 : !> \param transposed ...
736 : !> \author Patrick Seewald
737 : ! **************************************************************************************************
738 5052139 : SUBROUTINE dbt_buffer_add_block(buffer, index, block, transposed)
739 : TYPE(dbt_buffer_type), INTENT(INOUT) :: buffer
740 : INTEGER(KIND=int_8), DIMENSION(2), INTENT(IN) :: index
741 : REAL(dp), DIMENSION(:, :), INTENT(IN) :: block
742 : LOGICAL, INTENT(IN), OPTIONAL :: transposed
743 :
744 : INTEGER :: ndata, p, p_data
745 : INTEGER(KIND=int_8), DIMENSION(2) :: index_prv
746 : LOGICAL :: tr
747 :
748 5052139 : IF (PRESENT(transposed)) THEN
749 2231325 : tr = transposed
750 : ELSE
751 : tr = .FALSE.
752 : END IF
753 :
754 5052139 : index_prv(:) = INDEX(:)
755 5052139 : IF (tr) THEN
756 690043 : CALL swap(index_prv)
757 : END IF
758 15156417 : ndata = PRODUCT(SHAPE(block))
759 :
760 5052139 : p = buffer%endpos
761 5052139 : IF (p .EQ. 0) THEN
762 : p_data = 0
763 : ELSE
764 4601422 : p_data = INT(buffer%indx(p, 3))
765 : END IF
766 :
767 5052139 : IF (tr) THEN
768 78837305 : buffer%msg(p_data + 1:p_data + ndata) = RESHAPE(TRANSPOSE(block), [ndata])
769 : ELSE
770 639351348 : buffer%msg(p_data + 1:p_data + ndata) = RESHAPE(block, [ndata])
771 : END IF
772 :
773 15156417 : buffer%indx(p + 1, 1:2) = index_prv(:)
774 5052139 : IF (p > 0) THEN
775 4601422 : buffer%indx(p + 1, 3) = buffer%indx(p, 3) + INT(ndata, KIND=int_8)
776 : ELSE
777 450717 : buffer%indx(p + 1, 3) = INT(ndata, KIND=int_8)
778 : END IF
779 5052139 : buffer%endpos = buffer%endpos + 1
780 5052139 : END SUBROUTINE
781 :
782 : ! **************************************************************************************************
783 : !> \brief get next block from buffer. Iterator is advanced only if block is retrieved or advance_iter.
784 : !> \param buffer ...
785 : !> \param ndata ...
786 : !> \param index ...
787 : !> \param block ...
788 : !> \param advance_iter ...
789 : !> \author Patrick Seewald
790 : ! **************************************************************************************************
791 10104278 : SUBROUTINE dbt_buffer_get_next_block(buffer, ndata, index, block, advance_iter)
792 : TYPE(dbt_buffer_type), INTENT(INOUT) :: buffer
793 : INTEGER, INTENT(OUT) :: ndata
794 : INTEGER(KIND=int_8), DIMENSION(2), INTENT(OUT) :: index
795 : REAL(dp), DIMENSION(:, :), INTENT(OUT), OPTIONAL :: block
796 : LOGICAL, INTENT(IN), OPTIONAL :: advance_iter
797 :
798 : INTEGER :: p, p_data
799 : LOGICAL :: do_advance
800 :
801 10104278 : do_advance = .FALSE.
802 10104278 : IF (PRESENT(advance_iter)) THEN
803 0 : do_advance = advance_iter
804 10104278 : ELSE IF (PRESENT(block)) THEN
805 5052139 : do_advance = .TRUE.
806 : END IF
807 :
808 10104278 : p = buffer%endpos
809 10104278 : IF (p .EQ. 0) THEN
810 : p_data = 0
811 : ELSE
812 9202844 : p_data = INT(buffer%indx(p, 3))
813 : END IF
814 :
815 9202844 : IF (p > 0) THEN
816 9202844 : ndata = INT(buffer%indx(p + 1, 3) - buffer%indx(p, 3))
817 : ELSE
818 901434 : ndata = INT(buffer%indx(p + 1, 3))
819 : END IF
820 30312834 : INDEX(:) = buffer%indx(p + 1, 1:2)
821 :
822 10104278 : IF (PRESENT(block)) THEN
823 15156417 : block(:, :) = RESHAPE(buffer%msg(p_data + 1:p_data + ndata), SHAPE(block))
824 : END IF
825 :
826 10104278 : IF (do_advance) buffer%endpos = buffer%endpos + 1
827 10104278 : END SUBROUTINE
828 :
829 : ! **************************************************************************************************
830 : !> \brief communicate buffer
831 : !> \param mp_comm ...
832 : !> \param buffer_recv ...
833 : !> \param buffer_send ...
834 : !> \param req_array ...
835 : !> \author Patrick Seewald
836 : ! **************************************************************************************************
837 4452822 : SUBROUTINE dbt_tas_communicate_buffer(mp_comm, buffer_recv, buffer_send, req_array)
838 : CLASS(mp_comm_type), INTENT(IN) :: mp_comm
839 : TYPE(dbt_buffer_type), DIMENSION(0:), &
840 : INTENT(INOUT) :: buffer_recv, buffer_send
841 : TYPE(mp_request_type), DIMENSION(:, :), &
842 : INTENT(OUT) :: req_array
843 :
844 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_communicate_buffer'
845 :
846 : INTEGER :: handle, iproc, numnodes, &
847 : rec_counter, send_counter
848 :
849 394374 : CALL timeset(routineN, handle)
850 394374 : numnodes = mp_comm%num_pe
851 :
852 394374 : IF (numnodes > 1) THEN
853 :
854 225864 : send_counter = 0
855 225864 : rec_counter = 0
856 :
857 677592 : DO iproc = 0, numnodes - 1
858 677592 : IF (buffer_recv(iproc)%nblock > 0) THEN
859 309278 : rec_counter = rec_counter + 1
860 309278 : CALL mp_comm%irecv(buffer_recv(iproc)%indx, iproc, req_array(rec_counter, 3), tag=4)
861 309278 : CALL mp_comm%irecv(buffer_recv(iproc)%msg, iproc, req_array(rec_counter, 4), tag=7)
862 : END IF
863 : END DO
864 :
865 677592 : DO iproc = 0, numnodes - 1
866 677592 : IF (buffer_send(iproc)%nblock > 0) THEN
867 309278 : send_counter = send_counter + 1
868 309278 : CALL mp_comm%isend(buffer_send(iproc)%indx, iproc, req_array(send_counter, 1), tag=4)
869 309278 : CALL mp_comm%isend(buffer_send(iproc)%msg, iproc, req_array(send_counter, 2), tag=7)
870 : END IF
871 : END DO
872 :
873 225864 : IF (send_counter > 0) THEN
874 195381 : CALL mp_waitall(req_array(1:send_counter, 1:2))
875 : END IF
876 225864 : IF (rec_counter > 0) THEN
877 214920 : CALL mp_waitall(req_array(1:rec_counter, 3:4))
878 : END IF
879 :
880 : ELSE
881 168510 : IF (buffer_recv(0)%nblock > 0) THEN
882 4851364 : buffer_recv(0)%indx(:, :) = buffer_send(0)%indx(:, :)
883 427161238 : buffer_recv(0)%msg(:) = buffer_send(0)%msg(:)
884 : END IF
885 : END IF
886 394374 : CALL timestop(handle)
887 394374 : END SUBROUTINE
888 :
889 260779 : END MODULE
|