Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Matrix multiplication for tall-and-skinny matrices.
10 : !> This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
11 : !> as long as the two smaller dimensions have the same size.
12 : !> Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
13 : !> submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
14 : !> parameters (group sizes and process grid dimensions) can not be derived from matrix
15 : !> dimensions and need to be set manually.
16 : !> \author Patrick Seewald
17 : ! **************************************************************************************************
18 : MODULE dbt_tas_mm
19 : USE dbm_api, ONLY: &
20 : dbm_add, dbm_clear, dbm_copy, dbm_create, dbm_create_from_template, dbm_distribution_new, &
21 : dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
22 : dbm_get_distribution, dbm_get_name, dbm_get_nze, dbm_get_row_block_sizes, dbm_multiply, &
23 : dbm_redistribute, dbm_release, dbm_scale, dbm_type, dbm_zero
24 : USE dbt_tas_base, ONLY: &
25 : dbt_tas_clear, dbt_tas_copy, dbt_tas_create, dbt_tas_destroy, dbt_tas_distribution_new, &
26 : dbt_tas_filter, dbt_tas_get_info, dbt_tas_get_nze_total, dbt_tas_info, &
27 : dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
28 : dbt_tas_iterator_stop, dbt_tas_nblkcols_total, dbt_tas_nblkrows_total, dbt_tas_put_block, &
29 : dbt_tas_reserve_blocks
30 : USE dbt_tas_global, ONLY: dbt_tas_blk_size_one,&
31 : dbt_tas_default_distvec,&
32 : dbt_tas_dist_arb,&
33 : dbt_tas_dist_arb_default,&
34 : dbt_tas_dist_cyclic,&
35 : dbt_tas_distribution,&
36 : dbt_tas_rowcol_data
37 : USE dbt_tas_io, ONLY: dbt_tas_write_dist,&
38 : dbt_tas_write_matrix_info,&
39 : dbt_tas_write_split_info,&
40 : prep_output_unit
41 : USE dbt_tas_reshape_ops, ONLY: dbt_tas_merge,&
42 : dbt_tas_replicate,&
43 : dbt_tas_reshape
44 : USE dbt_tas_split, ONLY: &
45 : accept_pgrid_dims, colsplit, dbt_tas_create_split, dbt_tas_get_split_info, &
46 : dbt_tas_info_hold, dbt_tas_mp_comm, dbt_tas_release_info, default_nsplit_accept_ratio, &
47 : rowsplit
48 : USE dbt_tas_types, ONLY: dbt_tas_distribution_type,&
49 : dbt_tas_iterator,&
50 : dbt_tas_split_info,&
51 : dbt_tas_type
52 : USE dbt_tas_util, ONLY: array_eq,&
53 : swap
54 : USE kinds, ONLY: default_string_length,&
55 : dp,&
56 : int_8
57 : USE message_passing, ONLY: mp_cart_type
58 : #include "../../base/base_uses.f90"
59 :
60 : IMPLICIT NONE
61 : PRIVATE
62 :
63 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
64 :
65 : PUBLIC :: &
66 : dbt_tas_multiply, &
67 : dbt_tas_batched_mm_init, &
68 : dbt_tas_batched_mm_finalize, &
69 : dbt_tas_set_batched_state, &
70 : dbt_tas_batched_mm_complete
71 :
72 : CONTAINS
73 :
74 : ! **************************************************************************************************
75 : !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
76 : !> to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
77 : !> \param transa ...
78 : !> \param transb ...
79 : !> \param transc ...
80 : !> \param alpha ...
81 : !> \param matrix_a ...
82 : !> \param matrix_b ...
83 : !> \param beta ...
84 : !> \param matrix_c ...
85 : !> \param optimize_dist Whether distribution should be optimized internally. In the current
86 : !> implementation this guarantees optimal parameters only for dense matrices.
87 : !> \param split_opt optionally return split info containing optimal grid and split parameters.
88 : !> This can be used to choose optimal process grids for subsequent matrix
89 : !> multiplications with matrices of similar shape and sparsity.
90 : !> \param filter_eps ...
91 : !> \param flop ...
92 : !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
93 : !> (for internal use only)
94 : !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
95 : !> (for internal use only)
96 : !> \param retain_sparsity ...
97 : !> \param simple_split ...
98 : !> \param unit_nr unit number for logging output
99 : !> \param log_verbose only for testing: verbose output
100 : !> \author Patrick Seewald
101 : ! **************************************************************************************************
102 855780 : RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103 : optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104 : move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
105 :
106 : LOGICAL, INTENT(IN) :: transa, transb, transc
107 : REAL(dp), INTENT(IN) :: alpha
108 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b
109 : REAL(dp), INTENT(IN) :: beta
110 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_c
111 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
112 : TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL :: split_opt
113 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
114 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
115 : LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, &
116 : retain_sparsity, simple_split
117 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
118 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
119 :
120 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_multiply'
121 :
122 : INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123 : nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
124 : unit_nr_prv
125 : INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126 : INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c
127 : INTEGER(KIND=int_8), DIMENSION(3) :: dims
128 : INTEGER, DIMENSION(2) :: pdims, pdims_sub
129 : LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130 : simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131 : REAL(KIND=dp) :: filter_eps_prv
132 : TYPE(dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
133 3450405 : TYPE(dbt_tas_split_info) :: info, info_a, info_b, info_c
134 : TYPE(dbt_tas_type), POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135 : matrix_b_rs, matrix_c_rep, matrix_c_rs
136 202965 : TYPE(mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137 202965 : mp_comm_mm, mp_comm_opt
138 :
139 202965 : CALL timeset(routineN, handle)
140 202965 : CALL matrix_a%dist%info%mp_comm%sync()
141 202965 : CALL timeset("dbt_tas_total", handle2)
142 :
143 202965 : NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
144 :
145 202965 : unit_nr_prv = prep_output_unit(unit_nr)
146 :
147 202965 : IF (PRESENT(simple_split)) THEN
148 60287 : simple_split_prv = simple_split
149 : ELSE
150 142678 : simple_split_prv = .FALSE.
151 :
152 428034 : info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
153 142678 : IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .TRUE.
154 : END IF
155 :
156 202965 : nodata_3 = .TRUE.
157 202965 : IF (PRESENT(retain_sparsity)) THEN
158 4762 : IF (retain_sparsity) nodata_3 = .FALSE.
159 : END IF
160 :
161 : ! get prestored info for multiplication strategy in case of batched mm
162 202965 : batched_repl = 0
163 202965 : do_batched = .FALSE.
164 202965 : IF (matrix_a%do_batched > 0) THEN
165 40898 : do_batched = .TRUE.
166 40898 : IF (matrix_a%do_batched == 3) THEN
167 : CPASSERT(batched_repl == 0)
168 13633 : batched_repl = 1
169 : CALL dbt_tas_get_split_info( &
170 : dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
171 13633 : nsplit=nsplit_batched)
172 13633 : CPASSERT(nsplit_batched > 0)
173 : max_mm_dim_batched = 3
174 : END IF
175 : END IF
176 :
177 202965 : IF (matrix_b%do_batched > 0) THEN
178 14380 : do_batched = .TRUE.
179 14380 : IF (matrix_b%do_batched == 3) THEN
180 2960 : CPASSERT(batched_repl == 0)
181 2960 : batched_repl = 2
182 : CALL dbt_tas_get_split_info( &
183 : dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
184 2960 : nsplit=nsplit_batched)
185 2960 : CPASSERT(nsplit_batched > 0)
186 : max_mm_dim_batched = 1
187 : END IF
188 : END IF
189 :
190 202965 : IF (matrix_c%do_batched > 0) THEN
191 32389 : do_batched = .TRUE.
192 32389 : IF (matrix_c%do_batched == 3) THEN
193 6328 : CPASSERT(batched_repl == 0)
194 6328 : batched_repl = 3
195 : CALL dbt_tas_get_split_info( &
196 : dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
197 6328 : nsplit=nsplit_batched)
198 6328 : CPASSERT(nsplit_batched > 0)
199 : max_mm_dim_batched = 2
200 : END IF
201 : END IF
202 :
203 202965 : move_a = .FALSE.
204 202965 : move_b = .FALSE.
205 :
206 202965 : IF (PRESENT(move_data_a)) move_a = move_data_a
207 202965 : IF (PRESENT(move_data_b)) move_b = move_data_b
208 :
209 202965 : transa_prv = transa; transb_prv = transb; transc_prv = transc
210 :
211 608895 : dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
212 608895 : dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
213 608895 : dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
214 :
215 202965 : IF (unit_nr_prv > 0) THEN
216 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
217 : WRITE (unit_nr_prv, "(A)") &
218 : "DBT TAS MATRIX MULTIPLICATION: "// &
219 : TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
220 : TRIM(dbm_get_name(matrix_b%matrix))//" = "// &
221 34 : TRIM(dbm_get_name(matrix_c%matrix))
222 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
223 : END IF
224 202965 : IF (do_batched) THEN
225 85227 : IF (unit_nr_prv > 0) THEN
226 : WRITE (unit_nr_prv, "(T2,A)") &
227 0 : "BATCHED PROCESSING OF MATMUL"
228 0 : IF (batched_repl > 0) THEN
229 0 : WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
230 : END IF
231 : END IF
232 : END IF
233 :
234 202965 : IF (transa_prv) THEN
235 61838 : CALL swap(dims_a)
236 : END IF
237 :
238 202965 : IF (transb_prv) THEN
239 103337 : CALL swap(dims_b)
240 : END IF
241 :
242 608895 : dims_c = [dims_a(1), dims_b(2)]
243 :
244 202965 : IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
245 0 : CPABORT("inconsistent matrix dimensions")
246 : END IF
247 :
248 811860 : dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
249 :
250 202965 : IF (unit_nr_prv > 0) THEN
251 34 : WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
252 : END IF
253 :
254 202965 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
255 202965 : numproc = mp_comm%num_pe
256 :
257 : ! derive optimal matrix layout and split factor from occupancies
258 202965 : nze_a = dbt_tas_get_nze_total(matrix_a)
259 202965 : nze_b = dbt_tas_get_nze_total(matrix_b)
260 :
261 202965 : IF (.NOT. simple_split_prv) THEN
262 : CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263 : estimated_nze=nze_c, filter_eps=filter_eps, &
264 60403 : retain_sparsity=retain_sparsity)
265 :
266 241612 : max_mm_dim = MAXLOC(dims, 1)
267 60403 : nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
268 60403 : nsplit_opt = nsplit
269 :
270 60403 : IF (unit_nr_prv > 0) THEN
271 : WRITE (unit_nr_prv, "(T2,A)") &
272 34 : "MM PARAMETERS"
273 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
274 68 : (nze_c + numproc - 1)/numproc
275 :
276 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
277 : END IF
278 :
279 142562 : ELSEIF (batched_repl > 0) THEN
280 22921 : nsplit = nsplit_batched
281 22921 : nsplit_opt = nsplit
282 22921 : max_mm_dim = max_mm_dim_batched
283 22921 : IF (unit_nr_prv > 0) THEN
284 : WRITE (unit_nr_prv, "(T2,A)") &
285 0 : "MM PARAMETERS"
286 0 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
287 : END IF
288 :
289 : ELSE
290 119641 : nsplit = 0
291 478564 : max_mm_dim = MAXLOC(dims, 1)
292 : END IF
293 :
294 : ! reshape matrices to the optimal layout and split factor
295 202965 : split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
296 61882 : SELECT CASE (max_mm_dim)
297 : CASE (1)
298 :
299 : split_a = rowsplit; split_c = rowsplit
300 : CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301 : new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
302 : nsplit=nsplit, &
303 : opt_nsplit=batched_repl == 0, &
304 : split_rc_1=split_a, split_rc_2=split_c, &
305 : nodata2=nodata_3, comm_new=comm_tmp, &
306 61882 : move_data_1=move_a, unit_nr=unit_nr_prv)
307 :
308 61882 : info = dbt_tas_info(matrix_a_rs)
309 61882 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
310 :
311 61882 : new_b = .FALSE.
312 61882 : IF (matrix_b%do_batched <= 2) THEN
313 294610 : ALLOCATE (matrix_b_rs)
314 58922 : CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
315 58922 : transb_prv = .FALSE.
316 58922 : new_b = .TRUE.
317 : END IF
318 :
319 61882 : tr_case = transa_prv
320 :
321 123775 : IF (unit_nr_prv > 0) THEN
322 11 : IF (.NOT. tr_case) THEN
323 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
324 : ELSE
325 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
326 : END IF
327 : END IF
328 :
329 : CASE (2)
330 :
331 61157 : split_a = colsplit; split_b = rowsplit
332 : CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333 : optimize_dist=optimize_dist, &
334 : nsplit=nsplit, &
335 : opt_nsplit=batched_repl == 0, &
336 : split_rc_1=split_a, split_rc_2=split_b, &
337 : comm_new=comm_tmp, &
338 61157 : move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
339 :
340 61157 : info = dbt_tas_info(matrix_a_rs)
341 61157 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
342 :
343 61157 : IF (matrix_c%do_batched == 1) THEN
344 24623 : matrix_c%mm_storage%batched_beta = beta
345 36534 : ELSEIF (matrix_c%do_batched > 1) THEN
346 7484 : matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
347 : END IF
348 :
349 61157 : IF (matrix_c%do_batched <= 2) THEN
350 274145 : ALLOCATE (matrix_c_rs)
351 54829 : CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
352 54829 : transc_prv = .FALSE.
353 :
354 : ! just leave sparsity structure for retain sparsity but no values
355 54829 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
356 :
357 54829 : IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358 6328 : ELSEIF (matrix_c%do_batched == 3) THEN
359 6328 : matrix_c_rs => matrix_c%mm_storage%store_batched
360 : END IF
361 :
362 61157 : new_c = matrix_c%do_batched == 0
363 61157 : tr_case = transa_prv
364 :
365 122327 : IF (unit_nr_prv > 0) THEN
366 13 : IF (.NOT. tr_case) THEN
367 2 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
368 : ELSE
369 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
370 : END IF
371 : END IF
372 :
373 : CASE (3)
374 :
375 79926 : split_b = colsplit; split_c = colsplit
376 : CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377 : transc_prv, optimize_dist=optimize_dist, &
378 : nsplit=nsplit, &
379 : opt_nsplit=batched_repl == 0, &
380 : split_rc_1=split_b, split_rc_2=split_c, &
381 : nodata2=nodata_3, comm_new=comm_tmp, &
382 79926 : move_data_1=move_b, unit_nr=unit_nr_prv)
383 79926 : info = dbt_tas_info(matrix_b_rs)
384 79926 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
385 :
386 79926 : new_a = .FALSE.
387 79926 : IF (matrix_a%do_batched <= 2) THEN
388 331465 : ALLOCATE (matrix_a_rs)
389 66293 : CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
390 66293 : transa_prv = .FALSE.
391 66293 : new_a = .TRUE.
392 : END IF
393 :
394 79926 : tr_case = transb_prv
395 :
396 362817 : IF (unit_nr_prv > 0) THEN
397 10 : IF (.NOT. tr_case) THEN
398 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
399 : ELSE
400 10 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
401 : END IF
402 : END IF
403 :
404 : END SELECT
405 :
406 202965 : CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
407 :
408 202965 : numproc = mp_comm%num_pe
409 608895 : pdims_sub = mp_comm_group%num_pe_cart
410 :
411 202965 : opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
412 :
413 202965 : IF (PRESENT(filter_eps)) THEN
414 152715 : filter_eps_prv = filter_eps
415 : ELSE
416 50250 : filter_eps_prv = 0.0_dp
417 : END IF
418 :
419 202965 : IF (unit_nr_prv /= 0) THEN
420 46890 : IF (unit_nr_prv > 0) THEN
421 34 : WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
422 : END IF
423 46890 : CALL dbt_tas_write_split_info(info, unit_nr_prv)
424 46890 : IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
425 46890 : IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
426 46890 : IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
427 46890 : IF (unit_nr_prv > 0) THEN
428 34 : IF (opt_pgrid) THEN
429 0 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
430 : ELSE
431 34 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
432 : END IF
433 : END IF
434 : END IF
435 :
436 202965 : pdims = 0
437 202965 : CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
438 :
439 : ! Convert DBM submatrices to optimized process grids and multiply
440 61882 : SELECT CASE (max_mm_dim)
441 : CASE (1)
442 61882 : IF (matrix_b%do_batched <= 2) THEN
443 294610 : ALLOCATE (matrix_b_rep)
444 58922 : CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
445 58922 : IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
446 8478 : matrix_b%mm_storage%store_batched_repl => matrix_b_rep
447 8478 : CALL dbt_tas_set_batched_state(matrix_b, state=3)
448 : END IF
449 2960 : ELSEIF (matrix_b%do_batched == 3) THEN
450 2960 : matrix_b_rep => matrix_b%mm_storage%store_batched_repl
451 : END IF
452 :
453 61882 : IF (new_b) THEN
454 58922 : CALL dbt_tas_destroy(matrix_b_rs)
455 58922 : DEALLOCATE (matrix_b_rs)
456 : END IF
457 61882 : IF (unit_nr_prv /= 0) THEN
458 418 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
459 418 : CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
460 : END IF
461 :
462 61882 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
463 :
464 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
465 61882 : info_a = dbt_tas_info(matrix_a_rs)
466 61882 : CALL dbt_tas_info_hold(info_a)
467 :
468 61882 : IF (new_a) THEN
469 5608 : CALL dbt_tas_destroy(matrix_a_rs)
470 5608 : DEALLOCATE (matrix_a_rs)
471 : END IF
472 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473 61882 : move_data=matrix_b%do_batched == 0)
474 :
475 61882 : info_b = dbt_tas_info(matrix_b_rep)
476 61882 : CALL dbt_tas_info_hold(info_b)
477 :
478 61882 : IF (matrix_b%do_batched == 0) THEN
479 50444 : CALL dbt_tas_destroy(matrix_b_rep)
480 50444 : DEALLOCATE (matrix_b_rep)
481 : END IF
482 :
483 61882 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
484 :
485 61882 : info_c = dbt_tas_info(matrix_c_rs)
486 61882 : CALL dbt_tas_info_hold(info_c)
487 :
488 61882 : CALL matrix_a%dist%info%mp_comm%sync()
489 61882 : CALL timeset("dbt_tas_dbm", handle4)
490 61882 : IF (.NOT. tr_case) THEN
491 56522 : CALL timeset("dbt_tas_mm_1N", handle3)
492 :
493 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
494 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495 56522 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496 56522 : CALL timestop(handle3)
497 : ELSE
498 5360 : CALL timeset("dbt_tas_mm_1T", handle3)
499 : CALL dbm_multiply(transa=.TRUE., transb=.FALSE., alpha=alpha, &
500 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501 5360 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
502 :
503 5360 : CALL timestop(handle3)
504 : END IF
505 61882 : CALL matrix_a%dist%info%mp_comm%sync()
506 61882 : CALL timestop(handle4)
507 :
508 61882 : CALL dbm_release(matrix_a_mm)
509 61882 : CALL dbm_release(matrix_b_mm)
510 :
511 61882 : nze_c = dbm_get_nze(matrix_c_mm)
512 :
513 61882 : IF (.NOT. new_c) THEN
514 56466 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
515 : ELSE
516 5416 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
517 : END IF
518 :
519 61882 : CALL dbm_release(matrix_c_mm)
520 :
521 61882 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
522 :
523 247946 : IF (unit_nr_prv /= 0) THEN
524 418 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
525 : END IF
526 :
527 : CASE (2)
528 61157 : IF (matrix_c%do_batched <= 1) THEN
529 268365 : ALLOCATE (matrix_c_rep)
530 53673 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
531 53673 : IF (matrix_c%do_batched == 1) THEN
532 24623 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
533 24623 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
534 : END IF
535 7484 : ELSEIF (matrix_c%do_batched == 2) THEN
536 5780 : ALLOCATE (matrix_c_rep)
537 1156 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
538 : ! just leave sparsity structure for retain sparsity but no values
539 1156 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
540 1156 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
541 1156 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
542 6328 : ELSEIF (matrix_c%do_batched == 3) THEN
543 6328 : matrix_c_rep => matrix_c%mm_storage%store_batched_repl
544 : END IF
545 :
546 61157 : IF (unit_nr_prv /= 0) THEN
547 20792 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
548 20792 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
549 : END IF
550 :
551 61157 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
552 :
553 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
554 61157 : info_a = dbt_tas_info(matrix_a_rs)
555 61157 : CALL dbt_tas_info_hold(info_a)
556 :
557 61157 : IF (new_a) THEN
558 486 : CALL dbt_tas_destroy(matrix_a_rs)
559 486 : DEALLOCATE (matrix_a_rs)
560 : END IF
561 :
562 61157 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
563 :
564 61157 : info_b = dbt_tas_info(matrix_b_rs)
565 61157 : CALL dbt_tas_info_hold(info_b)
566 :
567 61157 : IF (new_b) THEN
568 634 : CALL dbt_tas_destroy(matrix_b_rs)
569 634 : DEALLOCATE (matrix_b_rs)
570 : END IF
571 :
572 61157 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
573 :
574 61157 : info_c = dbt_tas_info(matrix_c_rep)
575 61157 : CALL dbt_tas_info_hold(info_c)
576 :
577 61157 : CALL matrix_a%dist%info%mp_comm%sync()
578 61157 : CALL timeset("dbt_tas_dbm", handle4)
579 61157 : CALL timeset("dbt_tas_mm_2", handle3)
580 : CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581 : matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582 61157 : filter_eps=filter_eps_prv/REAL(nsplit, KIND=dp), retain_sparsity=retain_sparsity, flop=flop)
583 61157 : CALL matrix_a%dist%info%mp_comm%sync()
584 61157 : CALL timestop(handle3)
585 61157 : CALL timestop(handle4)
586 :
587 61157 : CALL dbm_release(matrix_a_mm)
588 61157 : CALL dbm_release(matrix_b_mm)
589 :
590 61157 : nze_c = dbm_get_nze(matrix_c_mm)
591 :
592 61157 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
593 61157 : nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
594 :
595 61157 : CALL dbm_release(matrix_c_mm)
596 :
597 61157 : IF (unit_nr_prv /= 0) THEN
598 20792 : CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
599 : END IF
600 :
601 61157 : IF (matrix_c%do_batched == 0) THEN
602 29050 : CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
603 : ELSE
604 32107 : matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
605 : END IF
606 :
607 61157 : IF (matrix_c%do_batched == 0) THEN
608 29050 : CALL dbt_tas_destroy(matrix_c_rep)
609 29050 : DEALLOCATE (matrix_c_rep)
610 : END IF
611 :
612 61157 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
613 :
614 : ! set upper limit on memory consumption for replicated matrix and complete batched mm
615 : ! if limit is exceeded
616 307389 : IF (nze_c_sum > default_nsplit_accept_ratio*MAX(nze_a, nze_b)) THEN
617 1604 : CALL dbt_tas_batched_mm_complete(matrix_c)
618 : END IF
619 :
620 : CASE (3)
621 79926 : IF (matrix_a%do_batched <= 2) THEN
622 331465 : ALLOCATE (matrix_a_rep)
623 66293 : CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
624 66293 : IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
625 23009 : matrix_a%mm_storage%store_batched_repl => matrix_a_rep
626 23009 : CALL dbt_tas_set_batched_state(matrix_a, state=3)
627 : END IF
628 13633 : ELSEIF (matrix_a%do_batched == 3) THEN
629 13633 : matrix_a_rep => matrix_a%mm_storage%store_batched_repl
630 : END IF
631 :
632 79926 : IF (new_a) THEN
633 66293 : CALL dbt_tas_destroy(matrix_a_rs)
634 66293 : DEALLOCATE (matrix_a_rs)
635 : END IF
636 79926 : IF (unit_nr_prv /= 0) THEN
637 25680 : CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
638 25680 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
639 : END IF
640 :
641 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642 79926 : move_data=matrix_a%do_batched == 0)
643 :
644 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
645 79926 : info_a = dbt_tas_info(matrix_a_rep)
646 79926 : CALL dbt_tas_info_hold(info_a)
647 :
648 79926 : IF (matrix_a%do_batched == 0) THEN
649 43284 : CALL dbt_tas_destroy(matrix_a_rep)
650 43284 : DEALLOCATE (matrix_a_rep)
651 : END IF
652 :
653 79926 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
654 :
655 79926 : info_b = dbt_tas_info(matrix_b_rs)
656 79926 : CALL dbt_tas_info_hold(info_b)
657 :
658 79926 : IF (new_b) THEN
659 16 : CALL dbt_tas_destroy(matrix_b_rs)
660 16 : DEALLOCATE (matrix_b_rs)
661 : END IF
662 79926 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
663 :
664 79926 : info_c = dbt_tas_info(matrix_c_rs)
665 79926 : CALL dbt_tas_info_hold(info_c)
666 :
667 79926 : CALL matrix_a%dist%info%mp_comm%sync()
668 79926 : CALL timeset("dbt_tas_dbm", handle4)
669 79926 : IF (.NOT. tr_case) THEN
670 37024 : CALL timeset("dbt_tas_mm_3N", handle3)
671 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
672 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673 37024 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674 37024 : CALL timestop(handle3)
675 : ELSE
676 42902 : CALL timeset("dbt_tas_mm_3T", handle3)
677 : CALL dbm_multiply(transa=.FALSE., transb=.TRUE., alpha=alpha, &
678 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679 42902 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680 42902 : CALL timestop(handle3)
681 : END IF
682 79926 : CALL matrix_a%dist%info%mp_comm%sync()
683 79926 : CALL timestop(handle4)
684 :
685 79926 : CALL dbm_release(matrix_a_mm)
686 79926 : CALL dbm_release(matrix_b_mm)
687 :
688 79926 : nze_c = dbm_get_nze(matrix_c_mm)
689 :
690 79926 : IF (.NOT. new_c) THEN
691 74344 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
692 : ELSE
693 5582 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
694 : END IF
695 :
696 79926 : CALL dbm_release(matrix_c_mm)
697 :
698 79926 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
699 :
700 522669 : IF (unit_nr_prv /= 0) THEN
701 25680 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
702 : END IF
703 : END SELECT
704 :
705 202965 : CALL mp_comm_mm%free()
706 :
707 202965 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
708 :
709 202965 : IF (PRESENT(split_opt)) THEN
710 94394 : SELECT CASE (max_mm_dim)
711 : CASE (1, 3)
712 94394 : CALL mp_comm%sum(nze_c)
713 : CASE (2)
714 48236 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
715 48236 : CALL mp_comm%sum(nze_c)
716 190866 : CALL mp_comm%max(nze_c)
717 :
718 : END SELECT
719 142630 : nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
720 : ! ideally we should rederive the split factor from the actual sparsity of C, but
721 : ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
722 142630 : mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
723 142630 : CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
724 142630 : IF (unit_nr_prv > 0) THEN
725 : WRITE (unit_nr_prv, "(T2,A)") &
726 10 : "MM PARAMETERS"
727 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
728 20 : (nze_c + numproc - 1)/numproc
729 :
730 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
731 : END IF
732 :
733 : END IF
734 :
735 202965 : IF (new_c) THEN
736 40048 : CALL dbm_scale(matrix_c%matrix, beta)
737 : CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., &
738 : transposed=(transc_prv .NEQV. transc), &
739 40048 : move_data=.TRUE.)
740 40048 : CALL dbt_tas_destroy(matrix_c_rs)
741 40048 : DEALLOCATE (matrix_c_rs)
742 40048 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
743 162917 : ELSEIF (matrix_c%do_batched > 0) THEN
744 32381 : IF (matrix_c%mm_storage%batched_out) THEN
745 32107 : matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
746 : END IF
747 : END IF
748 :
749 202965 : IF (PRESENT(move_data_a)) THEN
750 202917 : IF (move_data_a) CALL dbt_tas_clear(matrix_a)
751 : END IF
752 202965 : IF (PRESENT(move_data_b)) THEN
753 202917 : IF (move_data_b) CALL dbt_tas_clear(matrix_b)
754 : END IF
755 :
756 202965 : IF (PRESENT(flop)) THEN
757 91271 : CALL mp_comm%sum(flop)
758 91271 : flop = (flop + numproc - 1)/numproc
759 : END IF
760 :
761 202965 : IF (PRESENT(optimize_dist)) THEN
762 48 : IF (optimize_dist) CALL comm_tmp%free()
763 : END IF
764 202965 : IF (unit_nr_prv > 0) THEN
765 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
766 34 : WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
767 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
768 : END IF
769 :
770 202965 : CALL dbt_tas_release_info(info_a)
771 202965 : CALL dbt_tas_release_info(info_b)
772 202965 : CALL dbt_tas_release_info(info_c)
773 :
774 202965 : CALL matrix_a%dist%info%mp_comm%sync()
775 202965 : CALL timestop(handle2)
776 202965 : CALL timestop(handle)
777 :
778 405930 : END SUBROUTINE
779 :
780 : ! **************************************************************************************************
781 : !> \brief ...
782 : !> \param matrix_in ...
783 : !> \param matrix_out ...
784 : !> \param local_copy ...
785 : !> \param alpha ...
786 : !> \author Patrick Seewald
787 : ! **************************************************************************************************
788 202965 : SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789 : TYPE(dbm_type), INTENT(IN) :: matrix_in
790 : TYPE(dbm_type), INTENT(INOUT) :: matrix_out
791 : LOGICAL, INTENT(IN), OPTIONAL :: local_copy
792 : REAL(dp), INTENT(IN) :: alpha
793 :
794 : LOGICAL :: local_copy_prv
795 : TYPE(dbm_type) :: matrix_tmp
796 :
797 202965 : IF (PRESENT(local_copy)) THEN
798 202965 : local_copy_prv = local_copy
799 : ELSE
800 : local_copy_prv = .FALSE.
801 : END IF
802 :
803 202965 : IF (alpha /= 1.0_dp) THEN
804 132702 : CALL dbm_scale(matrix_out, alpha)
805 : END IF
806 :
807 202965 : IF (.NOT. local_copy_prv) THEN
808 0 : CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
809 0 : CALL dbm_redistribute(matrix_in, matrix_tmp)
810 0 : CALL dbm_add(matrix_out, matrix_tmp)
811 0 : CALL dbm_release(matrix_tmp)
812 : ELSE
813 202965 : CALL dbm_add(matrix_out, matrix_in)
814 : END IF
815 :
816 202965 : END SUBROUTINE
817 :
818 : ! **************************************************************************************************
819 : !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
820 : !> the same process grid as the other 2 matrices.
821 : !> \param mp_comm communicator that defines Cartesian topology
822 : !> \param matrix_in ...
823 : !> \param matrix_out ...
824 : !> \param transposed Whether matrix_out should be transposed
825 : !> \param nodata Data of matrix_in should not be copied to matrix_out
826 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
827 : !> \author Patrick Seewald
828 : ! **************************************************************************************************
829 1260308 : SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
830 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm
831 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
832 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
833 : LOGICAL, INTENT(IN) :: transposed
834 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
835 :
836 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_small'
837 :
838 : INTEGER :: handle
839 : INTEGER(KIND=int_8), DIMENSION(2) :: dims
840 : INTEGER, DIMENSION(2) :: pdims
841 : LOGICAL :: nodata_prv
842 180044 : TYPE(dbt_tas_dist_arb) :: new_col_dist, new_row_dist
843 900220 : TYPE(dbt_tas_distribution_type) :: dist
844 :
845 180044 : CALL timeset(routineN, handle)
846 :
847 180044 : IF (PRESENT(nodata)) THEN
848 54829 : nodata_prv = nodata
849 : ELSE
850 : nodata_prv = .FALSE.
851 : END IF
852 :
853 540132 : pdims = mp_comm%num_pe_cart
854 :
855 540132 : dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
856 :
857 180044 : IF (transposed) CALL swap(dims)
858 :
859 180044 : IF (.NOT. transposed) THEN
860 126669 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
861 126669 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
862 126669 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
863 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
864 126669 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
865 : ELSE
866 53375 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
867 53375 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
868 53375 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
869 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
870 53375 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
871 : END IF
872 180044 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
873 :
874 180044 : CALL timestop(handle)
875 :
876 180044 : END SUBROUTINE
877 :
878 : ! **************************************************************************************************
879 : !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
880 : !> with the same split factor.
881 : !> \param matrix1_in ...
882 : !> \param matrix2_in ...
883 : !> \param matrix1_out ...
884 : !> \param matrix2_out ...
885 : !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
886 : !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
887 : !> \param trans1 transpose flag of matrix1_in for multiplication
888 : !> \param trans2 transpose flag of matrix2_in for multiplication
889 : !> \param optimize_dist experimental: optimize matrix splitting and distribution
890 : !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
891 : !> \param opt_nsplit ...
892 : !> \param split_rc_1 Whether to split rows or columns for matrix 1
893 : !> \param split_rc_2 Whether to split rows or columns for matrix 2
894 : !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
895 : !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
896 : !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
897 : !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
898 : !> \param comm_new returns the new communicator only if optimize_dist
899 : !> \param unit_nr output unit
900 : !> \author Patrick Seewald
901 : ! **************************************************************************************************
902 202965 : SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903 : optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904 : move_data_1, move_data_2, comm_new, unit_nr)
905 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix1_in, matrix2_in
906 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix1_out, matrix2_out
907 : LOGICAL, INTENT(OUT) :: new1, new2
908 : LOGICAL, INTENT(INOUT) :: trans1, trans2
909 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
910 : INTEGER, INTENT(IN), OPTIONAL :: nsplit
911 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
912 : INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2
913 : LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
914 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
915 : TYPE(mp_cart_type), INTENT(OUT), OPTIONAL :: comm_new
916 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
917 :
918 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible'
919 :
920 : INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
921 : unit_nr_prv
922 : INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923 : INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref
924 : INTEGER, DIMENSION(2) :: pdims
925 : LOGICAL :: nodata1_prv, nodata2_prv, &
926 : optimize_dist_prv, trans1_newdist, &
927 : trans2_newdist
928 : TYPE(dbt_tas_dist_cyclic) :: col_dist_1, col_dist_2, row_dist_1, &
929 : row_dist_2
930 1826685 : TYPE(dbt_tas_distribution_type) :: dist_1, dist_2
931 1014825 : TYPE(dbt_tas_split_info) :: split_info
932 202965 : TYPE(mp_cart_type) :: mp_comm
933 :
934 202965 : CALL timeset(routineN, handle)
935 202965 : new1 = .FALSE.; new2 = .FALSE.
936 :
937 202965 : IF (PRESENT(nodata1)) THEN
938 0 : nodata1_prv = nodata1
939 : ELSE
940 : nodata1_prv = .FALSE.
941 : END IF
942 :
943 202965 : IF (PRESENT(nodata2)) THEN
944 141808 : nodata2_prv = nodata2
945 : ELSE
946 : nodata2_prv = .FALSE.
947 : END IF
948 :
949 202965 : unit_nr_prv = prep_output_unit(unit_nr)
950 :
951 202965 : NULLIFY (matrix1_out, matrix2_out)
952 :
953 202965 : IF (PRESENT(optimize_dist)) THEN
954 48 : optimize_dist_prv = optimize_dist
955 : ELSE
956 : optimize_dist_prv = .FALSE.
957 : END IF
958 :
959 608895 : dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
960 608895 : dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
961 202965 : nze1 = dbt_tas_get_nze_total(matrix1_in)
962 202965 : nze2 = dbt_tas_get_nze_total(matrix2_in)
963 :
964 202965 : IF (trans1) split_rc_1 = MOD(split_rc_1, 2) + 1
965 :
966 202965 : IF (trans2) split_rc_2 = MOD(split_rc_2, 2) + 1
967 :
968 202965 : IF (nze1 >= nze2) THEN
969 191865 : ref = 1
970 191865 : split_rc_ref = split_rc_1
971 191865 : dims_ref = dims1
972 : ELSE
973 11100 : ref = 2
974 11100 : split_rc_ref = split_rc_2
975 11100 : dims_ref = dims2
976 : END IF
977 :
978 202965 : IF (PRESENT(nsplit)) THEN
979 202965 : nsplit_prv = nsplit
980 : ELSE
981 0 : nsplit_prv = 0
982 : END IF
983 :
984 202965 : IF (optimize_dist_prv) THEN
985 48 : CPASSERT(PRESENT(comm_new))
986 : END IF
987 :
988 202917 : IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
989 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990 190647 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
991 190647 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
992 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993 190647 : move_data=move_data_2, nodata=nodata2, opt_nsplit=.FALSE.)
994 190647 : IF (unit_nr_prv > 0) THEN
995 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
996 10 : TRIM(dbm_get_name(matrix1_in%matrix)), &
997 20 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
998 10 : IF (new1) THEN
999 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1000 0 : TRIM(dbm_get_name(matrix1_in%matrix)), ": Yes"
1001 : ELSE
1002 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1003 20 : TRIM(dbm_get_name(matrix1_in%matrix)), ": No"
1004 : END IF
1005 10 : IF (new2) THEN
1006 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1007 0 : TRIM(dbm_get_name(matrix2_in%matrix)), ": Yes"
1008 : ELSE
1009 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1010 20 : TRIM(dbm_get_name(matrix2_in%matrix)), ": No"
1011 : END IF
1012 : END IF
1013 : ELSE
1014 :
1015 12270 : IF (optimize_dist_prv) THEN
1016 48 : IF (unit_nr_prv > 0) THEN
1017 24 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
1018 24 : TRIM(dbm_get_name(matrix1_in%matrix)), &
1019 48 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
1020 : END IF
1021 :
1022 48 : trans1_newdist = (split_rc_1 == colsplit)
1023 48 : trans2_newdist = (split_rc_2 == colsplit)
1024 :
1025 48 : IF (trans1_newdist) THEN
1026 24 : CALL swap(dims1)
1027 24 : trans1 = .NOT. trans1
1028 : END IF
1029 :
1030 48 : IF (trans2_newdist) THEN
1031 24 : CALL swap(dims2)
1032 24 : trans2 = .NOT. trans2
1033 : END IF
1034 :
1035 48 : IF (nsplit_prv == 0) THEN
1036 0 : SELECT CASE (split_rc_ref)
1037 : CASE (rowsplit)
1038 0 : d1 = dims_ref(1)
1039 0 : d2 = dims_ref(2)
1040 : CASE (colsplit)
1041 0 : d1 = dims_ref(2)
1042 0 : d2 = dims_ref(1)
1043 : END SELECT
1044 0 : nsplit_prv = INT((d1 - 1)/d2 + 1)
1045 : END IF
1046 :
1047 48 : CPASSERT(nsplit_prv > 0)
1048 :
1049 48 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
1050 48 : comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
1051 48 : CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
1052 :
1053 144 : pdims = comm_new%num_pe_cart
1054 :
1055 : ! use a very simple cyclic distribution that may not be load balanced if block
1056 : ! sizes are not equal. However we can not use arbitrary distributions
1057 : ! for large dimensions since this would require storing distribution vectors as arrays
1058 : ! which can not be stored for large dimensions.
1059 48 : row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
1060 48 : col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
1061 :
1062 48 : row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
1063 48 : col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
1064 :
1065 48 : CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
1066 48 : CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
1067 48 : CALL dbt_tas_release_info(split_info)
1068 :
1069 240 : ALLOCATE (matrix1_out)
1070 48 : IF (.NOT. trans1_newdist) THEN
1071 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1072 24 : matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
1073 :
1074 : ELSE
1075 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1076 24 : matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
1077 : END IF
1078 :
1079 240 : ALLOCATE (matrix2_out)
1080 48 : IF (.NOT. trans2_newdist) THEN
1081 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1082 24 : matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
1083 : ELSE
1084 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1085 24 : matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
1086 : END IF
1087 :
1088 48 : IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089 48 : IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1090 48 : new1 = .TRUE.
1091 48 : new2 = .TRUE.
1092 :
1093 : ELSE
1094 11560 : SELECT CASE (ref)
1095 : CASE (1)
1096 11560 : IF (unit_nr_prv > 0) THEN
1097 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1098 0 : TRIM(dbm_get_name(matrix2_in%matrix))
1099 : END IF
1100 :
1101 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102 11560 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1103 :
1104 57800 : ALLOCATE (matrix2_out)
1105 : CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106 11560 : nodata=nodata2, move_data=move_data_2)
1107 11560 : new2 = .TRUE.
1108 : CASE (2)
1109 710 : IF (unit_nr_prv > 0) THEN
1110 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1111 0 : TRIM(dbm_get_name(matrix1_in%matrix))
1112 : END IF
1113 :
1114 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115 710 : move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1116 :
1117 3550 : ALLOCATE (matrix1_out)
1118 : CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119 710 : nodata=nodata1, move_data=move_data_1)
1120 25250 : new1 = .TRUE.
1121 : END SELECT
1122 : END IF
1123 : END IF
1124 :
1125 202965 : IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
1126 202965 : IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
1127 :
1128 202965 : CALL timestop(handle)
1129 :
1130 608895 : END SUBROUTINE
1131 :
1132 : ! **************************************************************************************************
1133 : !> \brief Change split factor without redistribution
1134 : !> \param matrix_in ...
1135 : !> \param matrix_out ...
1136 : !> \param nsplit new split factor, set to 0 to not change split of matrix_in
1137 : !> \param split_rowcol split rows or columns
1138 : !> \param is_new whether matrix_out is new or a pointer to matrix_in
1139 : !> \param opt_nsplit whether nsplit should be optimized for current process grid
1140 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1141 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1142 : !> \author Patrick Seewald
1143 : ! **************************************************************************************************
1144 393564 : SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1145 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_in
1146 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix_out
1147 : INTEGER, INTENT(IN) :: nsplit, split_rowcol
1148 : LOGICAL, INTENT(OUT) :: is_new
1149 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
1150 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data
1151 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1152 :
1153 : CHARACTER(len=default_string_length) :: name
1154 : INTEGER :: handle, nsplit_new, nsplit_old, &
1155 : nsplit_prv, split_rc
1156 : LOGICAL :: nodata_prv
1157 1967820 : TYPE(dbt_tas_distribution_type) :: dist
1158 1967820 : TYPE(dbt_tas_split_info) :: split_info
1159 393564 : TYPE(mp_cart_type) :: mp_comm
1160 :
1161 1574256 : CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
1162 787128 : CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize
1163 : CHARACTER(LEN=*), PARAMETER :: routineN = 'change_split'
1164 :
1165 393564 : NULLIFY (matrix_out)
1166 :
1167 393564 : is_new = .TRUE.
1168 :
1169 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
1170 393564 : split_rowcol=split_rc, nsplit=nsplit_old)
1171 :
1172 393564 : IF (nsplit == 0) THEN
1173 119641 : IF (split_rowcol == split_rc) THEN
1174 116821 : matrix_out => matrix_in
1175 116821 : is_new = .FALSE.
1176 116821 : RETURN
1177 : ELSE
1178 2820 : nsplit_prv = 1
1179 : END IF
1180 : ELSE
1181 273923 : nsplit_prv = nsplit
1182 : END IF
1183 :
1184 276743 : CALL timeset(routineN, handle)
1185 :
1186 276743 : nodata_prv = .FALSE.
1187 276743 : IF (PRESENT(nodata)) nodata_prv = nodata
1188 :
1189 : CALL dbt_tas_get_info(matrix_in, name=name, &
1190 : row_blk_size=rbsize, col_blk_size=cbsize, &
1191 : proc_row_dist=rdist, proc_col_dist=cdist)
1192 :
1193 276743 : CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1194 :
1195 276743 : CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
1196 :
1197 276743 : IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1198 271367 : matrix_out => matrix_in
1199 271367 : is_new = .FALSE.
1200 271367 : CALL dbt_tas_release_info(split_info)
1201 271367 : CALL timestop(handle)
1202 271367 : RETURN
1203 : END IF
1204 :
1205 : CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1206 5376 : split_info=split_info)
1207 :
1208 5376 : CALL dbt_tas_release_info(split_info)
1209 :
1210 26880 : ALLOCATE (matrix_out)
1211 5376 : CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.TRUE.)
1212 :
1213 5376 : IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
1214 :
1215 5376 : IF (PRESENT(move_data)) THEN
1216 5376 : IF (.NOT. nodata_prv) THEN
1217 5376 : IF (move_data) CALL dbt_tas_clear(matrix_in)
1218 5376 : move_data = .TRUE.
1219 : END IF
1220 : END IF
1221 :
1222 5376 : CALL timestop(handle)
1223 1386640 : END SUBROUTINE
1224 :
1225 : ! **************************************************************************************************
1226 : !> \brief Check whether matrices have same distribution and same split.
1227 : !> \param mat_a ...
1228 : !> \param mat_b ...
1229 : !> \param split_rc_a ...
1230 : !> \param split_rc_b ...
1231 : !> \param unit_nr ...
1232 : !> \return ...
1233 : !> \author Patrick Seewald
1234 : ! **************************************************************************************************
1235 202917 : FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1236 : TYPE(dbt_tas_type), INTENT(IN) :: mat_a, mat_b
1237 : INTEGER, INTENT(IN) :: split_rc_a, split_rc_b
1238 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1239 : LOGICAL :: dist_compatible
1240 :
1241 : INTEGER :: numproc, same_local_rowcols, &
1242 : split_check_a, split_check_b, &
1243 : unit_nr_prv
1244 202917 : INTEGER(int_8), ALLOCATABLE, DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245 : INTEGER, DIMENSION(2) :: pdims_a, pdims_b
1246 1826253 : TYPE(dbt_tas_split_info) :: info_a, info_b
1247 :
1248 202917 : unit_nr_prv = prep_output_unit(unit_nr)
1249 :
1250 202917 : dist_compatible = .FALSE.
1251 :
1252 202917 : info_a = dbt_tas_info(mat_a)
1253 202917 : info_b = dbt_tas_info(mat_b)
1254 202917 : CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
1255 202917 : CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
1256 202917 : IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
1257 12190 : IF (unit_nr_prv > 0) THEN
1258 0 : WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
1259 0 : WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
1260 : END IF
1261 12246 : RETURN
1262 : END IF
1263 :
1264 : ! check if communicators are equivalent
1265 : ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1266 : ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
1267 190727 : numproc = info_b%mp_comm%num_pe
1268 572181 : pdims_a = info_a%mp_comm%num_pe_cart
1269 572181 : pdims_b = info_b%mp_comm%num_pe_cart
1270 190727 : IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1271 56 : IF (unit_nr_prv > 0) THEN
1272 0 : WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
1273 : END IF
1274 56 : RETURN
1275 : END IF
1276 :
1277 : ! check that distribution is the same by comparing local rows / columns for each matrix
1278 136120 : SELECT CASE (split_rc_a)
1279 : CASE (rowsplit)
1280 136120 : CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
1281 136120 : CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
1282 : CASE (colsplit)
1283 54551 : CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
1284 245222 : CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
1285 : END SELECT
1286 :
1287 190671 : same_local_rowcols = MERGE(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
1288 :
1289 190671 : CALL info_a%mp_comm%sum(same_local_rowcols)
1290 :
1291 190671 : IF (same_local_rowcols == numproc) THEN
1292 : dist_compatible = .TRUE.
1293 : ELSE
1294 24 : IF (unit_nr_prv > 0) THEN
1295 0 : WRITE (unit_nr_prv, *) "local rowcols not compatible"
1296 0 : WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
1297 0 : WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
1298 : END IF
1299 : END IF
1300 :
1301 405834 : END FUNCTION
1302 :
1303 : ! **************************************************************************************************
1304 : !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
1305 : !> \param template ...
1306 : !> \param matrix_in ...
1307 : !> \param matrix_out ...
1308 : !> \param trans ...
1309 : !> \param split_rc ...
1310 : !> \param nodata ...
1311 : !> \param move_data ...
1312 : !> \author Patrick Seewald
1313 : ! **************************************************************************************************
1314 85890 : SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1315 : TYPE(dbt_tas_type), INTENT(IN) :: template
1316 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1317 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1318 : LOGICAL, INTENT(INOUT) :: trans
1319 : INTEGER, INTENT(IN) :: split_rc
1320 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1321 :
1322 12270 : CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1323 :
1324 73620 : TYPE(dbt_tas_distribution_type) :: dist_new
1325 134970 : TYPE(dbt_tas_split_info) :: info_template, info_matrix
1326 : INTEGER :: dim_split_template, dim_split_matrix, &
1327 : handle
1328 : INTEGER, DIMENSION(2) :: pdims
1329 : LOGICAL :: nodata_prv, transposed
1330 12270 : TYPE(mp_cart_type) :: mp_comm
1331 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template'
1332 :
1333 12270 : CALL timeset(routineN, handle)
1334 :
1335 12270 : IF (PRESENT(nodata)) THEN
1336 10966 : nodata_prv = nodata
1337 : ELSE
1338 : nodata_prv = .FALSE.
1339 : END IF
1340 :
1341 12270 : info_template = dbt_tas_info(template)
1342 12270 : info_matrix = dbt_tas_info(matrix_in)
1343 :
1344 12270 : dim_split_template = info_template%split_rowcol
1345 12270 : dim_split_matrix = split_rc
1346 :
1347 12270 : transposed = dim_split_template .NE. dim_split_matrix
1348 12270 : IF (transposed) trans = .NOT. trans
1349 :
1350 36810 : pdims = info_template%mp_comm%num_pe_cart
1351 :
1352 6244 : SELECT CASE (dim_split_template)
1353 : CASE (1)
1354 6244 : IF (.NOT. transposed) THEN
1355 44 : ALLOCATE (row_dist, source=template%dist%row_dist)
1356 44 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1357 : ELSE
1358 6200 : ALLOCATE (row_dist, source=template%dist%row_dist)
1359 6200 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1360 : END IF
1361 : CASE (2)
1362 12270 : IF (.NOT. transposed) THEN
1363 120 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1364 120 : ALLOCATE (col_dist, source=template%dist%col_dist)
1365 : ELSE
1366 11932 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1367 11932 : ALLOCATE (col_dist, source=template%dist%col_dist)
1368 : END IF
1369 : END SELECT
1370 :
1371 12270 : CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
1372 12270 : CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1373 12270 : IF (.NOT. transposed) THEN
1374 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1375 104 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
1376 : ELSE
1377 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1378 12166 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
1379 : END IF
1380 :
1381 12270 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1382 :
1383 12270 : CALL timestop(handle)
1384 :
1385 24540 : END SUBROUTINE
1386 :
1387 : ! **************************************************************************************************
1388 : !> \brief Estimate sparsity pattern of C resulting from A x B = C
1389 : !> by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
1390 : !> \param transa ...
1391 : !> \param transb ...
1392 : !> \param transc ...
1393 : !> \param matrix_a ...
1394 : !> \param matrix_b ...
1395 : !> \param matrix_c ...
1396 : !> \param estimated_nze ...
1397 : !> \param filter_eps ...
1398 : !> \param unit_nr ...
1399 : !> \param retain_sparsity ...
1400 : !> \author Patrick Seewald
1401 : ! **************************************************************************************************
1402 60403 : SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403 : estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404 : LOGICAL, INTENT(IN) :: transa, transb, transc
1405 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c
1406 : INTEGER(int_8), INTENT(OUT) :: estimated_nze
1407 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
1408 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1409 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1410 :
1411 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_estimate_result_nze'
1412 :
1413 : INTEGER :: col_size, handle, row_size
1414 : INTEGER(int_8) :: col, row
1415 : LOGICAL :: retain_sparsity_prv
1416 : TYPE(dbt_tas_iterator) :: iter
1417 : TYPE(dbt_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1418 : matrix_c_bnorm
1419 60403 : TYPE(mp_cart_type) :: mp_comm
1420 :
1421 60403 : CALL timeset(routineN, handle)
1422 :
1423 60403 : IF (PRESENT(retain_sparsity)) THEN
1424 116 : retain_sparsity_prv = retain_sparsity
1425 : ELSE
1426 : retain_sparsity_prv = .FALSE.
1427 : END IF
1428 :
1429 116 : IF (.NOT. retain_sparsity_prv) THEN
1430 1145453 : ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431 60287 : CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432 60287 : CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433 60287 : CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
1434 :
1435 : CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
1436 : matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437 : filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
1438 60287 : simple_split=.TRUE., unit_nr=unit_nr)
1439 60287 : CALL dbt_tas_destroy(matrix_a_bnorm)
1440 60287 : CALL dbt_tas_destroy(matrix_b_bnorm)
1441 :
1442 60287 : DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1443 : ELSE
1444 : matrix_c_bnorm => matrix_c
1445 : END IF
1446 :
1447 60403 : estimated_nze = 0
1448 : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
1449 60403 : !$OMP PRIVATE(iter,row,col,row_size,col_size)
1450 : CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
1451 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1452 : CALL dbt_tas_iterator_next_block(iter, row, col)
1453 : row_size = matrix_c%row_blk_size%data(row)
1454 : col_size = matrix_c%col_blk_size%data(col)
1455 : estimated_nze = estimated_nze + row_size*col_size
1456 : END DO
1457 : CALL dbt_tas_iterator_stop(iter)
1458 : !$OMP END PARALLEL
1459 :
1460 60403 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
1461 60403 : CALL mp_comm%sum(estimated_nze)
1462 :
1463 60403 : IF (.NOT. retain_sparsity_prv) THEN
1464 60287 : CALL dbt_tas_destroy(matrix_c_bnorm)
1465 60287 : DEALLOCATE (matrix_c_bnorm)
1466 : END IF
1467 :
1468 60403 : CALL timestop(handle)
1469 :
1470 120806 : END SUBROUTINE
1471 :
1472 : ! **************************************************************************************************
1473 : !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1474 : !> This estimate is based on the minimization of communication volume whereby the
1475 : !> communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
1476 : !> \param max_mm_dim ...
1477 : !> \param nze_a number of non-zeroes in A
1478 : !> \param nze_b number of non-zeroes in B
1479 : !> \param nze_c number of non-zeroes in C
1480 : !> \param numnodes number of MPI ranks
1481 : !> \return estimated split factor
1482 : !> \author Patrick Seewald
1483 : ! **************************************************************************************************
1484 203033 : FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1485 : INTEGER, INTENT(IN) :: max_mm_dim
1486 : INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c
1487 : INTEGER, INTENT(IN) :: numnodes
1488 : INTEGER :: nsplit
1489 :
1490 : INTEGER(KIND=int_8) :: max_nze, min_nze
1491 : REAL(dp) :: s_opt_factor
1492 :
1493 203033 : s_opt_factor = 1.0_dp ! Could be further tuned.
1494 :
1495 264899 : SELECT CASE (max_mm_dim)
1496 : CASE (1)
1497 61866 : min_nze = MAX(nze_b, 1_int_8)
1498 185598 : max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
1499 : CASE (2)
1500 61141 : min_nze = MAX(nze_c, 1_int_8)
1501 183423 : max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
1502 : CASE (3)
1503 80026 : min_nze = MAX(nze_a, 1_int_8)
1504 240078 : max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
1505 : CASE DEFAULT
1506 203033 : CPABORT("")
1507 : END SELECT
1508 :
1509 203033 : nsplit = INT(MIN(INT(numnodes, KIND=int_8), NINT(REAL(max_nze, dp)/(REAL(min_nze, dp)*s_opt_factor), KIND=int_8)))
1510 203033 : IF (nsplit == 0) nsplit = 1
1511 :
1512 203033 : END FUNCTION
1513 :
1514 : ! **************************************************************************************************
1515 : !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
1516 : !> \param matrix_in ...
1517 : !> \param matrix_out ...
1518 : !> \param nodata ...
1519 : !> \author Patrick Seewald
1520 : ! **************************************************************************************************
1521 1085166 : SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1522 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1523 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1524 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1525 :
1526 : CHARACTER(len=default_string_length) :: name
1527 : INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528 : LOGICAL :: nodata_prv
1529 : REAL(dp), DIMENSION(1, 1) :: blk_put
1530 180861 : REAL(dp), DIMENSION(:, :), POINTER :: blk_get
1531 : TYPE(dbt_tas_blk_size_one) :: col_blk_size, row_blk_size
1532 : TYPE(dbt_tas_iterator) :: iter
1533 :
1534 : !REAL(dp), DIMENSION(:, :), POINTER :: dbt_put
1535 :
1536 180861 : CPASSERT(matrix_in%valid)
1537 :
1538 180861 : IF (PRESENT(nodata)) THEN
1539 60287 : nodata_prv = nodata
1540 : ELSE
1541 : nodata_prv = .FALSE.
1542 : END IF
1543 :
1544 180861 : CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1545 180861 : row_blk_size = dbt_tas_blk_size_one(nblkrows)
1546 180861 : col_blk_size = dbt_tas_blk_size_one(nblkcols)
1547 :
1548 : ! not sure if assumption that same distribution can be taken still holds
1549 180861 : CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1550 :
1551 180861 : IF (.NOT. nodata_prv) THEN
1552 120574 : CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
1553 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
1554 120574 : !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
1555 : CALL dbt_tas_iterator_start(iter, matrix_in)
1556 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1557 : CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
1558 : blk_put(1, 1) = NORM2(blk_get)
1559 : CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
1560 : END DO
1561 : CALL dbt_tas_iterator_stop(iter)
1562 : !$OMP END PARALLEL
1563 : END IF
1564 :
1565 180861 : END SUBROUTINE
1566 :
1567 : ! **************************************************************************************************
1568 : !> \brief Convert a DBM matrix to a new process grid
1569 : !> \param mp_comm_cart new process grid
1570 : !> \param matrix_in ...
1571 : !> \param matrix_out ...
1572 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1573 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1574 : !> \param optimize_pgrid Whether to change process grid
1575 : !> \author Patrick Seewald
1576 : ! **************************************************************************************************
1577 608895 : SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1578 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm_cart
1579 : TYPE(dbm_type), INTENT(INOUT) :: matrix_in
1580 : TYPE(dbm_type), INTENT(OUT) :: matrix_out
1581 : LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata, optimize_pgrid
1582 :
1583 : CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid'
1584 :
1585 : CHARACTER(len=default_string_length) :: name
1586 : INTEGER :: handle, nbcols, nbrows
1587 608895 : INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: col_dist, rbsize, rcsize, row_dist
1588 : INTEGER, DIMENSION(2) :: pdims
1589 : LOGICAL :: nodata_prv, optimize_pgrid_prv
1590 : TYPE(dbm_distribution_obj) :: dist, dist_old
1591 :
1592 608895 : NULLIFY (row_dist, col_dist, rbsize, rcsize)
1593 :
1594 608895 : CALL timeset(routineN, handle)
1595 :
1596 608895 : IF (PRESENT(optimize_pgrid)) THEN
1597 608895 : optimize_pgrid_prv = optimize_pgrid
1598 : ELSE
1599 : optimize_pgrid_prv = .TRUE.
1600 : END IF
1601 :
1602 608895 : IF (PRESENT(nodata)) THEN
1603 202965 : nodata_prv = nodata
1604 : ELSE
1605 : nodata_prv = .FALSE.
1606 : END IF
1607 :
1608 608895 : name = dbm_get_name(matrix_in)
1609 :
1610 608895 : IF (.NOT. optimize_pgrid_prv) THEN
1611 608895 : CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
1612 608895 : IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
1613 608895 : CALL timestop(handle)
1614 608895 : RETURN
1615 : END IF
1616 :
1617 0 : rbsize => dbm_get_row_block_sizes(matrix_in)
1618 0 : rcsize => dbm_get_col_block_sizes(matrix_in)
1619 0 : nbrows = SIZE(rbsize)
1620 0 : nbcols = SIZE(rcsize)
1621 0 : dist_old = dbm_get_distribution(matrix_in)
1622 0 : pdims = mp_comm_cart%num_pe_cart
1623 :
1624 0 : ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1625 0 : CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1626 0 : CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1627 :
1628 0 : CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
1629 0 : DEALLOCATE (row_dist, col_dist)
1630 :
1631 0 : CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1632 0 : CALL dbm_distribution_release(dist)
1633 :
1634 0 : IF (.NOT. nodata_prv) THEN
1635 0 : CALL dbm_redistribute(matrix_in, matrix_out)
1636 0 : IF (PRESENT(move_data)) THEN
1637 0 : IF (move_data) CALL dbm_clear(matrix_in)
1638 : END IF
1639 : END IF
1640 :
1641 0 : CALL timestop(handle)
1642 608895 : END SUBROUTINE
1643 :
1644 : ! **************************************************************************************************
1645 : !> \brief ...
1646 : !> \param matrix ...
1647 : !> \author Patrick Seewald
1648 : ! **************************************************************************************************
1649 67327 : SUBROUTINE dbt_tas_batched_mm_init(matrix)
1650 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1651 :
1652 67327 : CALL dbt_tas_set_batched_state(matrix, state=1)
1653 67327 : ALLOCATE (matrix%mm_storage)
1654 : matrix%mm_storage%batched_out = .FALSE.
1655 67327 : END SUBROUTINE
1656 :
1657 : ! **************************************************************************************************
1658 : !> \brief ...
1659 : !> \param matrix ...
1660 : !> \author Patrick Seewald
1661 : ! **************************************************************************************************
1662 134654 : SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
1663 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1664 :
1665 : INTEGER :: handle
1666 :
1667 67327 : CALL matrix%dist%info%mp_comm%sync()
1668 67327 : CALL timeset("dbt_tas_total", handle)
1669 :
1670 67327 : IF (matrix%do_batched == 0) RETURN
1671 :
1672 67327 : IF (matrix%mm_storage%batched_out) THEN
1673 24623 : CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1674 : END IF
1675 :
1676 67327 : CALL dbt_tas_batched_mm_complete(matrix)
1677 :
1678 67327 : matrix%mm_storage%batched_out = .FALSE.
1679 :
1680 67327 : DEALLOCATE (matrix%mm_storage)
1681 67327 : CALL dbt_tas_set_batched_state(matrix, state=0)
1682 :
1683 67327 : CALL matrix%dist%info%mp_comm%sync()
1684 67327 : CALL timestop(handle)
1685 :
1686 : END SUBROUTINE
1687 :
1688 : ! **************************************************************************************************
1689 : !> \brief set state flags during batched multiplication
1690 : !> \param matrix ...
1691 : !> \param state 0 no batched MM
1692 : !> 1 batched MM but mm_storage not yet initialized
1693 : !> 2 batched MM and mm_storage requires update
1694 : !> 3 batched MM and mm_storage initialized
1695 : !> \param opt_grid whether process grid was already optimized and should not be changed
1696 : !> \author Patrick Seewald
1697 : ! **************************************************************************************************
1698 1074002 : SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
1699 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1700 : INTEGER, INTENT(IN), OPTIONAL :: state
1701 : LOGICAL, INTENT(IN), OPTIONAL :: opt_grid
1702 :
1703 1074002 : IF (PRESENT(opt_grid)) THEN
1704 811149 : matrix%has_opt_pgrid = opt_grid
1705 811149 : matrix%dist%info%strict_split(1) = .TRUE.
1706 : END IF
1707 :
1708 1074002 : IF (PRESENT(state)) THEN
1709 833441 : matrix%do_batched = state
1710 602950 : SELECT CASE (state)
1711 : CASE (0, 1)
1712 : ! reset to default
1713 602950 : IF (matrix%has_opt_pgrid) THEN
1714 350823 : matrix%dist%info%strict_split(1) = .TRUE.
1715 : ELSE
1716 252127 : matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1717 : END IF
1718 : CASE (2, 3)
1719 230491 : matrix%dist%info%strict_split(1) = .TRUE.
1720 : CASE DEFAULT
1721 833441 : CPABORT("should not happen")
1722 : END SELECT
1723 : END IF
1724 1074002 : END SUBROUTINE
1725 :
1726 : ! **************************************************************************************************
1727 : !> \brief ...
1728 : !> \param matrix ...
1729 : !> \param warn ...
1730 : !> \author Patrick Seewald
1731 : ! **************************************************************************************************
1732 850019 : SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
1733 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1734 : LOGICAL, INTENT(IN), OPTIONAL :: warn
1735 :
1736 850019 : IF (matrix%do_batched == 0) RETURN
1737 : ASSOCIATE (storage => matrix%mm_storage)
1738 70933 : IF (PRESENT(warn)) THEN
1739 1840 : IF (warn .AND. matrix%do_batched == 3) THEN
1740 : CALL cp_warn(__LOCATION__, &
1741 0 : "Optimizations for batched multiplication are disabled because of conflicting data access")
1742 : END IF
1743 : END IF
1744 70933 : IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
1745 :
1746 : CALL dbt_tas_merge(storage%store_batched%matrix, &
1747 25779 : storage%store_batched_repl, move_data=.TRUE.)
1748 :
1749 : CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
1750 25779 : transposed=storage%batched_trans, move_data=.TRUE.)
1751 25779 : CALL dbt_tas_destroy(storage%store_batched)
1752 25779 : DEALLOCATE (storage%store_batched)
1753 : END IF
1754 :
1755 141866 : IF (ASSOCIATED(storage%store_batched_repl)) THEN
1756 57266 : CALL dbt_tas_destroy(storage%store_batched_repl)
1757 57266 : DEALLOCATE (storage%store_batched_repl)
1758 : END IF
1759 : END ASSOCIATE
1760 :
1761 70933 : CALL dbt_tas_set_batched_state(matrix, state=2)
1762 :
1763 : END SUBROUTINE
1764 :
1765 1670268 : END MODULE
|