Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Routines to reshape / redistribute tensors
10 : !> \author Patrick Seewald
11 : ! **************************************************************************************************
12 : MODULE dbt_reshape_ops
13 : #:include "dbt_macros.fypp"
14 : #:set maxdim = maxrank
15 : #:set ndims = range(2,maxdim+1)
16 :
17 : USE dbt_allocate_wrap, ONLY: allocate_any
18 : USE dbt_tas_base, ONLY: dbt_tas_copy, dbt_tas_get_info, dbt_tas_info
19 : USE dbt_block, ONLY: &
20 : block_nd, create_block, destroy_block, dbt_iterator_type, dbt_iterator_next_block, &
21 : dbt_iterator_blocks_left, dbt_iterator_start, dbt_iterator_stop, dbt_get_block, &
22 : dbt_reserve_blocks, dbt_put_block
23 : USE dbt_types, ONLY: dbt_blk_sizes, &
24 : dbt_create, &
25 : dbt_type, &
26 : ndims_tensor, &
27 : dbt_get_stored_coordinates, &
28 : dbt_clear
29 : USE kinds, ONLY: default_string_length
30 : USE kinds, ONLY: dp, dp
31 : USE message_passing, ONLY: &
32 : mp_waitall, mp_comm_type, mp_request_type
33 :
34 : #include "../base/base_uses.f90"
35 :
36 : IMPLICIT NONE
37 : PRIVATE
38 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_reshape_ops'
39 :
40 : PUBLIC :: dbt_reshape
41 :
42 : TYPE block_buffer_type
43 : INTEGER, DIMENSION(:, :), ALLOCATABLE :: blocks
44 : REAL(dp), DIMENSION(:), ALLOCATABLE :: data
45 : END TYPE
46 :
47 : CONTAINS
48 :
49 : ! **************************************************************************************************
50 : !> \brief copy data (involves reshape)
51 : !> tensor_out = tensor_out + tensor_in move_data memory optimization:
52 : !> transfer data from tensor_in to tensor_out s.t. tensor_in is empty on return
53 : !> \author Ole Schuett
54 : ! **************************************************************************************************
55 191581 : SUBROUTINE dbt_reshape(tensor_in, tensor_out, summation, move_data)
56 :
57 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in, tensor_out
58 : LOGICAL, INTENT(IN), OPTIONAL :: summation
59 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
60 :
61 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_reshape'
62 :
63 : INTEGER :: iproc, numnodes, &
64 : handle, iblk, jblk, offset, ndata, &
65 : nblks_recv_mythread
66 191581 : INTEGER, ALLOCATABLE, DIMENSION(:, :) :: blks_to_allocate
67 : TYPE(dbt_iterator_type) :: iter
68 191581 : TYPE(block_nd) :: blk_data
69 191581 : TYPE(block_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
70 191581 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: blk_size, ind_nd
71 : LOGICAL :: found, summation_prv, move_prv
72 :
73 191581 : INTEGER, ALLOCATABLE, DIMENSION(:) :: nblks_send_total, ndata_send_total, &
74 191581 : nblks_recv_total, ndata_recv_total, &
75 191581 : nblks_send_mythread, ndata_send_mythread
76 : TYPE(mp_comm_type) :: mp_comm
77 :
78 191581 : CALL timeset(routineN, handle)
79 :
80 191581 : IF (PRESENT(summation)) THEN
81 65453 : summation_prv = summation
82 : ELSE
83 : summation_prv = .FALSE.
84 : END IF
85 :
86 191581 : IF (PRESENT(move_data)) THEN
87 191581 : move_prv = move_data
88 : ELSE
89 : move_prv = .FALSE.
90 : END IF
91 :
92 191581 : CPASSERT(tensor_out%valid)
93 :
94 191581 : IF (.NOT. summation_prv) CALL dbt_clear(tensor_out)
95 :
96 191581 : mp_comm = tensor_in%pgrid%mp_comm_2d
97 191581 : numnodes = mp_comm%num_pe
98 1473318 : ALLOCATE (buffer_send(0:numnodes - 1), buffer_recv(0:numnodes - 1))
99 1473318 : ALLOCATE (nblks_send_total(0:numnodes - 1), ndata_send_total(0:numnodes - 1), source=0)
100 1281737 : ALLOCATE (nblks_recv_total(0:numnodes - 1), ndata_recv_total(0:numnodes - 1), source=0)
101 :
102 : !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
103 : !$OMP SHARED(tensor_in,tensor_out,summation) &
104 : !$OMP SHARED(buffer_send,buffer_recv,mp_comm,numnodes) &
105 : !$OMP SHARED(nblks_send_total,ndata_send_total,nblks_recv_total,ndata_recv_total) &
106 : !$OMP PRIVATE(nblks_send_mythread,ndata_send_mythread,nblks_recv_mythread) &
107 : !$OMP PRIVATE(iter,ind_nd,blk_size,blk_data,found,iproc) &
108 191581 : !$OMP PRIVATE(blks_to_allocate,offset,ndata,iblk,jblk)
109 : ALLOCATE (nblks_send_mythread(0:numnodes - 1), ndata_send_mythread(0:numnodes - 1), source=0)
110 :
111 : CALL dbt_iterator_start(iter, tensor_in)
112 : DO WHILE (dbt_iterator_blocks_left(iter))
113 : CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
114 : CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
115 : nblks_send_mythread(iproc) = nblks_send_mythread(iproc) + 1
116 : ndata_send_mythread(iproc) = ndata_send_mythread(iproc) + PRODUCT(blk_size)
117 : END DO
118 : CALL dbt_iterator_stop(iter)
119 : !$OMP CRITICAL
120 : nblks_send_total(:) = nblks_send_total(:) + nblks_send_mythread(:)
121 : ndata_send_total(:) = ndata_send_total(:) + ndata_send_mythread(:)
122 : nblks_send_mythread(:) = nblks_send_total(:) ! current totals indicate slot for this thread
123 : ndata_send_mythread(:) = ndata_send_total(:)
124 : !$OMP END CRITICAL
125 : !$OMP BARRIER
126 :
127 : !$OMP MASTER
128 : CALL mp_comm%alltoall(nblks_send_total, nblks_recv_total, 1)
129 : CALL mp_comm%alltoall(ndata_send_total, ndata_recv_total, 1)
130 : !$OMP END MASTER
131 : !$OMP BARRIER
132 :
133 : !$OMP DO
134 : DO iproc = 0, numnodes - 1
135 : ALLOCATE (buffer_send(iproc)%data(ndata_send_total(iproc)))
136 : ALLOCATE (buffer_recv(iproc)%data(ndata_recv_total(iproc)))
137 : ! going to use buffer%blocks(:,0) to store data offsets
138 : ALLOCATE (buffer_send(iproc)%blocks(nblks_send_total(iproc), 0:ndims_tensor(tensor_in)))
139 : ALLOCATE (buffer_recv(iproc)%blocks(nblks_recv_total(iproc), 0:ndims_tensor(tensor_in)))
140 : END DO
141 : !$OMP END DO
142 : !$OMP BARRIER
143 :
144 : CALL dbt_iterator_start(iter, tensor_in)
145 : DO WHILE (dbt_iterator_blocks_left(iter))
146 : CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
147 : CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
148 : CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
149 : CPASSERT(found)
150 : ! insert block data
151 : ndata = PRODUCT(blk_size)
152 : ndata_send_mythread(iproc) = ndata_send_mythread(iproc) - ndata
153 : offset = ndata_send_mythread(iproc)
154 : buffer_send(iproc)%data(offset + 1:offset + ndata) = blk_data%blk(:)
155 : ! insert block index
156 : nblks_send_mythread(iproc) = nblks_send_mythread(iproc) - 1
157 : iblk = nblks_send_mythread(iproc) + 1
158 : buffer_send(iproc)%blocks(iblk, 1:) = ind_nd(:)
159 : buffer_send(iproc)%blocks(iblk, 0) = offset
160 : CALL destroy_block(blk_data)
161 : END DO
162 : CALL dbt_iterator_stop(iter)
163 : !$OMP BARRIER
164 :
165 : CALL dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
166 : !$OMP BARRIER
167 :
168 : !$OMP DO
169 : DO iproc = 0, numnodes - 1
170 : DEALLOCATE (buffer_send(iproc)%blocks, buffer_send(iproc)%data)
171 : END DO
172 : !$OMP END DO
173 :
174 : nblks_recv_mythread = 0
175 : DO iproc = 0, numnodes - 1
176 : !$OMP DO
177 : DO iblk = 1, nblks_recv_total(iproc)
178 : nblks_recv_mythread = nblks_recv_mythread + 1
179 : END DO
180 : !$OMP END DO
181 : END DO
182 : ALLOCATE (blks_to_allocate(nblks_recv_mythread, ndims_tensor(tensor_in)))
183 :
184 : jblk = 0
185 : DO iproc = 0, numnodes - 1
186 : !$OMP DO
187 : DO iblk = 1, nblks_recv_total(iproc)
188 : jblk = jblk + 1
189 : blks_to_allocate(jblk, :) = buffer_recv(iproc)%blocks(iblk, 1:)
190 : END DO
191 : !$OMP END DO
192 : END DO
193 : CPASSERT(jblk == nblks_recv_mythread)
194 : CALL dbt_reserve_blocks(tensor_out, blks_to_allocate)
195 : DEALLOCATE (blks_to_allocate)
196 :
197 : DO iproc = 0, numnodes - 1
198 : !$OMP DO
199 : DO iblk = 1, nblks_recv_total(iproc)
200 : ind_nd(:) = buffer_recv(iproc)%blocks(iblk, 1:)
201 : CALL dbt_blk_sizes(tensor_out, ind_nd, blk_size)
202 : offset = buffer_recv(iproc)%blocks(iblk, 0)
203 : ndata = PRODUCT(blk_size)
204 : CALL create_block(blk_data, blk_size, &
205 : array=buffer_recv(iproc)%data(offset + 1:offset + ndata))
206 : CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
207 : CALL destroy_block(blk_data)
208 : END DO
209 : !$OMP END DO
210 : END DO
211 :
212 : !$OMP DO
213 : DO iproc = 0, numnodes - 1
214 : DEALLOCATE (buffer_recv(iproc)%blocks, buffer_recv(iproc)%data)
215 : END DO
216 : !$OMP END DO
217 : !$OMP END PARALLEL
218 :
219 191581 : DEALLOCATE (nblks_recv_total, ndata_recv_total)
220 191581 : DEALLOCATE (nblks_send_total, ndata_send_total)
221 898575 : DEALLOCATE (buffer_send, buffer_recv)
222 :
223 191581 : IF (move_prv) CALL dbt_clear(tensor_in)
224 :
225 191581 : CALL timestop(handle)
226 383162 : END SUBROUTINE dbt_reshape
227 :
228 : ! **************************************************************************************************
229 : !> \brief communicate buffer
230 : !> \author Patrick Seewald
231 : ! **************************************************************************************************
232 191581 : SUBROUTINE dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
233 : TYPE(mp_comm_type), INTENT(IN) :: mp_comm
234 : TYPE(block_buffer_type), DIMENSION(0:), INTENT(INOUT) :: buffer_recv, buffer_send
235 :
236 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_communicate_buffer'
237 :
238 : INTEGER :: iproc, numnodes, &
239 : rec_counter, send_counter, i
240 191581 : TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:, :) :: req_array
241 : INTEGER :: handle
242 :
243 191581 : CALL timeset(routineN, handle)
244 191581 : numnodes = mp_comm%num_pe
245 :
246 191581 : IF (numnodes > 1) THEN
247 161916 : !$OMP MASTER
248 161916 : send_counter = 0
249 161916 : rec_counter = 0
250 :
251 2428740 : ALLOCATE (req_array(1:numnodes, 4))
252 :
253 485748 : DO iproc = 0, numnodes - 1
254 1133412 : IF (SIZE(buffer_recv(iproc)%blocks) > 0) THEN
255 204151 : rec_counter = rec_counter + 1
256 204151 : CALL mp_comm%irecv(buffer_recv(iproc)%blocks, iproc, req_array(rec_counter, 3), tag=4)
257 204151 : CALL mp_comm%irecv(buffer_recv(iproc)%data, iproc, req_array(rec_counter, 4), tag=7)
258 : END IF
259 : END DO
260 :
261 485748 : DO iproc = 0, numnodes - 1
262 1133412 : IF (SIZE(buffer_send(iproc)%blocks) > 0) THEN
263 204151 : send_counter = send_counter + 1
264 204151 : CALL mp_comm%isend(buffer_send(iproc)%blocks, iproc, req_array(send_counter, 1), tag=4)
265 204151 : CALL mp_comm%isend(buffer_send(iproc)%data, iproc, req_array(send_counter, 2), tag=7)
266 : END IF
267 : END DO
268 :
269 161916 : IF (send_counter > 0) THEN
270 143058 : CALL mp_waitall(req_array(1:send_counter, 1:2))
271 : END IF
272 161916 : IF (rec_counter > 0) THEN
273 135722 : CALL mp_waitall(req_array(1:rec_counter, 3:4))
274 : END IF
275 : !$OMP END MASTER
276 :
277 : ELSE
278 29665 : !$OMP DO SCHEDULE(static, 512)
279 : DO i = 1, SIZE(buffer_send(0)%blocks, 1)
280 3878925 : buffer_recv(0)%blocks(i, :) = buffer_send(0)%blocks(i, :)
281 : END DO
282 : !$OMP END DO
283 29665 : !$OMP DO SCHEDULE(static, 512)
284 : DO i = 1, SIZE(buffer_send(0)%data)
285 412866138 : buffer_recv(0)%data(i) = buffer_send(0)%data(i)
286 : END DO
287 : !$OMP END DO
288 : END IF
289 191581 : CALL timestop(handle)
290 :
291 191581 : END SUBROUTINE dbt_communicate_buffer
292 :
293 0 : END MODULE dbt_reshape_ops
|