Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Matrix multiplication for tall-and-skinny matrices.
10 : !> This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
11 : !> as long as the two smaller dimensions have the same size.
12 : !> Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
13 : !> submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
14 : !> parameters (group sizes and process grid dimensions) can not be derived from matrix
15 : !> dimensions and need to be set manually.
16 : !> \author Patrick Seewald
17 : ! **************************************************************************************************
18 : MODULE dbt_tas_mm
19 : USE dbm_api, ONLY: &
20 : dbm_add, dbm_clear, dbm_copy, dbm_create, dbm_create_from_template, dbm_distribution_new, &
21 : dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
22 : dbm_get_distribution, dbm_get_name, dbm_get_nze, dbm_get_row_block_sizes, dbm_multiply, &
23 : dbm_redistribute, dbm_release, dbm_scale, dbm_type, dbm_zero
24 : USE dbt_tas_base, ONLY: &
25 : dbt_tas_clear, dbt_tas_copy, dbt_tas_create, dbt_tas_destroy, dbt_tas_distribution_new, &
26 : dbt_tas_filter, dbt_tas_get_info, dbt_tas_get_nze_total, dbt_tas_info, &
27 : dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
28 : dbt_tas_iterator_stop, dbt_tas_nblkcols_total, dbt_tas_nblkrows_total, dbt_tas_put_block, &
29 : dbt_tas_reserve_blocks
30 : USE dbt_tas_global, ONLY: dbt_tas_blk_size_one,&
31 : dbt_tas_default_distvec,&
32 : dbt_tas_dist_arb,&
33 : dbt_tas_dist_arb_default,&
34 : dbt_tas_dist_cyclic,&
35 : dbt_tas_distribution,&
36 : dbt_tas_rowcol_data
37 : USE dbt_tas_io, ONLY: dbt_tas_write_dist,&
38 : dbt_tas_write_matrix_info,&
39 : dbt_tas_write_split_info,&
40 : prep_output_unit
41 : USE dbt_tas_reshape_ops, ONLY: dbt_tas_merge,&
42 : dbt_tas_replicate,&
43 : dbt_tas_reshape
44 : USE dbt_tas_split, ONLY: &
45 : accept_pgrid_dims, colsplit, dbt_tas_create_split, dbt_tas_get_split_info, &
46 : dbt_tas_info_hold, dbt_tas_mp_comm, dbt_tas_release_info, default_nsplit_accept_ratio, &
47 : rowsplit
48 : USE dbt_tas_types, ONLY: dbt_tas_distribution_type,&
49 : dbt_tas_iterator,&
50 : dbt_tas_split_info,&
51 : dbt_tas_type
52 : USE dbt_tas_util, ONLY: array_eq,&
53 : swap
54 : USE kinds, ONLY: default_string_length,&
55 : dp,&
56 : int_8
57 : USE message_passing, ONLY: mp_cart_type
58 : #include "../../base/base_uses.f90"
59 :
60 : IMPLICIT NONE
61 : PRIVATE
62 :
63 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
64 :
65 : PUBLIC :: &
66 : dbt_tas_multiply, &
67 : dbt_tas_batched_mm_init, &
68 : dbt_tas_batched_mm_finalize, &
69 : dbt_tas_set_batched_state, &
70 : dbt_tas_batched_mm_complete
71 :
72 : CONTAINS
73 :
74 : ! **************************************************************************************************
75 : !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
76 : !> to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
77 : !> \param transa ...
78 : !> \param transb ...
79 : !> \param transc ...
80 : !> \param alpha ...
81 : !> \param matrix_a ...
82 : !> \param matrix_b ...
83 : !> \param beta ...
84 : !> \param matrix_c ...
85 : !> \param optimize_dist Whether distribution should be optimized internally. In the current
86 : !> implementation this guarantees optimal parameters only for dense matrices.
87 : !> \param split_opt optionally return split info containing optimal grid and split parameters.
88 : !> This can be used to choose optimal process grids for subsequent matrix
89 : !> multiplications with matrices of similar shape and sparsity.
90 : !> \param filter_eps ...
91 : !> \param flop ...
92 : !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
93 : !> (for internal use only)
94 : !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
95 : !> (for internal use only)
96 : !> \param retain_sparsity ...
97 : !> \param simple_split ...
98 : !> \param unit_nr unit number for logging output
99 : !> \param log_verbose only for testing: verbose output
100 : !> \author Patrick Seewald
101 : ! **************************************************************************************************
102 849060 : RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103 : optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104 : move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
105 :
106 : LOGICAL, INTENT(IN) :: transa, transb, transc
107 : REAL(dp), INTENT(IN) :: alpha
108 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b
109 : REAL(dp), INTENT(IN) :: beta
110 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_c
111 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
112 : TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL :: split_opt
113 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
114 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
115 : LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, &
116 : retain_sparsity, simple_split
117 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
118 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
119 :
120 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_multiply'
121 :
122 : INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123 : nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
124 : unit_nr_prv
125 : INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126 : INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c
127 : INTEGER(KIND=int_8), DIMENSION(3) :: dims
128 : INTEGER, DIMENSION(2) :: pdims, pdims_sub
129 : LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130 : simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131 : REAL(KIND=dp) :: filter_eps_prv
132 : TYPE(dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
133 3432045 : TYPE(dbt_tas_split_info) :: info, info_a, info_b, info_c
134 : TYPE(dbt_tas_type), POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135 : matrix_b_rs, matrix_c_rep, matrix_c_rs
136 201885 : TYPE(mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137 201885 : mp_comm_mm, mp_comm_opt
138 :
139 201885 : CALL timeset(routineN, handle)
140 201885 : CALL matrix_a%dist%info%mp_comm%sync()
141 201885 : CALL timeset("dbt_tas_total", handle2)
142 :
143 201885 : NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
144 :
145 201885 : unit_nr_prv = prep_output_unit(unit_nr)
146 :
147 201885 : IF (PRESENT(simple_split)) THEN
148 60327 : simple_split_prv = simple_split
149 : ELSE
150 141558 : simple_split_prv = .FALSE.
151 :
152 424674 : info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
153 141558 : IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .TRUE.
154 : END IF
155 :
156 201885 : nodata_3 = .TRUE.
157 201885 : IF (PRESENT(retain_sparsity)) THEN
158 4762 : IF (retain_sparsity) nodata_3 = .FALSE.
159 : END IF
160 :
161 : ! get prestored info for multiplication strategy in case of batched mm
162 201885 : batched_repl = 0
163 201885 : do_batched = .FALSE.
164 201885 : IF (matrix_a%do_batched > 0) THEN
165 40322 : do_batched = .TRUE.
166 40322 : IF (matrix_a%do_batched == 3) THEN
167 : CPASSERT(batched_repl == 0)
168 13249 : batched_repl = 1
169 : CALL dbt_tas_get_split_info( &
170 : dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
171 13249 : nsplit=nsplit_batched)
172 13249 : CPASSERT(nsplit_batched > 0)
173 : max_mm_dim_batched = 3
174 : END IF
175 : END IF
176 :
177 201885 : IF (matrix_b%do_batched > 0) THEN
178 14380 : do_batched = .TRUE.
179 14380 : IF (matrix_b%do_batched == 3) THEN
180 2960 : CPASSERT(batched_repl == 0)
181 2960 : batched_repl = 2
182 : CALL dbt_tas_get_split_info( &
183 : dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
184 2960 : nsplit=nsplit_batched)
185 2960 : CPASSERT(nsplit_batched > 0)
186 : max_mm_dim_batched = 1
187 : END IF
188 : END IF
189 :
190 201885 : IF (matrix_c%do_batched > 0) THEN
191 31805 : do_batched = .TRUE.
192 31805 : IF (matrix_c%do_batched == 3) THEN
193 6328 : CPASSERT(batched_repl == 0)
194 6328 : batched_repl = 3
195 : CALL dbt_tas_get_split_info( &
196 : dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
197 6328 : nsplit=nsplit_batched)
198 6328 : CPASSERT(nsplit_batched > 0)
199 : max_mm_dim_batched = 2
200 : END IF
201 : END IF
202 :
203 201885 : move_a = .FALSE.
204 201885 : move_b = .FALSE.
205 :
206 201885 : IF (PRESENT(move_data_a)) move_a = move_data_a
207 201885 : IF (PRESENT(move_data_b)) move_b = move_data_b
208 :
209 201885 : transa_prv = transa; transb_prv = transb; transc_prv = transc
210 :
211 605655 : dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
212 605655 : dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
213 605655 : dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
214 :
215 201885 : IF (unit_nr_prv > 0) THEN
216 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
217 : WRITE (unit_nr_prv, "(A)") &
218 : "DBT TAS MATRIX MULTIPLICATION: "// &
219 : TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
220 : TRIM(dbm_get_name(matrix_b%matrix))//" = "// &
221 34 : TRIM(dbm_get_name(matrix_c%matrix))
222 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
223 : END IF
224 201885 : IF (do_batched) THEN
225 84067 : IF (unit_nr_prv > 0) THEN
226 : WRITE (unit_nr_prv, "(T2,A)") &
227 0 : "BATCHED PROCESSING OF MATMUL"
228 0 : IF (batched_repl > 0) THEN
229 0 : WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
230 : END IF
231 : END IF
232 : END IF
233 :
234 201885 : IF (transa_prv) THEN
235 61294 : CALL swap(dims_a)
236 : END IF
237 :
238 201885 : IF (transb_prv) THEN
239 102801 : CALL swap(dims_b)
240 : END IF
241 :
242 605655 : dims_c = [dims_a(1), dims_b(2)]
243 :
244 201885 : IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
245 0 : CPABORT("inconsistent matrix dimensions")
246 : END IF
247 :
248 807540 : dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
249 :
250 201885 : IF (unit_nr_prv > 0) THEN
251 34 : WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
252 : END IF
253 :
254 201885 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
255 201885 : numproc = mp_comm%num_pe
256 :
257 : ! derive optimal matrix layout and split factor from occupancies
258 201885 : nze_a = dbt_tas_get_nze_total(matrix_a)
259 201885 : nze_b = dbt_tas_get_nze_total(matrix_b)
260 :
261 201885 : IF (.NOT. simple_split_prv) THEN
262 : CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263 : estimated_nze=nze_c, filter_eps=filter_eps, &
264 60443 : retain_sparsity=retain_sparsity)
265 :
266 241772 : max_mm_dim = MAXLOC(dims, 1)
267 60443 : nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
268 60443 : nsplit_opt = nsplit
269 :
270 60443 : IF (unit_nr_prv > 0) THEN
271 : WRITE (unit_nr_prv, "(T2,A)") &
272 34 : "MM PARAMETERS"
273 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
274 68 : (nze_c + numproc - 1)/numproc
275 :
276 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
277 : END IF
278 :
279 141442 : ELSEIF (batched_repl > 0) THEN
280 22537 : nsplit = nsplit_batched
281 22537 : nsplit_opt = nsplit
282 22537 : max_mm_dim = max_mm_dim_batched
283 22537 : IF (unit_nr_prv > 0) THEN
284 : WRITE (unit_nr_prv, "(T2,A)") &
285 0 : "MM PARAMETERS"
286 0 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
287 : END IF
288 :
289 : ELSE
290 118905 : nsplit = 0
291 475620 : max_mm_dim = MAXLOC(dims, 1)
292 : END IF
293 :
294 : ! reshape matrices to the optimal layout and split factor
295 201885 : split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
296 61922 : SELECT CASE (max_mm_dim)
297 : CASE (1)
298 :
299 : split_a = rowsplit; split_c = rowsplit
300 : CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301 : new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
302 : nsplit=nsplit, &
303 : opt_nsplit=batched_repl == 0, &
304 : split_rc_1=split_a, split_rc_2=split_c, &
305 : nodata2=nodata_3, comm_new=comm_tmp, &
306 61922 : move_data_1=move_a, unit_nr=unit_nr_prv)
307 :
308 61922 : info = dbt_tas_info(matrix_a_rs)
309 61922 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
310 :
311 61922 : new_b = .FALSE.
312 61922 : IF (matrix_b%do_batched <= 2) THEN
313 294810 : ALLOCATE (matrix_b_rs)
314 58962 : CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
315 58962 : transb_prv = .FALSE.
316 58962 : new_b = .TRUE.
317 : END IF
318 :
319 61922 : tr_case = transa_prv
320 :
321 123855 : IF (unit_nr_prv > 0) THEN
322 11 : IF (.NOT. tr_case) THEN
323 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
324 : ELSE
325 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
326 : END IF
327 : END IF
328 :
329 : CASE (2)
330 :
331 60613 : split_a = colsplit; split_b = rowsplit
332 : CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333 : optimize_dist=optimize_dist, &
334 : nsplit=nsplit, &
335 : opt_nsplit=batched_repl == 0, &
336 : split_rc_1=split_a, split_rc_2=split_b, &
337 : comm_new=comm_tmp, &
338 60613 : move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
339 :
340 60613 : info = dbt_tas_info(matrix_a_rs)
341 60613 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
342 :
343 60613 : IF (matrix_c%do_batched == 1) THEN
344 24039 : matrix_c%mm_storage%batched_beta = beta
345 36574 : ELSEIF (matrix_c%do_batched > 1) THEN
346 7484 : matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
347 : END IF
348 :
349 60613 : IF (matrix_c%do_batched <= 2) THEN
350 271425 : ALLOCATE (matrix_c_rs)
351 54285 : CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
352 54285 : transc_prv = .FALSE.
353 :
354 : ! just leave sparsity structure for retain sparsity but no values
355 54285 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
356 :
357 54285 : IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358 6328 : ELSEIF (matrix_c%do_batched == 3) THEN
359 6328 : matrix_c_rs => matrix_c%mm_storage%store_batched
360 : END IF
361 :
362 60613 : new_c = matrix_c%do_batched == 0
363 60613 : tr_case = transa_prv
364 :
365 121239 : IF (unit_nr_prv > 0) THEN
366 13 : IF (.NOT. tr_case) THEN
367 2 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
368 : ELSE
369 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
370 : END IF
371 : END IF
372 :
373 : CASE (3)
374 :
375 79350 : split_b = colsplit; split_c = colsplit
376 : CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377 : transc_prv, optimize_dist=optimize_dist, &
378 : nsplit=nsplit, &
379 : opt_nsplit=batched_repl == 0, &
380 : split_rc_1=split_b, split_rc_2=split_c, &
381 : nodata2=nodata_3, comm_new=comm_tmp, &
382 79350 : move_data_1=move_b, unit_nr=unit_nr_prv)
383 79350 : info = dbt_tas_info(matrix_b_rs)
384 79350 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
385 :
386 79350 : new_a = .FALSE.
387 79350 : IF (matrix_a%do_batched <= 2) THEN
388 330505 : ALLOCATE (matrix_a_rs)
389 66101 : CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
390 66101 : transa_prv = .FALSE.
391 66101 : new_a = .TRUE.
392 : END IF
393 :
394 79350 : tr_case = transb_prv
395 :
396 360585 : IF (unit_nr_prv > 0) THEN
397 10 : IF (.NOT. tr_case) THEN
398 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
399 : ELSE
400 10 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
401 : END IF
402 : END IF
403 :
404 : END SELECT
405 :
406 201885 : CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
407 :
408 201885 : numproc = mp_comm%num_pe
409 605655 : pdims_sub = mp_comm_group%num_pe_cart
410 :
411 201885 : opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
412 :
413 201885 : IF (PRESENT(filter_eps)) THEN
414 151555 : filter_eps_prv = filter_eps
415 : ELSE
416 50330 : filter_eps_prv = 0.0_dp
417 : END IF
418 :
419 201885 : IF (unit_nr_prv /= 0) THEN
420 45730 : IF (unit_nr_prv > 0) THEN
421 34 : WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
422 : END IF
423 45730 : CALL dbt_tas_write_split_info(info, unit_nr_prv)
424 45730 : IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
425 45730 : IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
426 45730 : IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
427 45730 : IF (unit_nr_prv > 0) THEN
428 34 : IF (opt_pgrid) THEN
429 0 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
430 : ELSE
431 34 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
432 : END IF
433 : END IF
434 : END IF
435 :
436 201885 : pdims = 0
437 201885 : CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
438 :
439 : ! Convert DBM submatrices to optimized process grids and multiply
440 61922 : SELECT CASE (max_mm_dim)
441 : CASE (1)
442 61922 : IF (matrix_b%do_batched <= 2) THEN
443 294810 : ALLOCATE (matrix_b_rep)
444 58962 : CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
445 58962 : IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
446 8478 : matrix_b%mm_storage%store_batched_repl => matrix_b_rep
447 8478 : CALL dbt_tas_set_batched_state(matrix_b, state=3)
448 : END IF
449 2960 : ELSEIF (matrix_b%do_batched == 3) THEN
450 2960 : matrix_b_rep => matrix_b%mm_storage%store_batched_repl
451 : END IF
452 :
453 61922 : IF (new_b) THEN
454 58962 : CALL dbt_tas_destroy(matrix_b_rs)
455 58962 : DEALLOCATE (matrix_b_rs)
456 : END IF
457 61922 : IF (unit_nr_prv /= 0) THEN
458 418 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
459 418 : CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
460 : END IF
461 :
462 61922 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
463 :
464 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
465 61922 : info_a = dbt_tas_info(matrix_a_rs)
466 61922 : CALL dbt_tas_info_hold(info_a)
467 :
468 61922 : IF (new_a) THEN
469 5608 : CALL dbt_tas_destroy(matrix_a_rs)
470 5608 : DEALLOCATE (matrix_a_rs)
471 : END IF
472 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473 61922 : move_data=matrix_b%do_batched == 0)
474 :
475 61922 : info_b = dbt_tas_info(matrix_b_rep)
476 61922 : CALL dbt_tas_info_hold(info_b)
477 :
478 61922 : IF (matrix_b%do_batched == 0) THEN
479 50484 : CALL dbt_tas_destroy(matrix_b_rep)
480 50484 : DEALLOCATE (matrix_b_rep)
481 : END IF
482 :
483 61922 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
484 :
485 61922 : info_c = dbt_tas_info(matrix_c_rs)
486 61922 : CALL dbt_tas_info_hold(info_c)
487 :
488 61922 : CALL matrix_a%dist%info%mp_comm%sync()
489 61922 : CALL timeset("dbt_tas_dbm", handle4)
490 61922 : IF (.NOT. tr_case) THEN
491 56562 : CALL timeset("dbt_tas_mm_1N", handle3)
492 :
493 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
494 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495 56562 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496 56562 : CALL timestop(handle3)
497 : ELSE
498 5360 : CALL timeset("dbt_tas_mm_1T", handle3)
499 : CALL dbm_multiply(transa=.TRUE., transb=.FALSE., alpha=alpha, &
500 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501 5360 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
502 :
503 5360 : CALL timestop(handle3)
504 : END IF
505 61922 : CALL matrix_a%dist%info%mp_comm%sync()
506 61922 : CALL timestop(handle4)
507 :
508 61922 : CALL dbm_release(matrix_a_mm)
509 61922 : CALL dbm_release(matrix_b_mm)
510 :
511 61922 : nze_c = dbm_get_nze(matrix_c_mm)
512 :
513 61922 : IF (.NOT. new_c) THEN
514 56506 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
515 : ELSE
516 5416 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
517 : END IF
518 :
519 61922 : CALL dbm_release(matrix_c_mm)
520 :
521 61922 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
522 :
523 248106 : IF (unit_nr_prv /= 0) THEN
524 418 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
525 : END IF
526 :
527 : CASE (2)
528 60613 : IF (matrix_c%do_batched <= 1) THEN
529 265645 : ALLOCATE (matrix_c_rep)
530 53129 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
531 53129 : IF (matrix_c%do_batched == 1) THEN
532 24039 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
533 24039 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
534 : END IF
535 7484 : ELSEIF (matrix_c%do_batched == 2) THEN
536 5780 : ALLOCATE (matrix_c_rep)
537 1156 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
538 : ! just leave sparsity structure for retain sparsity but no values
539 1156 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
540 1156 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
541 1156 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
542 6328 : ELSEIF (matrix_c%do_batched == 3) THEN
543 6328 : matrix_c_rep => matrix_c%mm_storage%store_batched_repl
544 : END IF
545 :
546 60613 : IF (unit_nr_prv /= 0) THEN
547 20208 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
548 20208 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
549 : END IF
550 :
551 60613 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
552 :
553 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
554 60613 : info_a = dbt_tas_info(matrix_a_rs)
555 60613 : CALL dbt_tas_info_hold(info_a)
556 :
557 60613 : IF (new_a) THEN
558 486 : CALL dbt_tas_destroy(matrix_a_rs)
559 486 : DEALLOCATE (matrix_a_rs)
560 : END IF
561 :
562 60613 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
563 :
564 60613 : info_b = dbt_tas_info(matrix_b_rs)
565 60613 : CALL dbt_tas_info_hold(info_b)
566 :
567 60613 : IF (new_b) THEN
568 634 : CALL dbt_tas_destroy(matrix_b_rs)
569 634 : DEALLOCATE (matrix_b_rs)
570 : END IF
571 :
572 60613 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
573 :
574 60613 : info_c = dbt_tas_info(matrix_c_rep)
575 60613 : CALL dbt_tas_info_hold(info_c)
576 :
577 60613 : CALL matrix_a%dist%info%mp_comm%sync()
578 60613 : CALL timeset("dbt_tas_dbm", handle4)
579 60613 : CALL timeset("dbt_tas_mm_2", handle3)
580 : CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581 : matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582 60613 : filter_eps=filter_eps_prv/REAL(nsplit, KIND=dp), retain_sparsity=retain_sparsity, flop=flop)
583 60613 : CALL matrix_a%dist%info%mp_comm%sync()
584 60613 : CALL timestop(handle3)
585 60613 : CALL timestop(handle4)
586 :
587 60613 : CALL dbm_release(matrix_a_mm)
588 60613 : CALL dbm_release(matrix_b_mm)
589 :
590 60613 : nze_c = dbm_get_nze(matrix_c_mm)
591 :
592 60613 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
593 60613 : nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
594 :
595 60613 : CALL dbm_release(matrix_c_mm)
596 :
597 60613 : IF (unit_nr_prv /= 0) THEN
598 20208 : CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
599 : END IF
600 :
601 60613 : IF (matrix_c%do_batched == 0) THEN
602 29090 : CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
603 : ELSE
604 31523 : matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
605 : END IF
606 :
607 60613 : IF (matrix_c%do_batched == 0) THEN
608 29090 : CALL dbt_tas_destroy(matrix_c_rep)
609 29090 : DEALLOCATE (matrix_c_rep)
610 : END IF
611 :
612 60613 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
613 :
614 : ! set upper limit on memory consumption for replicated matrix and complete batched mm
615 : ! if limit is exceeded
616 304669 : IF (nze_c_sum > default_nsplit_accept_ratio*MAX(nze_a, nze_b)) THEN
617 1604 : CALL dbt_tas_batched_mm_complete(matrix_c)
618 : END IF
619 :
620 : CASE (3)
621 79350 : IF (matrix_a%do_batched <= 2) THEN
622 330505 : ALLOCATE (matrix_a_rep)
623 66101 : CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
624 66101 : IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
625 22817 : matrix_a%mm_storage%store_batched_repl => matrix_a_rep
626 22817 : CALL dbt_tas_set_batched_state(matrix_a, state=3)
627 : END IF
628 13249 : ELSEIF (matrix_a%do_batched == 3) THEN
629 13249 : matrix_a_rep => matrix_a%mm_storage%store_batched_repl
630 : END IF
631 :
632 79350 : IF (new_a) THEN
633 66101 : CALL dbt_tas_destroy(matrix_a_rs)
634 66101 : DEALLOCATE (matrix_a_rs)
635 : END IF
636 79350 : IF (unit_nr_prv /= 0) THEN
637 25104 : CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
638 25104 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
639 : END IF
640 :
641 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642 79350 : move_data=matrix_a%do_batched == 0)
643 :
644 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
645 79350 : info_a = dbt_tas_info(matrix_a_rep)
646 79350 : CALL dbt_tas_info_hold(info_a)
647 :
648 79350 : IF (matrix_a%do_batched == 0) THEN
649 43284 : CALL dbt_tas_destroy(matrix_a_rep)
650 43284 : DEALLOCATE (matrix_a_rep)
651 : END IF
652 :
653 79350 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
654 :
655 79350 : info_b = dbt_tas_info(matrix_b_rs)
656 79350 : CALL dbt_tas_info_hold(info_b)
657 :
658 79350 : IF (new_b) THEN
659 16 : CALL dbt_tas_destroy(matrix_b_rs)
660 16 : DEALLOCATE (matrix_b_rs)
661 : END IF
662 79350 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
663 :
664 79350 : info_c = dbt_tas_info(matrix_c_rs)
665 79350 : CALL dbt_tas_info_hold(info_c)
666 :
667 79350 : CALL matrix_a%dist%info%mp_comm%sync()
668 79350 : CALL timeset("dbt_tas_dbm", handle4)
669 79350 : IF (.NOT. tr_case) THEN
670 37024 : CALL timeset("dbt_tas_mm_3N", handle3)
671 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
672 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673 37024 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674 37024 : CALL timestop(handle3)
675 : ELSE
676 42326 : CALL timeset("dbt_tas_mm_3T", handle3)
677 : CALL dbm_multiply(transa=.FALSE., transb=.TRUE., alpha=alpha, &
678 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679 42326 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680 42326 : CALL timestop(handle3)
681 : END IF
682 79350 : CALL matrix_a%dist%info%mp_comm%sync()
683 79350 : CALL timestop(handle4)
684 :
685 79350 : CALL dbm_release(matrix_a_mm)
686 79350 : CALL dbm_release(matrix_b_mm)
687 :
688 79350 : nze_c = dbm_get_nze(matrix_c_mm)
689 :
690 79350 : IF (.NOT. new_c) THEN
691 73768 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
692 : ELSE
693 5582 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
694 : END IF
695 :
696 79350 : CALL dbm_release(matrix_c_mm)
697 :
698 79350 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
699 :
700 519285 : IF (unit_nr_prv /= 0) THEN
701 25104 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
702 : END IF
703 : END SELECT
704 :
705 201885 : CALL mp_comm_mm%free()
706 :
707 201885 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
708 :
709 201885 : IF (PRESENT(split_opt)) THEN
710 93838 : SELECT CASE (max_mm_dim)
711 : CASE (1, 3)
712 93838 : CALL mp_comm%sum(nze_c)
713 : CASE (2)
714 47672 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
715 47672 : CALL mp_comm%sum(nze_c)
716 189182 : CALL mp_comm%max(nze_c)
717 :
718 : END SELECT
719 141510 : nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
720 : ! ideally we should rederive the split factor from the actual sparsity of C, but
721 : ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
722 141510 : mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
723 141510 : CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
724 141510 : IF (unit_nr_prv > 0) THEN
725 : WRITE (unit_nr_prv, "(T2,A)") &
726 10 : "MM PARAMETERS"
727 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
728 20 : (nze_c + numproc - 1)/numproc
729 :
730 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
731 : END IF
732 :
733 : END IF
734 :
735 201885 : IF (new_c) THEN
736 40088 : CALL dbm_scale(matrix_c%matrix, beta)
737 : CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., &
738 : transposed=(transc_prv .NEQV. transc), &
739 40088 : move_data=.TRUE.)
740 40088 : CALL dbt_tas_destroy(matrix_c_rs)
741 40088 : DEALLOCATE (matrix_c_rs)
742 40088 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
743 161797 : ELSEIF (matrix_c%do_batched > 0) THEN
744 31797 : IF (matrix_c%mm_storage%batched_out) THEN
745 31523 : matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
746 : END IF
747 : END IF
748 :
749 201885 : IF (PRESENT(move_data_a)) THEN
750 201837 : IF (move_data_a) CALL dbt_tas_clear(matrix_a)
751 : END IF
752 201885 : IF (PRESENT(move_data_b)) THEN
753 201837 : IF (move_data_b) CALL dbt_tas_clear(matrix_b)
754 : END IF
755 :
756 201885 : IF (PRESENT(flop)) THEN
757 90111 : CALL mp_comm%sum(flop)
758 90111 : flop = (flop + numproc - 1)/numproc
759 : END IF
760 :
761 201885 : IF (PRESENT(optimize_dist)) THEN
762 48 : IF (optimize_dist) CALL comm_tmp%free()
763 : END IF
764 201885 : IF (unit_nr_prv > 0) THEN
765 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
766 34 : WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
767 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
768 : END IF
769 :
770 201885 : CALL dbt_tas_release_info(info_a)
771 201885 : CALL dbt_tas_release_info(info_b)
772 201885 : CALL dbt_tas_release_info(info_c)
773 :
774 201885 : CALL matrix_a%dist%info%mp_comm%sync()
775 201885 : CALL timestop(handle2)
776 201885 : CALL timestop(handle)
777 :
778 403770 : END SUBROUTINE
779 :
780 : ! **************************************************************************************************
781 : !> \brief ...
782 : !> \param matrix_in ...
783 : !> \param matrix_out ...
784 : !> \param local_copy ...
785 : !> \param alpha ...
786 : !> \author Patrick Seewald
787 : ! **************************************************************************************************
788 201885 : SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789 : TYPE(dbm_type), INTENT(IN) :: matrix_in
790 : TYPE(dbm_type), INTENT(INOUT) :: matrix_out
791 : LOGICAL, INTENT(IN), OPTIONAL :: local_copy
792 : REAL(dp), INTENT(IN) :: alpha
793 :
794 : LOGICAL :: local_copy_prv
795 : TYPE(dbm_type) :: matrix_tmp
796 :
797 201885 : IF (PRESENT(local_copy)) THEN
798 201885 : local_copy_prv = local_copy
799 : ELSE
800 : local_copy_prv = .FALSE.
801 : END IF
802 :
803 201885 : IF (alpha /= 1.0_dp) THEN
804 132206 : CALL dbm_scale(matrix_out, alpha)
805 : END IF
806 :
807 201885 : IF (.NOT. local_copy_prv) THEN
808 0 : CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
809 0 : CALL dbm_redistribute(matrix_in, matrix_tmp)
810 0 : CALL dbm_add(matrix_out, matrix_tmp)
811 0 : CALL dbm_release(matrix_tmp)
812 : ELSE
813 201885 : CALL dbm_add(matrix_out, matrix_in)
814 : END IF
815 :
816 201885 : END SUBROUTINE
817 :
818 : ! **************************************************************************************************
819 : !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
820 : !> the same process grid as the other 2 matrices.
821 : !> \param mp_comm communicator that defines Cartesian topology
822 : !> \param matrix_in ...
823 : !> \param matrix_out ...
824 : !> \param transposed Whether matrix_out should be transposed
825 : !> \param nodata Data of matrix_in should not be copied to matrix_out
826 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
827 : !> \author Patrick Seewald
828 : ! **************************************************************************************************
829 1255436 : SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
830 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm
831 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
832 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
833 : LOGICAL, INTENT(IN) :: transposed
834 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
835 :
836 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_small'
837 :
838 : INTEGER :: handle
839 : INTEGER(KIND=int_8), DIMENSION(2) :: dims
840 : INTEGER, DIMENSION(2) :: pdims
841 : LOGICAL :: nodata_prv
842 179348 : TYPE(dbt_tas_dist_arb) :: new_col_dist, new_row_dist
843 896740 : TYPE(dbt_tas_distribution_type) :: dist
844 :
845 179348 : CALL timeset(routineN, handle)
846 :
847 179348 : IF (PRESENT(nodata)) THEN
848 54285 : nodata_prv = nodata
849 : ELSE
850 : nodata_prv = .FALSE.
851 : END IF
852 :
853 538044 : pdims = mp_comm%num_pe_cart
854 :
855 538044 : dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
856 :
857 179348 : IF (transposed) CALL swap(dims)
858 :
859 179348 : IF (.NOT. transposed) THEN
860 125893 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
861 125893 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
862 125893 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
863 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
864 125893 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
865 : ELSE
866 53455 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
867 53455 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
868 53455 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
869 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
870 53455 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
871 : END IF
872 179348 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
873 :
874 179348 : CALL timestop(handle)
875 :
876 179348 : END SUBROUTINE
877 :
878 : ! **************************************************************************************************
879 : !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
880 : !> with the same split factor.
881 : !> \param matrix1_in ...
882 : !> \param matrix2_in ...
883 : !> \param matrix1_out ...
884 : !> \param matrix2_out ...
885 : !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
886 : !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
887 : !> \param trans1 transpose flag of matrix1_in for multiplication
888 : !> \param trans2 transpose flag of matrix2_in for multiplication
889 : !> \param optimize_dist experimental: optimize matrix splitting and distribution
890 : !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
891 : !> \param opt_nsplit ...
892 : !> \param split_rc_1 Whether to split rows or columns for matrix 1
893 : !> \param split_rc_2 Whether to split rows or columns for matrix 2
894 : !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
895 : !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
896 : !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
897 : !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
898 : !> \param comm_new returns the new communicator only if optimize_dist
899 : !> \param unit_nr output unit
900 : !> \author Patrick Seewald
901 : ! **************************************************************************************************
902 201885 : SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903 : optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904 : move_data_1, move_data_2, comm_new, unit_nr)
905 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix1_in, matrix2_in
906 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix1_out, matrix2_out
907 : LOGICAL, INTENT(OUT) :: new1, new2
908 : LOGICAL, INTENT(INOUT) :: trans1, trans2
909 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
910 : INTEGER, INTENT(IN), OPTIONAL :: nsplit
911 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
912 : INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2
913 : LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
914 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
915 : TYPE(mp_cart_type), INTENT(OUT), OPTIONAL :: comm_new
916 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
917 :
918 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible'
919 :
920 : INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
921 : unit_nr_prv
922 : INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923 : INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref
924 : INTEGER, DIMENSION(2) :: pdims
925 : LOGICAL :: nodata1_prv, nodata2_prv, &
926 : optimize_dist_prv, trans1_newdist, &
927 : trans2_newdist
928 : TYPE(dbt_tas_dist_cyclic) :: col_dist_1, col_dist_2, row_dist_1, &
929 : row_dist_2
930 1816965 : TYPE(dbt_tas_distribution_type) :: dist_1, dist_2
931 1009425 : TYPE(dbt_tas_split_info) :: split_info
932 201885 : TYPE(mp_cart_type) :: mp_comm
933 :
934 201885 : CALL timeset(routineN, handle)
935 201885 : new1 = .FALSE.; new2 = .FALSE.
936 :
937 201885 : IF (PRESENT(nodata1)) THEN
938 0 : nodata1_prv = nodata1
939 : ELSE
940 : nodata1_prv = .FALSE.
941 : END IF
942 :
943 201885 : IF (PRESENT(nodata2)) THEN
944 141272 : nodata2_prv = nodata2
945 : ELSE
946 : nodata2_prv = .FALSE.
947 : END IF
948 :
949 201885 : unit_nr_prv = prep_output_unit(unit_nr)
950 :
951 201885 : NULLIFY (matrix1_out, matrix2_out)
952 :
953 201885 : IF (PRESENT(optimize_dist)) THEN
954 48 : optimize_dist_prv = optimize_dist
955 : ELSE
956 : optimize_dist_prv = .FALSE.
957 : END IF
958 :
959 605655 : dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
960 605655 : dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
961 201885 : nze1 = dbt_tas_get_nze_total(matrix1_in)
962 201885 : nze2 = dbt_tas_get_nze_total(matrix2_in)
963 :
964 201885 : IF (trans1) split_rc_1 = MOD(split_rc_1, 2) + 1
965 :
966 201885 : IF (trans2) split_rc_2 = MOD(split_rc_2, 2) + 1
967 :
968 201885 : IF (nze1 >= nze2) THEN
969 190745 : ref = 1
970 190745 : split_rc_ref = split_rc_1
971 190745 : dims_ref = dims1
972 : ELSE
973 11140 : ref = 2
974 11140 : split_rc_ref = split_rc_2
975 11140 : dims_ref = dims2
976 : END IF
977 :
978 201885 : IF (PRESENT(nsplit)) THEN
979 201885 : nsplit_prv = nsplit
980 : ELSE
981 0 : nsplit_prv = 0
982 : END IF
983 :
984 201885 : IF (optimize_dist_prv) THEN
985 48 : CPASSERT(PRESENT(comm_new))
986 : END IF
987 :
988 201837 : IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
989 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990 189567 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
991 189567 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
992 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993 189567 : move_data=move_data_2, nodata=nodata2, opt_nsplit=.FALSE.)
994 189567 : IF (unit_nr_prv > 0) THEN
995 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
996 10 : TRIM(dbm_get_name(matrix1_in%matrix)), &
997 20 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
998 10 : IF (new1) THEN
999 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1000 0 : TRIM(dbm_get_name(matrix1_in%matrix)), ": Yes"
1001 : ELSE
1002 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1003 20 : TRIM(dbm_get_name(matrix1_in%matrix)), ": No"
1004 : END IF
1005 10 : IF (new2) THEN
1006 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1007 0 : TRIM(dbm_get_name(matrix2_in%matrix)), ": Yes"
1008 : ELSE
1009 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1010 20 : TRIM(dbm_get_name(matrix2_in%matrix)), ": No"
1011 : END IF
1012 : END IF
1013 : ELSE
1014 :
1015 12270 : IF (optimize_dist_prv) THEN
1016 48 : IF (unit_nr_prv > 0) THEN
1017 24 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
1018 24 : TRIM(dbm_get_name(matrix1_in%matrix)), &
1019 48 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
1020 : END IF
1021 :
1022 48 : trans1_newdist = (split_rc_1 == colsplit)
1023 48 : trans2_newdist = (split_rc_2 == colsplit)
1024 :
1025 48 : IF (trans1_newdist) THEN
1026 24 : CALL swap(dims1)
1027 24 : trans1 = .NOT. trans1
1028 : END IF
1029 :
1030 48 : IF (trans2_newdist) THEN
1031 24 : CALL swap(dims2)
1032 24 : trans2 = .NOT. trans2
1033 : END IF
1034 :
1035 48 : IF (nsplit_prv == 0) THEN
1036 0 : SELECT CASE (split_rc_ref)
1037 : CASE (rowsplit)
1038 0 : d1 = dims_ref(1)
1039 0 : d2 = dims_ref(2)
1040 : CASE (colsplit)
1041 0 : d1 = dims_ref(2)
1042 0 : d2 = dims_ref(1)
1043 : END SELECT
1044 0 : nsplit_prv = INT((d1 - 1)/d2 + 1)
1045 : END IF
1046 :
1047 48 : CPASSERT(nsplit_prv > 0)
1048 :
1049 48 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
1050 48 : comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
1051 48 : CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
1052 :
1053 144 : pdims = comm_new%num_pe_cart
1054 :
1055 : ! use a very simple cyclic distribution that may not be load balanced if block
1056 : ! sizes are not equal. However we can not use arbitrary distributions
1057 : ! for large dimensions since this would require storing distribution vectors as arrays
1058 : ! which can not be stored for large dimensions.
1059 48 : row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
1060 48 : col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
1061 :
1062 48 : row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
1063 48 : col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
1064 :
1065 48 : CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
1066 48 : CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
1067 48 : CALL dbt_tas_release_info(split_info)
1068 :
1069 240 : ALLOCATE (matrix1_out)
1070 48 : IF (.NOT. trans1_newdist) THEN
1071 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1072 24 : matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
1073 :
1074 : ELSE
1075 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1076 24 : matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
1077 : END IF
1078 :
1079 240 : ALLOCATE (matrix2_out)
1080 48 : IF (.NOT. trans2_newdist) THEN
1081 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1082 24 : matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
1083 : ELSE
1084 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1085 24 : matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
1086 : END IF
1087 :
1088 48 : IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089 48 : IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1090 48 : new1 = .TRUE.
1091 48 : new2 = .TRUE.
1092 :
1093 : ELSE
1094 11560 : SELECT CASE (ref)
1095 : CASE (1)
1096 11560 : IF (unit_nr_prv > 0) THEN
1097 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1098 0 : TRIM(dbm_get_name(matrix2_in%matrix))
1099 : END IF
1100 :
1101 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102 11560 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1103 :
1104 57800 : ALLOCATE (matrix2_out)
1105 : CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106 11560 : nodata=nodata2, move_data=move_data_2)
1107 11560 : new2 = .TRUE.
1108 : CASE (2)
1109 710 : IF (unit_nr_prv > 0) THEN
1110 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1111 0 : TRIM(dbm_get_name(matrix1_in%matrix))
1112 : END IF
1113 :
1114 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115 710 : move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1116 :
1117 3550 : ALLOCATE (matrix1_out)
1118 : CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119 710 : nodata=nodata1, move_data=move_data_1)
1120 25250 : new1 = .TRUE.
1121 : END SELECT
1122 : END IF
1123 : END IF
1124 :
1125 201885 : IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
1126 201885 : IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
1127 :
1128 201885 : CALL timestop(handle)
1129 :
1130 605655 : END SUBROUTINE
1131 :
1132 : ! **************************************************************************************************
1133 : !> \brief Change split factor without redistribution
1134 : !> \param matrix_in ...
1135 : !> \param matrix_out ...
1136 : !> \param nsplit new split factor, set to 0 to not change split of matrix_in
1137 : !> \param split_rowcol split rows or columns
1138 : !> \param is_new whether matrix_out is new or a pointer to matrix_in
1139 : !> \param opt_nsplit whether nsplit should be optimized for current process grid
1140 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1141 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1142 : !> \author Patrick Seewald
1143 : ! **************************************************************************************************
1144 391404 : SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1145 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_in
1146 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix_out
1147 : INTEGER, INTENT(IN) :: nsplit, split_rowcol
1148 : LOGICAL, INTENT(OUT) :: is_new
1149 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
1150 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data
1151 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1152 :
1153 : CHARACTER(len=default_string_length) :: name
1154 : INTEGER :: handle, nsplit_new, nsplit_old, &
1155 : nsplit_prv, split_rc
1156 : LOGICAL :: nodata_prv
1157 1957020 : TYPE(dbt_tas_distribution_type) :: dist
1158 1957020 : TYPE(dbt_tas_split_info) :: split_info
1159 391404 : TYPE(mp_cart_type) :: mp_comm
1160 :
1161 1565616 : CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
1162 782808 : CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize
1163 : CHARACTER(LEN=*), PARAMETER :: routineN = 'change_split'
1164 :
1165 391404 : NULLIFY (matrix_out)
1166 :
1167 391404 : is_new = .TRUE.
1168 :
1169 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
1170 391404 : split_rowcol=split_rc, nsplit=nsplit_old)
1171 :
1172 391404 : IF (nsplit == 0) THEN
1173 118905 : IF (split_rowcol == split_rc) THEN
1174 116085 : matrix_out => matrix_in
1175 116085 : is_new = .FALSE.
1176 116085 : RETURN
1177 : ELSE
1178 2820 : nsplit_prv = 1
1179 : END IF
1180 : ELSE
1181 272499 : nsplit_prv = nsplit
1182 : END IF
1183 :
1184 275319 : CALL timeset(routineN, handle)
1185 :
1186 275319 : nodata_prv = .FALSE.
1187 275319 : IF (PRESENT(nodata)) nodata_prv = nodata
1188 :
1189 : CALL dbt_tas_get_info(matrix_in, name=name, &
1190 : row_blk_size=rbsize, col_blk_size=cbsize, &
1191 : proc_row_dist=rdist, proc_col_dist=cdist)
1192 :
1193 275319 : CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1194 :
1195 275319 : CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
1196 :
1197 275319 : IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1198 269943 : matrix_out => matrix_in
1199 269943 : is_new = .FALSE.
1200 269943 : CALL dbt_tas_release_info(split_info)
1201 269943 : CALL timestop(handle)
1202 269943 : RETURN
1203 : END IF
1204 :
1205 : CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1206 5376 : split_info=split_info)
1207 :
1208 5376 : CALL dbt_tas_release_info(split_info)
1209 :
1210 26880 : ALLOCATE (matrix_out)
1211 5376 : CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.TRUE.)
1212 :
1213 5376 : IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
1214 :
1215 5376 : IF (PRESENT(move_data)) THEN
1216 5376 : IF (.NOT. nodata_prv) THEN
1217 5376 : IF (move_data) CALL dbt_tas_clear(matrix_in)
1218 5376 : move_data = .TRUE.
1219 : END IF
1220 : END IF
1221 :
1222 5376 : CALL timestop(handle)
1223 1378704 : END SUBROUTINE
1224 :
1225 : ! **************************************************************************************************
1226 : !> \brief Check whether matrices have same distribution and same split.
1227 : !> \param mat_a ...
1228 : !> \param mat_b ...
1229 : !> \param split_rc_a ...
1230 : !> \param split_rc_b ...
1231 : !> \param unit_nr ...
1232 : !> \return ...
1233 : !> \author Patrick Seewald
1234 : ! **************************************************************************************************
1235 201837 : FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1236 : TYPE(dbt_tas_type), INTENT(IN) :: mat_a, mat_b
1237 : INTEGER, INTENT(IN) :: split_rc_a, split_rc_b
1238 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1239 : LOGICAL :: dist_compatible
1240 :
1241 : INTEGER :: numproc, same_local_rowcols, &
1242 : split_check_a, split_check_b, &
1243 : unit_nr_prv
1244 201837 : INTEGER(int_8), ALLOCATABLE, DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245 : INTEGER, DIMENSION(2) :: pdims_a, pdims_b
1246 1816533 : TYPE(dbt_tas_split_info) :: info_a, info_b
1247 :
1248 201837 : unit_nr_prv = prep_output_unit(unit_nr)
1249 :
1250 201837 : dist_compatible = .FALSE.
1251 :
1252 201837 : info_a = dbt_tas_info(mat_a)
1253 201837 : info_b = dbt_tas_info(mat_b)
1254 201837 : CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
1255 201837 : CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
1256 201837 : IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
1257 12190 : IF (unit_nr_prv > 0) THEN
1258 0 : WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
1259 0 : WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
1260 : END IF
1261 12246 : RETURN
1262 : END IF
1263 :
1264 : ! check if communicators are equivalent
1265 : ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1266 : ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
1267 189647 : numproc = info_b%mp_comm%num_pe
1268 568941 : pdims_a = info_a%mp_comm%num_pe_cart
1269 568941 : pdims_b = info_b%mp_comm%num_pe_cart
1270 189647 : IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1271 56 : IF (unit_nr_prv > 0) THEN
1272 0 : WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
1273 : END IF
1274 56 : RETURN
1275 : END IF
1276 :
1277 : ! check that distribution is the same by comparing local rows / columns for each matrix
1278 135040 : SELECT CASE (split_rc_a)
1279 : CASE (rowsplit)
1280 135040 : CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
1281 135040 : CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
1282 : CASE (colsplit)
1283 54551 : CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
1284 244142 : CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
1285 : END SELECT
1286 :
1287 189591 : same_local_rowcols = MERGE(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
1288 :
1289 189591 : CALL info_a%mp_comm%sum(same_local_rowcols)
1290 :
1291 189591 : IF (same_local_rowcols == numproc) THEN
1292 : dist_compatible = .TRUE.
1293 : ELSE
1294 24 : IF (unit_nr_prv > 0) THEN
1295 0 : WRITE (unit_nr_prv, *) "local rowcols not compatible"
1296 0 : WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
1297 0 : WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
1298 : END IF
1299 : END IF
1300 :
1301 403674 : END FUNCTION
1302 :
1303 : ! **************************************************************************************************
1304 : !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
1305 : !> \param template ...
1306 : !> \param matrix_in ...
1307 : !> \param matrix_out ...
1308 : !> \param trans ...
1309 : !> \param split_rc ...
1310 : !> \param nodata ...
1311 : !> \param move_data ...
1312 : !> \author Patrick Seewald
1313 : ! **************************************************************************************************
1314 85890 : SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1315 : TYPE(dbt_tas_type), INTENT(IN) :: template
1316 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1317 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1318 : LOGICAL, INTENT(INOUT) :: trans
1319 : INTEGER, INTENT(IN) :: split_rc
1320 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1321 :
1322 12270 : CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1323 :
1324 73620 : TYPE(dbt_tas_distribution_type) :: dist_new
1325 134970 : TYPE(dbt_tas_split_info) :: info_template, info_matrix
1326 : INTEGER :: dim_split_template, dim_split_matrix, &
1327 : handle
1328 : INTEGER, DIMENSION(2) :: pdims
1329 : LOGICAL :: nodata_prv, transposed
1330 12270 : TYPE(mp_cart_type) :: mp_comm
1331 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template'
1332 :
1333 12270 : CALL timeset(routineN, handle)
1334 :
1335 12270 : IF (PRESENT(nodata)) THEN
1336 10966 : nodata_prv = nodata
1337 : ELSE
1338 : nodata_prv = .FALSE.
1339 : END IF
1340 :
1341 12270 : info_template = dbt_tas_info(template)
1342 12270 : info_matrix = dbt_tas_info(matrix_in)
1343 :
1344 12270 : dim_split_template = info_template%split_rowcol
1345 12270 : dim_split_matrix = split_rc
1346 :
1347 12270 : transposed = dim_split_template .NE. dim_split_matrix
1348 12270 : IF (transposed) trans = .NOT. trans
1349 :
1350 36810 : pdims = info_template%mp_comm%num_pe_cart
1351 :
1352 6244 : SELECT CASE (dim_split_template)
1353 : CASE (1)
1354 6244 : IF (.NOT. transposed) THEN
1355 44 : ALLOCATE (row_dist, source=template%dist%row_dist)
1356 44 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1357 : ELSE
1358 6200 : ALLOCATE (row_dist, source=template%dist%row_dist)
1359 6200 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1360 : END IF
1361 : CASE (2)
1362 12270 : IF (.NOT. transposed) THEN
1363 120 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1364 120 : ALLOCATE (col_dist, source=template%dist%col_dist)
1365 : ELSE
1366 11932 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1367 11932 : ALLOCATE (col_dist, source=template%dist%col_dist)
1368 : END IF
1369 : END SELECT
1370 :
1371 12270 : CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
1372 12270 : CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1373 12270 : IF (.NOT. transposed) THEN
1374 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1375 104 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
1376 : ELSE
1377 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1378 12166 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
1379 : END IF
1380 :
1381 12270 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1382 :
1383 12270 : CALL timestop(handle)
1384 :
1385 24540 : END SUBROUTINE
1386 :
1387 : ! **************************************************************************************************
1388 : !> \brief Estimate sparsity pattern of C resulting from A x B = C
1389 : !> by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
1390 : !> \param transa ...
1391 : !> \param transb ...
1392 : !> \param transc ...
1393 : !> \param matrix_a ...
1394 : !> \param matrix_b ...
1395 : !> \param matrix_c ...
1396 : !> \param estimated_nze ...
1397 : !> \param filter_eps ...
1398 : !> \param unit_nr ...
1399 : !> \param retain_sparsity ...
1400 : !> \author Patrick Seewald
1401 : ! **************************************************************************************************
1402 60443 : SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403 : estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404 : LOGICAL, INTENT(IN) :: transa, transb, transc
1405 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c
1406 : INTEGER(int_8), INTENT(OUT) :: estimated_nze
1407 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
1408 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1409 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1410 :
1411 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_estimate_result_nze'
1412 :
1413 : INTEGER :: col_size, handle, row_size
1414 : INTEGER(int_8) :: col, row
1415 : LOGICAL :: retain_sparsity_prv
1416 : TYPE(dbt_tas_iterator) :: iter
1417 : TYPE(dbt_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1418 : matrix_c_bnorm
1419 60443 : TYPE(mp_cart_type) :: mp_comm
1420 :
1421 60443 : CALL timeset(routineN, handle)
1422 :
1423 60443 : IF (PRESENT(retain_sparsity)) THEN
1424 116 : retain_sparsity_prv = retain_sparsity
1425 : ELSE
1426 : retain_sparsity_prv = .FALSE.
1427 : END IF
1428 :
1429 116 : IF (.NOT. retain_sparsity_prv) THEN
1430 1146213 : ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431 60327 : CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432 60327 : CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433 60327 : CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
1434 :
1435 : CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
1436 : matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437 : filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
1438 60327 : simple_split=.TRUE., unit_nr=unit_nr)
1439 60327 : CALL dbt_tas_destroy(matrix_a_bnorm)
1440 60327 : CALL dbt_tas_destroy(matrix_b_bnorm)
1441 :
1442 60327 : DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1443 : ELSE
1444 : matrix_c_bnorm => matrix_c
1445 : END IF
1446 :
1447 60443 : estimated_nze = 0
1448 : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
1449 60443 : !$OMP PRIVATE(iter,row,col,row_size,col_size)
1450 : CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
1451 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1452 : CALL dbt_tas_iterator_next_block(iter, row, col)
1453 : row_size = matrix_c%row_blk_size%data(row)
1454 : col_size = matrix_c%col_blk_size%data(col)
1455 : estimated_nze = estimated_nze + row_size*col_size
1456 : END DO
1457 : CALL dbt_tas_iterator_stop(iter)
1458 : !$OMP END PARALLEL
1459 :
1460 60443 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
1461 60443 : CALL mp_comm%sum(estimated_nze)
1462 :
1463 60443 : IF (.NOT. retain_sparsity_prv) THEN
1464 60327 : CALL dbt_tas_destroy(matrix_c_bnorm)
1465 60327 : DEALLOCATE (matrix_c_bnorm)
1466 : END IF
1467 :
1468 60443 : CALL timestop(handle)
1469 :
1470 120886 : END SUBROUTINE
1471 :
1472 : ! **************************************************************************************************
1473 : !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1474 : !> This estimate is based on the minimization of communication volume whereby the
1475 : !> communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
1476 : !> \param max_mm_dim ...
1477 : !> \param nze_a number of non-zeroes in A
1478 : !> \param nze_b number of non-zeroes in B
1479 : !> \param nze_c number of non-zeroes in C
1480 : !> \param numnodes number of MPI ranks
1481 : !> \return estimated split factor
1482 : !> \author Patrick Seewald
1483 : ! **************************************************************************************************
1484 201953 : FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1485 : INTEGER, INTENT(IN) :: max_mm_dim
1486 : INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c
1487 : INTEGER, INTENT(IN) :: numnodes
1488 : INTEGER :: nsplit
1489 :
1490 : INTEGER(KIND=int_8) :: max_nze, min_nze
1491 : REAL(dp) :: s_opt_factor
1492 :
1493 201953 : s_opt_factor = 1.0_dp ! Could be further tuned.
1494 :
1495 263859 : SELECT CASE (max_mm_dim)
1496 : CASE (1)
1497 61906 : min_nze = MAX(nze_b, 1_int_8)
1498 185718 : max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
1499 : CASE (2)
1500 60597 : min_nze = MAX(nze_c, 1_int_8)
1501 181791 : max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
1502 : CASE (3)
1503 79450 : min_nze = MAX(nze_a, 1_int_8)
1504 238350 : max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
1505 : CASE DEFAULT
1506 201953 : CPABORT("")
1507 : END SELECT
1508 :
1509 201953 : nsplit = INT(MIN(INT(numnodes, KIND=int_8), NINT(REAL(max_nze, dp)/(REAL(min_nze, dp)*s_opt_factor), KIND=int_8)))
1510 201953 : IF (nsplit == 0) nsplit = 1
1511 :
1512 201953 : END FUNCTION
1513 :
1514 : ! **************************************************************************************************
1515 : !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
1516 : !> \param matrix_in ...
1517 : !> \param matrix_out ...
1518 : !> \param nodata ...
1519 : !> \author Patrick Seewald
1520 : ! **************************************************************************************************
1521 1085886 : SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1522 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1523 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1524 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1525 :
1526 : CHARACTER(len=default_string_length) :: name
1527 : INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528 : LOGICAL :: nodata_prv
1529 : REAL(dp), DIMENSION(1, 1) :: blk_put
1530 180981 : REAL(dp), DIMENSION(:, :), POINTER :: blk_get
1531 : TYPE(dbt_tas_blk_size_one) :: col_blk_size, row_blk_size
1532 : TYPE(dbt_tas_iterator) :: iter
1533 :
1534 : !REAL(dp), DIMENSION(:, :), POINTER :: dbt_put
1535 :
1536 180981 : CPASSERT(matrix_in%valid)
1537 :
1538 180981 : IF (PRESENT(nodata)) THEN
1539 60327 : nodata_prv = nodata
1540 : ELSE
1541 : nodata_prv = .FALSE.
1542 : END IF
1543 :
1544 180981 : CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1545 180981 : row_blk_size = dbt_tas_blk_size_one(nblkrows)
1546 180981 : col_blk_size = dbt_tas_blk_size_one(nblkcols)
1547 :
1548 : ! not sure if assumption that same distribution can be taken still holds
1549 180981 : CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1550 :
1551 180981 : IF (.NOT. nodata_prv) THEN
1552 120654 : CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
1553 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
1554 120654 : !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
1555 : CALL dbt_tas_iterator_start(iter, matrix_in)
1556 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1557 : CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
1558 : blk_put(1, 1) = NORM2(blk_get)
1559 : CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
1560 : END DO
1561 : CALL dbt_tas_iterator_stop(iter)
1562 : !$OMP END PARALLEL
1563 : END IF
1564 :
1565 180981 : END SUBROUTINE
1566 :
1567 : ! **************************************************************************************************
1568 : !> \brief Convert a DBM matrix to a new process grid
1569 : !> \param mp_comm_cart new process grid
1570 : !> \param matrix_in ...
1571 : !> \param matrix_out ...
1572 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1573 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1574 : !> \param optimize_pgrid Whether to change process grid
1575 : !> \author Patrick Seewald
1576 : ! **************************************************************************************************
1577 605655 : SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1578 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm_cart
1579 : TYPE(dbm_type), INTENT(INOUT) :: matrix_in
1580 : TYPE(dbm_type), INTENT(OUT) :: matrix_out
1581 : LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata, optimize_pgrid
1582 :
1583 : CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid'
1584 :
1585 : CHARACTER(len=default_string_length) :: name
1586 : INTEGER :: handle, nbcols, nbrows
1587 605655 : INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: col_dist, rbsize, rcsize, row_dist
1588 : INTEGER, DIMENSION(2) :: pdims
1589 : LOGICAL :: nodata_prv, optimize_pgrid_prv
1590 : TYPE(dbm_distribution_obj) :: dist, dist_old
1591 :
1592 605655 : NULLIFY (row_dist, col_dist, rbsize, rcsize)
1593 :
1594 605655 : CALL timeset(routineN, handle)
1595 :
1596 605655 : IF (PRESENT(optimize_pgrid)) THEN
1597 605655 : optimize_pgrid_prv = optimize_pgrid
1598 : ELSE
1599 : optimize_pgrid_prv = .TRUE.
1600 : END IF
1601 :
1602 605655 : IF (PRESENT(nodata)) THEN
1603 201885 : nodata_prv = nodata
1604 : ELSE
1605 : nodata_prv = .FALSE.
1606 : END IF
1607 :
1608 605655 : name = dbm_get_name(matrix_in)
1609 :
1610 605655 : IF (.NOT. optimize_pgrid_prv) THEN
1611 605655 : CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
1612 605655 : IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
1613 605655 : CALL timestop(handle)
1614 605655 : RETURN
1615 : END IF
1616 :
1617 0 : rbsize => dbm_get_row_block_sizes(matrix_in)
1618 0 : rcsize => dbm_get_col_block_sizes(matrix_in)
1619 0 : nbrows = SIZE(rbsize)
1620 0 : nbcols = SIZE(rcsize)
1621 0 : dist_old = dbm_get_distribution(matrix_in)
1622 0 : pdims = mp_comm_cart%num_pe_cart
1623 :
1624 0 : ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1625 0 : CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1626 0 : CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1627 :
1628 0 : CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
1629 0 : DEALLOCATE (row_dist, col_dist)
1630 :
1631 0 : CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1632 0 : CALL dbm_distribution_release(dist)
1633 :
1634 0 : IF (.NOT. nodata_prv) THEN
1635 0 : CALL dbm_redistribute(matrix_in, matrix_out)
1636 0 : IF (PRESENT(move_data)) THEN
1637 0 : IF (move_data) CALL dbm_clear(matrix_in)
1638 : END IF
1639 : END IF
1640 :
1641 0 : CALL timestop(handle)
1642 605655 : END SUBROUTINE
1643 :
1644 : ! **************************************************************************************************
1645 : !> \brief ...
1646 : !> \param matrix ...
1647 : !> \author Patrick Seewald
1648 : ! **************************************************************************************************
1649 66559 : SUBROUTINE dbt_tas_batched_mm_init(matrix)
1650 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1651 :
1652 66559 : CALL dbt_tas_set_batched_state(matrix, state=1)
1653 66559 : ALLOCATE (matrix%mm_storage)
1654 : matrix%mm_storage%batched_out = .FALSE.
1655 66559 : END SUBROUTINE
1656 :
1657 : ! **************************************************************************************************
1658 : !> \brief ...
1659 : !> \param matrix ...
1660 : !> \author Patrick Seewald
1661 : ! **************************************************************************************************
1662 133118 : SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
1663 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1664 :
1665 : INTEGER :: handle
1666 :
1667 66559 : CALL matrix%dist%info%mp_comm%sync()
1668 66559 : CALL timeset("dbt_tas_total", handle)
1669 :
1670 66559 : IF (matrix%do_batched == 0) RETURN
1671 :
1672 66559 : IF (matrix%mm_storage%batched_out) THEN
1673 24039 : CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1674 : END IF
1675 :
1676 66559 : CALL dbt_tas_batched_mm_complete(matrix)
1677 :
1678 66559 : matrix%mm_storage%batched_out = .FALSE.
1679 :
1680 66559 : DEALLOCATE (matrix%mm_storage)
1681 66559 : CALL dbt_tas_set_batched_state(matrix, state=0)
1682 :
1683 66559 : CALL matrix%dist%info%mp_comm%sync()
1684 66559 : CALL timestop(handle)
1685 :
1686 : END SUBROUTINE
1687 :
1688 : ! **************************************************************************************************
1689 : !> \brief set state flags during batched multiplication
1690 : !> \param matrix ...
1691 : !> \param state 0 no batched MM
1692 : !> 1 batched MM but mm_storage not yet initialized
1693 : !> 2 batched MM and mm_storage requires update
1694 : !> 3 batched MM and mm_storage initialized
1695 : !> \param opt_grid whether process grid was already optimized and should not be changed
1696 : !> \author Patrick Seewald
1697 : ! **************************************************************************************************
1698 1061778 : SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
1699 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1700 : INTEGER, INTENT(IN), OPTIONAL :: state
1701 : LOGICAL, INTENT(IN), OPTIONAL :: opt_grid
1702 :
1703 1061778 : IF (PRESENT(opt_grid)) THEN
1704 802005 : matrix%has_opt_pgrid = opt_grid
1705 802005 : matrix%dist%info%strict_split(1) = .TRUE.
1706 : END IF
1707 :
1708 1061778 : IF (PRESENT(state)) THEN
1709 824697 : matrix%do_batched = state
1710 597294 : SELECT CASE (state)
1711 : CASE (0, 1)
1712 : ! reset to default
1713 597294 : IF (matrix%has_opt_pgrid) THEN
1714 345303 : matrix%dist%info%strict_split(1) = .TRUE.
1715 : ELSE
1716 251991 : matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1717 : END IF
1718 : CASE (2, 3)
1719 227403 : matrix%dist%info%strict_split(1) = .TRUE.
1720 : CASE DEFAULT
1721 824697 : CPABORT("should not happen")
1722 : END SELECT
1723 : END IF
1724 1061778 : END SUBROUTINE
1725 :
1726 : ! **************************************************************************************************
1727 : !> \brief ...
1728 : !> \param matrix ...
1729 : !> \param warn ...
1730 : !> \author Patrick Seewald
1731 : ! **************************************************************************************************
1732 845363 : SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
1733 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1734 : LOGICAL, INTENT(IN), OPTIONAL :: warn
1735 :
1736 845363 : IF (matrix%do_batched == 0) RETURN
1737 : ASSOCIATE (storage => matrix%mm_storage)
1738 70165 : IF (PRESENT(warn)) THEN
1739 1840 : IF (warn .AND. matrix%do_batched == 3) THEN
1740 : CALL cp_warn(__LOCATION__, &
1741 0 : "Optimizations for batched multiplication are disabled because of conflicting data access")
1742 : END IF
1743 : END IF
1744 70165 : IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
1745 :
1746 : CALL dbt_tas_merge(storage%store_batched%matrix, &
1747 25195 : storage%store_batched_repl, move_data=.TRUE.)
1748 :
1749 : CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
1750 25195 : transposed=storage%batched_trans, move_data=.TRUE.)
1751 25195 : CALL dbt_tas_destroy(storage%store_batched)
1752 25195 : DEALLOCATE (storage%store_batched)
1753 : END IF
1754 :
1755 140330 : IF (ASSOCIATED(storage%store_batched_repl)) THEN
1756 56490 : CALL dbt_tas_destroy(storage%store_batched_repl)
1757 56490 : DEALLOCATE (storage%store_batched_repl)
1758 : END IF
1759 : END ASSOCIATE
1760 :
1761 70165 : CALL dbt_tas_set_batched_state(matrix, state=2)
1762 :
1763 : END SUBROUTINE
1764 :
1765 1660428 : END MODULE
|