Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief DBT tensor framework for block-sparse tensor contraction.
10 : !> Representation of n-rank tensors as DBT tall-and-skinny matrices.
11 : !> Support for arbitrary redistribution between different representations.
12 : !> Support for arbitrary tensor contractions
13 : !> \todo implement checks and error messages
14 : !> \author Patrick Seewald
15 : ! **************************************************************************************************
16 : MODULE dbt_methods
17 : #:include "dbt_macros.fypp"
18 : #:set maxdim = maxrank
19 : #:set ndims = range(2,maxdim+1)
20 :
21 : USE cp_dbcsr_api, ONLY: &
22 : dbcsr_type, dbcsr_release, &
23 : dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
24 : dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
25 : USE dbt_allocate_wrap, ONLY: &
26 : allocate_any
27 : USE dbt_array_list_methods, ONLY: &
28 : get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
29 : create_array_list, destroy_array_list, sizes_of_arrays
30 : USE dbm_api, ONLY: &
31 : dbm_clear
32 : USE dbt_tas_types, ONLY: &
33 : dbt_tas_split_info
34 : USE dbt_tas_base, ONLY: &
35 : dbt_tas_copy, dbt_tas_finalize, dbt_tas_get_info, dbt_tas_info
36 : USE dbt_tas_mm, ONLY: &
37 : dbt_tas_multiply, dbt_tas_batched_mm_init, dbt_tas_batched_mm_finalize, &
38 : dbt_tas_batched_mm_complete, dbt_tas_set_batched_state
39 : USE dbt_block, ONLY: &
40 : dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
41 : dbt_iterator_blocks_left, dbt_iterator_stop, dbt_iterator_next_block, &
42 : ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
43 : USE dbt_index, ONLY: &
44 : dbt_get_mapping_info, nd_to_2d_mapping, dbt_inverse_order, permute_index, get_nd_indices_tensor, &
45 : ndims_mapping_row, ndims_mapping_column, ndims_mapping
46 : USE dbt_types, ONLY: &
47 : dbt_create, dbt_type, ndims_tensor, dims_tensor, &
48 : dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
49 : dbt_distribution_destroy, dbt_distribution_new_expert, dbt_get_stored_coordinates, &
50 : blk_dims_tensor, dbt_hold, dbt_pgrid_type, mp_environ_pgrid, dbt_filter, &
51 : dbt_clear, dbt_finalize, dbt_get_num_blocks, dbt_scale, &
52 : dbt_get_num_blocks_total, dbt_get_info, ndims_matrix_row, ndims_matrix_column, &
53 : dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
54 : dbt_distribution_new, dbt_copy_contraction_storage, dbt_pgrid_destroy
55 : USE kinds, ONLY: &
56 : dp, default_string_length, int_8, dp
57 : USE message_passing, ONLY: &
58 : mp_cart_type
59 : USE util, ONLY: &
60 : sort
61 : USE dbt_reshape_ops, ONLY: &
62 : dbt_reshape
63 : USE dbt_tas_split, ONLY: &
64 : dbt_tas_mp_comm, rowsplit, colsplit, dbt_tas_info_hold, dbt_tas_release_info, default_nsplit_accept_ratio, &
65 : default_pdims_accept_ratio, dbt_tas_create_split
66 : USE dbt_split, ONLY: &
67 : dbt_split_copyback, dbt_make_compatible_blocks, dbt_crop
68 : USE dbt_io, ONLY: &
69 : dbt_write_tensor_info, dbt_write_tensor_dist, prep_output_unit, dbt_write_split_info
70 : USE message_passing, ONLY: mp_comm_type
71 :
72 : #include "../base/base_uses.f90"
73 :
74 : IMPLICIT NONE
75 : PRIVATE
76 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
77 :
78 : PUBLIC :: &
79 : dbt_contract, &
80 : dbt_copy, &
81 : dbt_get_block, &
82 : dbt_get_stored_coordinates, &
83 : dbt_inverse_order, &
84 : dbt_iterator_blocks_left, &
85 : dbt_iterator_next_block, &
86 : dbt_iterator_start, &
87 : dbt_iterator_stop, &
88 : dbt_iterator_type, &
89 : dbt_put_block, &
90 : dbt_reserve_blocks, &
91 : dbt_copy_matrix_to_tensor, &
92 : dbt_copy_tensor_to_matrix, &
93 : dbt_batched_contract_init, &
94 : dbt_batched_contract_finalize
95 :
96 : CONTAINS
97 :
98 : ! **************************************************************************************************
99 : !> \brief Copy tensor data.
100 : !> Redistributes tensor data according to distributions of target and source tensor.
101 : !> Permutes tensor index according to `order` argument (if present).
102 : !> Source and target tensor formats are arbitrary as long as the following requirements are met:
103 : !> * source and target tensors have the same rank and the same sizes in each dimension in terms
104 : !> of tensor elements (block sizes don't need to be the same).
105 : !> If `order` argument is present, sizes must match after index permutation.
106 : !> OR
107 : !> * target tensor is not yet created, in this case an exact copy of source tensor is returned.
108 : !> \param tensor_in Source
109 : !> \param tensor_out Target
110 : !> \param order Permutation of target tensor index.
111 : !> Exact same convention as order argument of RESHAPE intrinsic.
112 : !> \param bounds crop tensor data: start and end index for each tensor dimension
113 : !> \author Patrick Seewald
114 : ! **************************************************************************************************
115 780712 : SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
116 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
117 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
118 : INTENT(IN), OPTIONAL :: order
119 : LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
120 : INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
121 : INTENT(IN), OPTIONAL :: bounds
122 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
123 : INTEGER :: handle
124 :
125 390356 : CALL tensor_in%pgrid%mp_comm_2d%sync()
126 390356 : CALL timeset("dbt_total", handle)
127 :
128 : ! make sure that it is safe to use dbt_copy during a batched contraction
129 390356 : CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
130 390356 : CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
131 :
132 390356 : CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
133 390356 : CALL tensor_in%pgrid%mp_comm_2d%sync()
134 390356 : CALL timestop(handle)
135 501506 : END SUBROUTINE
136 :
137 : ! **************************************************************************************************
138 : !> \brief expert routine for copying a tensor. For internal use only.
139 : !> \author Patrick Seewald
140 : ! **************************************************************************************************
141 423589 : SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
142 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
143 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
144 : INTENT(IN), OPTIONAL :: order
145 : LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
146 : INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
147 : INTENT(IN), OPTIONAL :: bounds
148 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
149 :
150 : TYPE(dbt_type), POINTER :: in_tmp_1, in_tmp_2, &
151 : in_tmp_3, out_tmp_1
152 : INTEGER :: handle, unit_nr_prv
153 423589 : INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
154 :
155 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy'
156 : LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
157 : summation_prv, new_in_1, new_in_2, &
158 : new_in_3, new_out_1, block_compatible, &
159 : move_prv
160 423589 : TYPE(array_list) :: blk_sizes_in
161 :
162 423589 : CALL timeset(routineN, handle)
163 :
164 423589 : CPASSERT(tensor_out%valid)
165 :
166 423589 : unit_nr_prv = prep_output_unit(unit_nr)
167 :
168 423589 : IF (PRESENT(move_data)) THEN
169 314031 : move_prv = move_data
170 : ELSE
171 109558 : move_prv = .FALSE.
172 : END IF
173 :
174 423589 : dist_compatible_tas = .FALSE.
175 423589 : dist_compatible_tensor = .FALSE.
176 423589 : block_compatible = .FALSE.
177 423589 : new_in_1 = .FALSE.
178 423589 : new_in_2 = .FALSE.
179 423589 : new_in_3 = .FALSE.
180 423589 : new_out_1 = .FALSE.
181 :
182 423589 : IF (PRESENT(summation)) THEN
183 111343 : summation_prv = summation
184 : ELSE
185 : summation_prv = .FALSE.
186 : END IF
187 :
188 423589 : IF (PRESENT(bounds)) THEN
189 39424 : ALLOCATE (in_tmp_1)
190 5632 : CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
191 5632 : new_in_1 = .TRUE.
192 5632 : move_prv = .TRUE.
193 : ELSE
194 : in_tmp_1 => tensor_in
195 : END IF
196 :
197 423589 : IF (PRESENT(order)) THEN
198 111150 : CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
199 111150 : block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
200 : ELSE
201 312439 : block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
202 : END IF
203 :
204 423589 : IF (.NOT. block_compatible) THEN
205 925249 : ALLOCATE (in_tmp_2, out_tmp_1)
206 : CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
207 71173 : nodata2=.NOT. summation_prv, move_data=move_prv)
208 71173 : new_in_2 = .TRUE.; new_out_1 = .TRUE.
209 71173 : move_prv = .TRUE.
210 : ELSE
211 : in_tmp_2 => in_tmp_1
212 : out_tmp_1 => tensor_out
213 : END IF
214 :
215 423589 : IF (PRESENT(order)) THEN
216 778050 : ALLOCATE (in_tmp_3)
217 111150 : CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
218 111150 : new_in_3 = .TRUE.
219 : ELSE
220 : in_tmp_3 => in_tmp_2
221 : END IF
222 :
223 1270767 : ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
224 1270767 : ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
225 423589 : CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
226 :
227 1270767 : ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
228 1270767 : ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
229 423589 : CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
230 :
231 423589 : IF (.NOT. PRESENT(order)) THEN
232 312439 : IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
233 264758 : dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234 619885 : ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
235 22490 : dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
236 : END IF
237 : END IF
238 :
239 264758 : IF (dist_compatible_tas) THEN
240 217130 : CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
241 217130 : IF (move_prv) CALL dbt_clear(in_tmp_3)
242 206459 : ELSEIF (dist_compatible_tensor) THEN
243 14878 : CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
244 14878 : IF (move_prv) CALL dbt_clear(in_tmp_3)
245 : ELSE
246 191581 : CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
247 : END IF
248 :
249 423589 : IF (new_in_1) THEN
250 5632 : CALL dbt_destroy(in_tmp_1)
251 5632 : DEALLOCATE (in_tmp_1)
252 : END IF
253 :
254 423589 : IF (new_in_2) THEN
255 71173 : CALL dbt_destroy(in_tmp_2)
256 71173 : DEALLOCATE (in_tmp_2)
257 : END IF
258 :
259 423589 : IF (new_in_3) THEN
260 111150 : CALL dbt_destroy(in_tmp_3)
261 111150 : DEALLOCATE (in_tmp_3)
262 : END IF
263 :
264 423589 : IF (new_out_1) THEN
265 71173 : IF (unit_nr_prv /= 0) THEN
266 0 : CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
267 : END IF
268 71173 : CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
269 71173 : CALL dbt_destroy(out_tmp_1)
270 71173 : DEALLOCATE (out_tmp_1)
271 : END IF
272 :
273 423589 : CALL timestop(handle)
274 :
275 847178 : END SUBROUTINE
276 :
277 : ! **************************************************************************************************
278 : !> \brief copy without communication, requires that both tensors have same process grid and distribution
279 : !> \param summation Whether to sum matrices b = a + b
280 : !> \author Patrick Seewald
281 : ! **************************************************************************************************
282 14878 : SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
283 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
284 : TYPE(dbt_type), INTENT(INOUT) :: tensor_out
285 : LOGICAL, INTENT(IN), OPTIONAL :: summation
286 : TYPE(dbt_iterator_type) :: iter
287 14878 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
288 14878 : TYPE(block_nd) :: blk_data
289 : LOGICAL :: found
290 :
291 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_nocomm'
292 : INTEGER :: handle
293 :
294 14878 : CALL timeset(routineN, handle)
295 14878 : CPASSERT(tensor_out%valid)
296 :
297 14878 : IF (PRESENT(summation)) THEN
298 7356 : IF (.NOT. summation) CALL dbt_clear(tensor_out)
299 : ELSE
300 7522 : CALL dbt_clear(tensor_out)
301 : END IF
302 :
303 14878 : CALL dbt_reserve_blocks(tensor_in, tensor_out)
304 :
305 : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
306 14878 : !$OMP PRIVATE(iter,ind_nd,blk_data,found)
307 : CALL dbt_iterator_start(iter, tensor_in)
308 : DO WHILE (dbt_iterator_blocks_left(iter))
309 : CALL dbt_iterator_next_block(iter, ind_nd)
310 : CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
311 : CPASSERT(found)
312 : CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
313 : CALL destroy_block(blk_data)
314 : END DO
315 : CALL dbt_iterator_stop(iter)
316 : !$OMP END PARALLEL
317 :
318 14878 : CALL timestop(handle)
319 29756 : END SUBROUTINE
320 :
321 : ! **************************************************************************************************
322 : !> \brief copy matrix to tensor.
323 : !> \param summation tensor_out = tensor_out + matrix_in
324 : !> \author Patrick Seewald
325 : ! **************************************************************************************************
326 64766 : SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
327 : TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in
328 : TYPE(dbt_type), INTENT(INOUT) :: tensor_out
329 : LOGICAL, INTENT(IN), OPTIONAL :: summation
330 : TYPE(dbcsr_type), POINTER :: matrix_in_desym
331 :
332 : INTEGER, DIMENSION(2) :: ind_2d
333 64766 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: block_arr
334 64766 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: block
335 : TYPE(dbcsr_iterator_type) :: iter
336 :
337 : INTEGER :: handle
338 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_matrix_to_tensor'
339 :
340 64766 : CALL timeset(routineN, handle)
341 64766 : CPASSERT(tensor_out%valid)
342 :
343 64766 : NULLIFY (block)
344 :
345 64766 : IF (dbcsr_has_symmetry(matrix_in)) THEN
346 4602 : ALLOCATE (matrix_in_desym)
347 4602 : CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
348 : ELSE
349 : matrix_in_desym => matrix_in
350 : END IF
351 :
352 64766 : IF (PRESENT(summation)) THEN
353 0 : IF (.NOT. summation) CALL dbt_clear(tensor_out)
354 : ELSE
355 64766 : CALL dbt_clear(tensor_out)
356 : END IF
357 :
358 64766 : CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
359 :
360 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
361 64766 : !$OMP PRIVATE(iter,ind_2d,block,block_arr)
362 : CALL dbcsr_iterator_start(iter, matrix_in_desym)
363 : DO WHILE (dbcsr_iterator_blocks_left(iter))
364 : CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block)
365 : CALL allocate_any(block_arr, source=block)
366 : CALL dbt_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
367 : DEALLOCATE (block_arr)
368 : END DO
369 : CALL dbcsr_iterator_stop(iter)
370 : !$OMP END PARALLEL
371 :
372 64766 : IF (dbcsr_has_symmetry(matrix_in)) THEN
373 4602 : CALL dbcsr_release(matrix_in_desym)
374 4602 : DEALLOCATE (matrix_in_desym)
375 : END IF
376 :
377 64766 : CALL timestop(handle)
378 :
379 129532 : END SUBROUTINE
380 :
381 : ! **************************************************************************************************
382 : !> \brief copy tensor to matrix
383 : !> \param summation matrix_out = matrix_out + tensor_in
384 : !> \author Patrick Seewald
385 : ! **************************************************************************************************
386 42244 : SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
387 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
388 : TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
389 : LOGICAL, INTENT(IN), OPTIONAL :: summation
390 : TYPE(dbt_iterator_type) :: iter
391 : INTEGER :: handle
392 : INTEGER, DIMENSION(2) :: ind_2d
393 42244 : REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: block
394 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_tensor_to_matrix'
395 : LOGICAL :: found
396 :
397 42244 : CALL timeset(routineN, handle)
398 :
399 42244 : IF (PRESENT(summation)) THEN
400 5876 : IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
401 : ELSE
402 36368 : CALL dbcsr_clear(matrix_out)
403 : END IF
404 :
405 42244 : CALL dbt_reserve_blocks(tensor_in, matrix_out)
406 :
407 : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
408 42244 : !$OMP PRIVATE(iter,ind_2d,block,found)
409 : CALL dbt_iterator_start(iter, tensor_in)
410 : DO WHILE (dbt_iterator_blocks_left(iter))
411 : CALL dbt_iterator_next_block(iter, ind_2d)
412 : IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
413 :
414 : CALL dbt_get_block(tensor_in, ind_2d, block, found)
415 : CPASSERT(found)
416 :
417 : IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
418 : CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
419 : ELSE
420 : CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
421 : END IF
422 : DEALLOCATE (block)
423 : END DO
424 : CALL dbt_iterator_stop(iter)
425 : !$OMP END PARALLEL
426 :
427 42244 : CALL timestop(handle)
428 :
429 84488 : END SUBROUTINE
430 :
431 : ! **************************************************************************************************
432 : !> \brief Contract tensors by multiplying matrix representations.
433 : !> tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
434 : !> * tensor_2(contract_2, notcontract_2)
435 : !> + beta * tensor_3(map_1, map_2)
436 : !>
437 : !> \note
438 : !> note 1: block sizes of the corresponding indices need to be the same in all tensors.
439 : !>
440 : !> note 2: for best performance the tensors should have been created in matrix layouts
441 : !> compatible with the contraction, e.g. tensor_1 should have been created with either
442 : !> map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
443 : !> map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
444 : !> tensor_3 and map_1 / map_2).
445 : !> Furthermore the two largest tensors involved in the contraction should map both to either
446 : !> tall or short matrices: the largest matrix dimension should be "on the same side"
447 : !> and should have identical distribution (which is always the case if the distributions were
448 : !> obtained with dbt_default_distvec).
449 : !>
450 : !> note 3: if the same tensor occurs in multiple contractions, a different tensor object should
451 : !> be created for each contraction and the data should be copied between the tensors by use of
452 : !> dbt_copy. If the same tensor object is used in multiple contractions,
453 : !> matrix layouts are not compatible for all contractions (see note 2).
454 : !>
455 : !> note 4: automatic optimizations are enabled by using the feature of batched contraction, see
456 : !> dbt_batched_contract_init, dbt_batched_contract_finalize.
457 : !> The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
458 : !>
459 : !> \param tensor_1 first tensor (in)
460 : !> \param tensor_2 second tensor (in)
461 : !> \param contract_1 indices of tensor_1 to contract
462 : !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
463 : !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
464 : !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
465 : !> \param notcontract_1 indices of tensor_1 not to contract
466 : !> \param notcontract_2 indices of tensor_2 not to contract
467 : !> \param tensor_3 contracted tensor (out)
468 : !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
469 : !> start and end index of an index range over which to contract.
470 : !> For use in batched contraction.
471 : !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
472 : !> For use in batched contraction.
473 : !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
474 : !> For use in batched contraction.
475 : !> \param optimize_dist Whether distribution should be optimized internally. In the current
476 : !> implementation this guarantees optimal parameters only for dense matrices.
477 : !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
478 : !> This can be used to choose optimal process grids for subsequent tensor
479 : !> contractions with tensors of similar shape and sparsity. Under some conditions,
480 : !> pgrid_opt_1 can not be returned, in this case the pointer is not associated.
481 : !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
482 : !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
483 : !> \param filter_eps As in DBM mm
484 : !> \param flop As in DBM mm
485 : !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
486 : !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
487 : !> \param unit_nr output unit for logging
488 : !> set it to -1 on ranks that should not write (and any valid unit number on
489 : !> ranks that should write output) if 0 on ALL ranks, no output is written
490 : !> \param log_verbose verbose logging (for testing only)
491 : !> \author Patrick Seewald
492 : ! **************************************************************************************************
493 341212 : SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
494 170606 : contract_1, notcontract_1, &
495 170606 : contract_2, notcontract_2, &
496 170606 : map_1, map_2, &
497 116968 : bounds_1, bounds_2, bounds_3, &
498 : optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
499 : filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
500 : REAL(dp), INTENT(IN) :: alpha
501 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
502 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
503 : REAL(dp), INTENT(IN) :: beta
504 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
505 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
506 : INTEGER, DIMENSION(:), INTENT(IN) :: map_1
507 : INTEGER, DIMENSION(:), INTENT(IN) :: map_2
508 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
509 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
510 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
511 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
512 : INTENT(IN), OPTIONAL :: bounds_1
513 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
514 : INTENT(IN), OPTIONAL :: bounds_2
515 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
516 : INTENT(IN), OPTIONAL :: bounds_3
517 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
518 : TYPE(dbt_pgrid_type), INTENT(OUT), &
519 : POINTER, OPTIONAL :: pgrid_opt_1
520 : TYPE(dbt_pgrid_type), INTENT(OUT), &
521 : POINTER, OPTIONAL :: pgrid_opt_2
522 : TYPE(dbt_pgrid_type), INTENT(OUT), &
523 : POINTER, OPTIONAL :: pgrid_opt_3
524 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
525 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
526 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
527 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
528 : INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
529 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
530 :
531 : INTEGER :: handle
532 :
533 170606 : CALL tensor_1%pgrid%mp_comm_2d%sync()
534 170606 : CALL timeset("dbt_total", handle)
535 : CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
536 : contract_1, notcontract_1, &
537 : contract_2, notcontract_2, &
538 : map_1, map_2, &
539 : bounds_1=bounds_1, &
540 : bounds_2=bounds_2, &
541 : bounds_3=bounds_3, &
542 : optimize_dist=optimize_dist, &
543 : pgrid_opt_1=pgrid_opt_1, &
544 : pgrid_opt_2=pgrid_opt_2, &
545 : pgrid_opt_3=pgrid_opt_3, &
546 : filter_eps=filter_eps, &
547 : flop=flop, &
548 : move_data=move_data, &
549 : retain_sparsity=retain_sparsity, &
550 : unit_nr=unit_nr, &
551 170606 : log_verbose=log_verbose)
552 170606 : CALL tensor_1%pgrid%mp_comm_2d%sync()
553 170606 : CALL timestop(handle)
554 :
555 242232 : END SUBROUTINE
556 :
557 : ! **************************************************************************************************
558 : !> \brief expert routine for tensor contraction. For internal use only.
559 : !> \param nblks_local number of local blocks on this MPI rank
560 : !> \author Patrick Seewald
561 : ! **************************************************************************************************
562 170606 : SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
563 170606 : contract_1, notcontract_1, &
564 170606 : contract_2, notcontract_2, &
565 170606 : map_1, map_2, &
566 170606 : bounds_1, bounds_2, bounds_3, &
567 : optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
568 : filter_eps, flop, move_data, retain_sparsity, &
569 : nblks_local, unit_nr, log_verbose)
570 : REAL(dp), INTENT(IN) :: alpha
571 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
572 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
573 : REAL(dp), INTENT(IN) :: beta
574 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
575 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
576 : INTEGER, DIMENSION(:), INTENT(IN) :: map_1
577 : INTEGER, DIMENSION(:), INTENT(IN) :: map_2
578 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
579 : INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
580 : TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
581 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
582 : INTENT(IN), OPTIONAL :: bounds_1
583 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
584 : INTENT(IN), OPTIONAL :: bounds_2
585 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
586 : INTENT(IN), OPTIONAL :: bounds_3
587 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
588 : TYPE(dbt_pgrid_type), INTENT(OUT), &
589 : POINTER, OPTIONAL :: pgrid_opt_1
590 : TYPE(dbt_pgrid_type), INTENT(OUT), &
591 : POINTER, OPTIONAL :: pgrid_opt_2
592 : TYPE(dbt_pgrid_type), INTENT(OUT), &
593 : POINTER, OPTIONAL :: pgrid_opt_3
594 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
595 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
596 : LOGICAL, INTENT(IN), OPTIONAL :: move_data
597 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
598 : INTEGER, INTENT(OUT), OPTIONAL :: nblks_local
599 : INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
600 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
601 :
602 : TYPE(dbt_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
603 3241514 : TYPE(dbt_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
604 : TYPE(dbt_type), POINTER :: tensor_crop_1, tensor_crop_2
605 : TYPE(dbt_type), POINTER :: tensor_small, tensor_large
606 :
607 : LOGICAL :: assert_stmt, tensors_remapped
608 : INTEGER :: max_mm_dim, max_tensor, &
609 : unit_nr_prv, ref_tensor, handle
610 170606 : TYPE(mp_cart_type) :: mp_comm_opt
611 341212 : INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod
612 341212 : INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
613 341212 : INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod
614 341212 : INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
615 341212 : INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod
616 341212 : INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod
617 : LOGICAL :: trans_1, trans_2, trans_3
618 : LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
619 : INTEGER :: ndims1, ndims2, ndims3
620 : INTEGER :: occ_1, occ_2
621 170606 : INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3
622 :
623 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_contract'
624 170606 : CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
625 170606 : indchar2_mod, indchar3_mod
626 : CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
627 : ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
628 341212 : INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
629 341212 : INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
630 : LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
631 : pgrid_changed_any, do_change_pgrid(2)
632 2217878 : TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
633 : INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
634 : REAL(dp) :: pdim_ratio, pdim_ratio_opt
635 :
636 170606 : NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
637 170606 : tensor_small)
638 :
639 170606 : CALL timeset(routineN, handle)
640 :
641 170606 : CPASSERT(tensor_1%valid)
642 170606 : CPASSERT(tensor_2%valid)
643 170606 : CPASSERT(tensor_3%valid)
644 :
645 170606 : assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
646 170606 : CPASSERT(assert_stmt)
647 :
648 170606 : assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
649 170606 : CPASSERT(assert_stmt)
650 :
651 170606 : assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
652 170606 : CPASSERT(assert_stmt)
653 :
654 170606 : assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
655 170606 : CPASSERT(assert_stmt)
656 :
657 170606 : assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
658 170606 : CPASSERT(assert_stmt)
659 :
660 170606 : assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
661 170606 : CPASSERT(assert_stmt)
662 :
663 170606 : unit_nr_prv = prep_output_unit(unit_nr)
664 :
665 170606 : IF (PRESENT(flop)) flop = 0
666 170606 : IF (PRESENT(nblks_local)) nblks_local = 0
667 :
668 170606 : IF (PRESENT(move_data)) THEN
669 35423 : move_data_1 = move_data
670 35423 : move_data_2 = move_data
671 : ELSE
672 135183 : move_data_1 = .FALSE.
673 135183 : move_data_2 = .FALSE.
674 : END IF
675 :
676 170606 : nodata_3 = .TRUE.
677 170606 : IF (PRESENT(retain_sparsity)) THEN
678 4762 : IF (retain_sparsity) nodata_3 = .FALSE.
679 : END IF
680 :
681 : CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
682 : contract_1, notcontract_1, &
683 : contract_2, notcontract_2, &
684 : bounds_t1, bounds_t2, &
685 : bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
686 170606 : do_crop_1=do_crop_1, do_crop_2=do_crop_2)
687 :
688 170606 : IF (do_crop_1) THEN
689 484638 : ALLOCATE (tensor_crop_1)
690 69234 : CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
691 69234 : move_data_1 = .TRUE.
692 : ELSE
693 : tensor_crop_1 => tensor_1
694 : END IF
695 :
696 170606 : IF (do_crop_2) THEN
697 469294 : ALLOCATE (tensor_crop_2)
698 67042 : CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
699 67042 : move_data_2 = .TRUE.
700 : ELSE
701 : tensor_crop_2 => tensor_2
702 : END IF
703 :
704 : ! shortcut for empty tensors
705 : ! this is needed to avoid unnecessary work in case user contracts different portions of a
706 : ! tensor consecutively to save memory
707 : ASSOCIATE (mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
708 170606 : occ_1 = dbt_get_num_blocks(tensor_crop_1)
709 170606 : CALL mp_comm%max(occ_1)
710 170606 : occ_2 = dbt_get_num_blocks(tensor_crop_2)
711 170606 : CALL mp_comm%max(occ_2)
712 : END ASSOCIATE
713 :
714 170606 : IF (occ_1 == 0 .OR. occ_2 == 0) THEN
715 27976 : CALL dbt_scale(tensor_3, beta)
716 27976 : IF (do_crop_1) THEN
717 2738 : CALL dbt_destroy(tensor_crop_1)
718 2738 : DEALLOCATE (tensor_crop_1)
719 : END IF
720 27976 : IF (do_crop_2) THEN
721 2752 : CALL dbt_destroy(tensor_crop_2)
722 2752 : DEALLOCATE (tensor_crop_2)
723 : END IF
724 :
725 27976 : CALL timestop(handle)
726 27976 : RETURN
727 : END IF
728 :
729 142630 : IF (unit_nr_prv /= 0) THEN
730 46842 : IF (unit_nr_prv > 0) THEN
731 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
732 10 : WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
733 20 : TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
734 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
735 : END IF
736 46842 : CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
737 46842 : CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
738 46842 : CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
739 46842 : CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
740 : END IF
741 :
742 : ! align tensor index with data, tensor data is not modified
743 142630 : ndims1 = ndims_tensor(tensor_crop_1)
744 142630 : ndims2 = ndims_tensor(tensor_crop_2)
745 142630 : ndims3 = ndims_tensor(tensor_3)
746 570520 : ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
747 570520 : ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
748 570520 : ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
749 :
750 : ! labeling tensor index with letters
751 :
752 1226584 : indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
753 343708 : indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
754 333044 : indchar2(contract_2) = indchar1(contract_1)
755 313534 : indchar3(map_1) = indchar1(notcontract_1)
756 343708 : indchar3(map_2) = indchar2(notcontract_2)
757 :
758 142630 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
759 : tensor_crop_2, indchar2, &
760 46842 : tensor_3, indchar3, unit_nr_prv)
761 142630 : IF (unit_nr_prv > 0) THEN
762 10 : WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
763 : END IF
764 :
765 : CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
766 142630 : tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
767 :
768 : CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
769 142630 : tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
770 :
771 : CALL align_tensor(tensor_3, map_1, map_2, &
772 142630 : tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
773 :
774 142630 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
775 : tensor_algn_2, indchar2_mod, &
776 46842 : tensor_algn_3, indchar3_mod, unit_nr_prv)
777 :
778 427890 : ALLOCATE (dims1(ndims1))
779 427890 : ALLOCATE (dims2(ndims2))
780 427890 : ALLOCATE (dims3(ndims3))
781 :
782 : ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
783 : ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
784 142630 : CALL blk_dims_tensor(tensor_crop_1, dims1)
785 142630 : CALL blk_dims_tensor(tensor_crop_2, dims2)
786 142630 : CALL blk_dims_tensor(tensor_3, dims3)
787 :
788 : max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
789 : PRODUCT(INT(dims1(contract_1), int_8)), &
790 1132916 : PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
791 1695312 : max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
792 36488 : SELECT CASE (max_mm_dim)
793 : CASE (1)
794 36488 : IF (unit_nr_prv > 0) THEN
795 3 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
796 3 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
797 : END IF
798 36488 : CALL index_linked_sort(contract_1_mod, contract_2_mod)
799 36488 : CALL index_linked_sort(map_2_mod, notcontract_2_mod)
800 36060 : SELECT CASE (max_tensor)
801 : CASE (1)
802 36060 : CALL index_linked_sort(notcontract_1_mod, map_1_mod)
803 : CASE (3)
804 428 : CALL index_linked_sort(map_1_mod, notcontract_1_mod)
805 : CASE DEFAULT
806 36488 : CPABORT("should not happen")
807 : END SELECT
808 :
809 : CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
810 : contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
811 : trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
812 36488 : move_data_1=move_data_1, unit_nr=unit_nr_prv)
813 :
814 : CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
815 36488 : new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
816 :
817 36060 : SELECT CASE (ref_tensor)
818 : CASE (1)
819 36060 : tensor_large => tensor_contr_1
820 : CASE (2)
821 36488 : tensor_large => tensor_contr_3
822 : END SELECT
823 36488 : tensor_small => tensor_contr_2
824 :
825 : CASE (2)
826 48236 : IF (unit_nr_prv > 0) THEN
827 5 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
828 5 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
829 : END IF
830 :
831 48236 : CALL index_linked_sort(notcontract_1_mod, map_1_mod)
832 48236 : CALL index_linked_sort(notcontract_2_mod, map_2_mod)
833 47436 : SELECT CASE (max_tensor)
834 : CASE (1)
835 47436 : CALL index_linked_sort(contract_1_mod, contract_2_mod)
836 : CASE (2)
837 800 : CALL index_linked_sort(contract_2_mod, contract_1_mod)
838 : CASE DEFAULT
839 48236 : CPABORT("should not happen")
840 : END SELECT
841 :
842 : CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
843 : notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
844 : trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
845 48236 : move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
846 48236 : trans_1 = .NOT. trans_1
847 :
848 : CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
849 48236 : new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
850 :
851 47436 : SELECT CASE (ref_tensor)
852 : CASE (1)
853 47436 : tensor_large => tensor_contr_1
854 : CASE (2)
855 48236 : tensor_large => tensor_contr_2
856 : END SELECT
857 48236 : tensor_small => tensor_contr_3
858 :
859 : CASE (3)
860 57906 : IF (unit_nr_prv > 0) THEN
861 2 : WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
862 2 : WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
863 : END IF
864 57906 : CALL index_linked_sort(map_1_mod, notcontract_1_mod)
865 57906 : CALL index_linked_sort(contract_2_mod, contract_1_mod)
866 57524 : SELECT CASE (max_tensor)
867 : CASE (2)
868 57524 : CALL index_linked_sort(notcontract_2_mod, map_2_mod)
869 : CASE (3)
870 382 : CALL index_linked_sort(map_2_mod, notcontract_2_mod)
871 : CASE DEFAULT
872 57906 : CPABORT("should not happen")
873 : END SELECT
874 :
875 : CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
876 : contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
877 : trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
878 57906 : move_data_1=move_data_2, unit_nr=unit_nr_prv)
879 :
880 57906 : trans_2 = .NOT. trans_2
881 57906 : trans_3 = .NOT. trans_3
882 :
883 : CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
884 57906 : trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
885 :
886 57524 : SELECT CASE (ref_tensor)
887 : CASE (1)
888 57524 : tensor_large => tensor_contr_2
889 : CASE (2)
890 57906 : tensor_large => tensor_contr_3
891 : END SELECT
892 200536 : tensor_small => tensor_contr_1
893 :
894 : END SELECT
895 :
896 142630 : IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
897 : tensor_contr_2, indchar2_mod, &
898 46842 : tensor_contr_3, indchar3_mod, unit_nr_prv)
899 142630 : IF (unit_nr_prv /= 0) THEN
900 46842 : IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
901 46842 : IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
902 46842 : IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
903 46842 : IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
904 : END IF
905 :
906 : CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
907 : tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
908 : beta, &
909 : tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
910 : unit_nr=unit_nr_prv, log_verbose=log_verbose, &
911 : split_opt=split_opt, &
912 142630 : move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
913 :
914 142630 : IF (PRESENT(pgrid_opt_1)) THEN
915 0 : IF (.NOT. new_1) THEN
916 0 : ALLOCATE (pgrid_opt_1)
917 0 : pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
918 : END IF
919 : END IF
920 :
921 142630 : IF (PRESENT(pgrid_opt_2)) THEN
922 0 : IF (.NOT. new_2) THEN
923 0 : ALLOCATE (pgrid_opt_2)
924 0 : pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
925 : END IF
926 : END IF
927 :
928 142630 : IF (PRESENT(pgrid_opt_3)) THEN
929 0 : IF (.NOT. new_3) THEN
930 0 : ALLOCATE (pgrid_opt_3)
931 0 : pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
932 : END IF
933 : END IF
934 :
935 142630 : do_batched = tensor_small%matrix_rep%do_batched > 0
936 :
937 142630 : tensors_remapped = .FALSE.
938 142630 : IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
939 :
940 142630 : IF (tensors_remapped .AND. do_batched) THEN
941 : CALL cp_warn(__LOCATION__, &
942 0 : "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
943 : END IF
944 :
945 : ! optimize process grid during batched contraction
946 142630 : do_change_pgrid(:) = .FALSE.
947 142630 : IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
948 : ASSOCIATE (storage => tensor_small%contraction_storage)
949 0 : CPASSERT(storage%static)
950 80187 : split = dbt_tas_info(tensor_large%matrix_rep)
951 : do_change_pgrid(:) = &
952 80187 : update_contraction_storage(storage, split_opt, split)
953 :
954 318816 : IF (ANY(do_change_pgrid)) THEN
955 966 : mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
956 : CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
957 966 : NINT(storage%nsplit_avg), own_comm=.TRUE.)
958 2898 : pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
959 : END IF
960 :
961 : END ASSOCIATE
962 :
963 80187 : IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
964 : ! check if new grid has better subgrid, if not there is no need to change process grid
965 2898 : pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
966 2898 : pdims_sub = split%mp_comm_group%num_pe_cart
967 :
968 4830 : pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
969 4830 : pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
970 966 : IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
971 0 : do_change_pgrid(1) = .FALSE.
972 0 : CALL dbt_tas_release_info(split_opt_avg)
973 : END IF
974 : END IF
975 : END IF
976 :
977 142630 : IF (unit_nr_prv /= 0) THEN
978 46842 : do_write_3 = .TRUE.
979 46842 : IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
980 20748 : IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
981 : END IF
982 : IF (do_write_3) THEN
983 26132 : CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
984 26132 : CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
985 : END IF
986 : END IF
987 :
988 142630 : IF (new_3) THEN
989 : ! need redistribute if we created new tensor for tensor 3
990 14914 : CALL dbt_scale(tensor_algn_3, beta)
991 14914 : CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
992 14914 : IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
993 : ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
994 : ! pointer to data of tensor_3
995 : END IF
996 :
997 : ! transfer contraction storage
998 142630 : CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
999 142630 : CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
1000 142630 : CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
1001 :
1002 142630 : IF (unit_nr_prv /= 0) THEN
1003 46842 : IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1004 46842 : IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1005 : END IF
1006 :
1007 142630 : CALL dbt_destroy(tensor_algn_1)
1008 142630 : CALL dbt_destroy(tensor_algn_2)
1009 142630 : CALL dbt_destroy(tensor_algn_3)
1010 :
1011 142630 : IF (do_crop_1) THEN
1012 66496 : CALL dbt_destroy(tensor_crop_1)
1013 66496 : DEALLOCATE (tensor_crop_1)
1014 : END IF
1015 :
1016 142630 : IF (do_crop_2) THEN
1017 64290 : CALL dbt_destroy(tensor_crop_2)
1018 64290 : DEALLOCATE (tensor_crop_2)
1019 : END IF
1020 :
1021 142630 : IF (new_1) THEN
1022 14990 : CALL dbt_destroy(tensor_contr_1)
1023 14990 : DEALLOCATE (tensor_contr_1)
1024 : END IF
1025 142630 : IF (new_2) THEN
1026 2561 : CALL dbt_destroy(tensor_contr_2)
1027 2561 : DEALLOCATE (tensor_contr_2)
1028 : END IF
1029 142630 : IF (new_3) THEN
1030 14914 : CALL dbt_destroy(tensor_contr_3)
1031 14914 : DEALLOCATE (tensor_contr_3)
1032 : END IF
1033 :
1034 142630 : IF (PRESENT(move_data)) THEN
1035 31647 : IF (move_data) THEN
1036 27809 : CALL dbt_clear(tensor_1)
1037 27809 : CALL dbt_clear(tensor_2)
1038 : END IF
1039 : END IF
1040 :
1041 142630 : IF (unit_nr_prv > 0) THEN
1042 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1043 10 : WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1044 10 : WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1045 : END IF
1046 :
1047 425958 : IF (ANY(do_change_pgrid)) THEN
1048 966 : pgrid_changed_any = .FALSE.
1049 264 : SELECT CASE (max_mm_dim)
1050 : CASE (1)
1051 264 : IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1052 : CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1053 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1054 : pgrid_changed=pgrid_changed, &
1055 0 : unit_nr=unit_nr_prv)
1056 0 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1057 : CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1058 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1059 : pgrid_changed=pgrid_changed, &
1060 0 : unit_nr=unit_nr_prv)
1061 0 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1062 : END IF
1063 0 : IF (pgrid_changed_any) THEN
1064 0 : IF (tensor_2%matrix_rep%do_batched == 3) THEN
1065 : ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1066 : ! in TAS multiply algorithm
1067 0 : CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1068 : END IF
1069 : END IF
1070 : CASE (2)
1071 174 : IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1072 : CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1073 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1074 : pgrid_changed=pgrid_changed, &
1075 174 : unit_nr=unit_nr_prv)
1076 174 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1077 : CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1078 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1079 : pgrid_changed=pgrid_changed, &
1080 174 : unit_nr=unit_nr_prv)
1081 174 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1082 : END IF
1083 8 : IF (pgrid_changed_any) THEN
1084 174 : IF (tensor_3%matrix_rep%do_batched == 3) THEN
1085 162 : CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1086 : END IF
1087 : END IF
1088 : CASE (3)
1089 528 : IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1090 : CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1091 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1092 : pgrid_changed=pgrid_changed, &
1093 214 : unit_nr=unit_nr_prv)
1094 214 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1095 : CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1096 : nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1097 : pgrid_changed=pgrid_changed, &
1098 214 : unit_nr=unit_nr_prv)
1099 214 : IF (pgrid_changed) pgrid_changed_any = .TRUE.
1100 : END IF
1101 966 : IF (pgrid_changed_any) THEN
1102 214 : IF (tensor_1%matrix_rep%do_batched == 3) THEN
1103 214 : CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1104 : END IF
1105 : END IF
1106 : END SELECT
1107 966 : CALL dbt_tas_release_info(split_opt_avg)
1108 : END IF
1109 :
1110 142630 : IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1111 : ! freeze TAS process grids if tensor grids were optimized
1112 80187 : CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
1113 80187 : CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
1114 80187 : CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
1115 : END IF
1116 :
1117 142630 : CALL dbt_tas_release_info(split_opt)
1118 :
1119 142630 : CALL timestop(handle)
1120 :
1121 483842 : END SUBROUTINE
1122 :
1123 : ! **************************************************************************************************
1124 : !> \brief align tensor index with data
1125 : !> \author Patrick Seewald
1126 : ! **************************************************************************************************
1127 3851010 : SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1128 427890 : tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1129 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1130 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in
1131 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1132 : INTEGER, DIMENSION(SIZE(contract_in)), &
1133 : INTENT(OUT) :: contract_out
1134 : INTEGER, DIMENSION(SIZE(notcontract_in)), &
1135 : INTENT(OUT) :: notcontract_out
1136 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1137 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1138 427890 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1139 :
1140 427890 : CALL dbt_align_index(tensor_in, tensor_out, order=align)
1141 979622 : contract_out = align(contract_in)
1142 1000950 : notcontract_out = align(notcontract_in)
1143 1552682 : indp_out(align) = indp_in
1144 :
1145 427890 : END SUBROUTINE
1146 :
1147 : ! **************************************************************************************************
1148 : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1149 : !> matrix multiplication. This routine reshapes the two largest of the three tensors.
1150 : !> Redistribution is avoided if tensors already in a consistent layout.
1151 : !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1152 : !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
1153 : !> 1:1 correspondence with ind1_linked
1154 : !> \param trans1 transpose flag of matrix rep. tensor 1
1155 : !> \param trans2 transpose flag of matrix rep. tensor 2
1156 : !> \param new1 whether a new tensor 1 was created
1157 : !> \param new2 whether a new tensor 2 was created
1158 : !> \param nodata1 don't copy data of tensor 1
1159 : !> \param nodata2 don't copy data of tensor 2
1160 : !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
1161 : !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
1162 : !> \param optimize_dist experimental: optimize distribution
1163 : !> \param unit_nr output unit
1164 : !> \author Patrick Seewald
1165 : ! **************************************************************************************************
1166 142630 : SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1167 142630 : ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1168 : nodata1, nodata2, move_data_1, &
1169 : move_data_2, optimize_dist, unit_nr)
1170 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor1
1171 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor2
1172 : TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out
1173 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free
1174 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked
1175 : LOGICAL, INTENT(OUT) :: trans1, trans2
1176 : LOGICAL, INTENT(OUT) :: new1, new2
1177 : INTEGER, INTENT(OUT) :: ref_tensor
1178 : LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
1179 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
1180 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
1181 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1182 : INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1183 : unit_nr_prv
1184 142630 : TYPE(mp_cart_type) :: comm_2d
1185 142630 : TYPE(array_list) :: dist_list
1186 142630 : INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims
1187 998410 : TYPE(dbt_distribution_type) :: dist_in
1188 : INTEGER(KIND=int_8) :: nblkrows, nblkcols
1189 : LOGICAL :: optimize_dist_prv
1190 285260 : INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1191 142630 : INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1192 :
1193 142630 : NULLIFY (tensor1_out, tensor2_out)
1194 :
1195 142630 : unit_nr_prv = prep_output_unit(unit_nr)
1196 :
1197 142630 : CALL blk_dims_tensor(tensor1, dims1)
1198 142630 : CALL blk_dims_tensor(tensor2, dims2)
1199 :
1200 980640 : IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
1201 141020 : ref_tensor = 1
1202 : ELSE
1203 1610 : ref_tensor = 2
1204 : END IF
1205 :
1206 142630 : IF (PRESENT(optimize_dist)) THEN
1207 306 : optimize_dist_prv = optimize_dist
1208 : ELSE
1209 : optimize_dist_prv = .FALSE.
1210 : END IF
1211 :
1212 142630 : compat1 = compat_map(tensor1%nd_index, ind1_linked)
1213 142630 : compat2 = compat_map(tensor2%nd_index, ind2_linked)
1214 142630 : compat1_old = compat1
1215 142630 : compat2_old = compat2
1216 :
1217 142630 : IF (unit_nr_prv > 0) THEN
1218 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
1219 6 : SELECT CASE (compat1)
1220 : CASE (0)
1221 6 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1222 : CASE (1)
1223 3 : WRITE (unit_nr_prv, '(A)') "Normal"
1224 : CASE (2)
1225 10 : WRITE (unit_nr_prv, '(A)') "Transposed"
1226 : END SELECT
1227 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
1228 5 : SELECT CASE (compat2)
1229 : CASE (0)
1230 5 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1231 : CASE (1)
1232 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1233 : CASE (2)
1234 10 : WRITE (unit_nr_prv, '(A)') "Transposed"
1235 : END SELECT
1236 : END IF
1237 :
1238 142630 : new1 = .FALSE.
1239 142630 : new2 = .FALSE.
1240 :
1241 142630 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1242 17401 : new1 = .TRUE.
1243 : END IF
1244 :
1245 142630 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1246 15042 : new2 = .TRUE.
1247 : END IF
1248 :
1249 142630 : IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1250 141020 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1251 17259 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
1252 51777 : nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
1253 34520 : nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
1254 17259 : comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1255 120813 : ALLOCATE (tensor1_out)
1256 : CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1257 17259 : nodata=nodata1, move_data=move_data_1)
1258 17259 : CALL comm_2d%free()
1259 17259 : compat1 = 1
1260 : ELSE
1261 123761 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1262 123761 : tensor1_out => tensor1
1263 : END IF
1264 141020 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1265 14900 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1266 8 : TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
1267 14896 : dist_in = dbt_distribution(tensor1_out)
1268 14896 : dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1269 14896 : IF (compat1 == 1) THEN ! linked index is first 2d dimension
1270 : ! get distribution of linked index, tensor 2 must adopt this distribution
1271 : ! get grid dimensions of linked index
1272 22662 : ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1273 7554 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1274 52878 : ALLOCATE (tensor2_out)
1275 : CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1276 7554 : dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1277 7342 : ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1278 : ! get distribution of linked index, tensor 2 must adopt this distribution
1279 : ! get grid dimensions of linked index
1280 22026 : ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1281 7342 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1282 51394 : ALLOCATE (tensor2_out)
1283 : CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1284 7342 : dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1285 : ELSE
1286 0 : CPABORT("should not happen")
1287 : END IF
1288 14896 : compat2 = compat1
1289 : ELSE
1290 126124 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1291 126124 : tensor2_out => tensor2
1292 : END IF
1293 : ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1294 1610 : IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1295 146 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
1296 302 : nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
1297 294 : nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
1298 146 : comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1299 1022 : ALLOCATE (tensor2_out)
1300 146 : CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1301 146 : CALL comm_2d%free()
1302 146 : compat2 = 1
1303 : ELSE
1304 1464 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1305 1464 : tensor2_out => tensor2
1306 : END IF
1307 1610 : IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1308 145 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
1309 6 : "compatible with", TRIM(tensor2%name)
1310 142 : dist_in = dbt_distribution(tensor2_out)
1311 142 : dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1312 142 : IF (compat2 == 1) THEN
1313 420 : ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1314 140 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1315 980 : ALLOCATE (tensor1_out)
1316 : CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1317 140 : dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1318 2 : ELSEIF (compat2 == 2) THEN
1319 6 : ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1320 2 : CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1321 14 : ALLOCATE (tensor1_out)
1322 : CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1323 2 : dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1324 : ELSE
1325 0 : CPABORT("should not happen")
1326 : END IF
1327 142 : compat1 = compat2
1328 : ELSE
1329 1468 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1330 1468 : tensor1_out => tensor1
1331 : END IF
1332 : END IF
1333 :
1334 96891 : SELECT CASE (compat1)
1335 : CASE (1)
1336 96891 : trans1 = .FALSE.
1337 : CASE (2)
1338 45739 : trans1 = .TRUE.
1339 : CASE DEFAULT
1340 142630 : CPABORT("should not happen")
1341 : END SELECT
1342 :
1343 97444 : SELECT CASE (compat2)
1344 : CASE (1)
1345 97444 : trans2 = .FALSE.
1346 : CASE (2)
1347 45186 : trans2 = .TRUE.
1348 : CASE DEFAULT
1349 142630 : CPABORT("should not happen")
1350 : END SELECT
1351 :
1352 142630 : IF (unit_nr_prv > 0) THEN
1353 10 : IF (compat1 .NE. compat1_old) THEN
1354 6 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
1355 0 : SELECT CASE (compat1)
1356 : CASE (0)
1357 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1358 : CASE (1)
1359 5 : WRITE (unit_nr_prv, '(A)') "Normal"
1360 : CASE (2)
1361 6 : WRITE (unit_nr_prv, '(A)') "Transposed"
1362 : END SELECT
1363 : END IF
1364 10 : IF (compat2 .NE. compat2_old) THEN
1365 5 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
1366 0 : SELECT CASE (compat2)
1367 : CASE (0)
1368 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1369 : CASE (1)
1370 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1371 : CASE (2)
1372 5 : WRITE (unit_nr_prv, '(A)') "Transposed"
1373 : END SELECT
1374 : END IF
1375 : END IF
1376 :
1377 142630 : IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
1378 142630 : IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
1379 :
1380 142630 : END SUBROUTINE
1381 :
1382 : ! **************************************************************************************************
1383 : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1384 : !> matrix multiplication. This routine reshapes the smallest of the three tensors.
1385 : !> \param ind1 index that should be mapped to first matrix dimension
1386 : !> \param ind2 index that should be mapped to second matrix dimension
1387 : !> \param trans transpose flag of matrix rep.
1388 : !> \param new whether a new tensor was created for tensor_out
1389 : !> \param nodata don't copy tensor data
1390 : !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
1391 : !> \param unit_nr output unit
1392 : !> \author Patrick Seewald
1393 : ! **************************************************************************************************
1394 142630 : SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1395 : TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor_in
1396 : INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2
1397 : TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor_out
1398 : LOGICAL, INTENT(OUT) :: trans
1399 : LOGICAL, INTENT(OUT) :: new
1400 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1401 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1402 : INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1403 : LOGICAL :: nodata_prv
1404 :
1405 142630 : NULLIFY (tensor_out)
1406 : IF (PRESENT(nodata)) THEN
1407 142630 : nodata_prv = nodata
1408 : ELSE
1409 142630 : nodata_prv = .FALSE.
1410 : END IF
1411 :
1412 142630 : unit_nr_prv = prep_output_unit(unit_nr)
1413 :
1414 142630 : new = .FALSE.
1415 142630 : compat1 = compat_map(tensor_in%nd_index, ind1)
1416 142630 : compat2 = compat_map(tensor_in%nd_index, ind2)
1417 142630 : compat1_old = compat1; compat2_old = compat2
1418 142630 : IF (unit_nr_prv > 0) THEN
1419 10 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
1420 10 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1421 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1422 6 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1423 2 : WRITE (unit_nr_prv, '(A)') "Transposed"
1424 : ELSE
1425 4 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1426 : END IF
1427 : END IF
1428 142630 : IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1429 :
1430 22 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
1431 :
1432 154 : ALLOCATE (tensor_out)
1433 22 : CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1434 22 : CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1435 22 : compat1 = 1
1436 22 : compat2 = 2
1437 22 : new = .TRUE.
1438 : ELSE
1439 142608 : IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
1440 142608 : tensor_out => tensor_in
1441 : END IF
1442 :
1443 142630 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1444 106036 : trans = .FALSE.
1445 36594 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1446 36594 : trans = .TRUE.
1447 : ELSE
1448 0 : CPABORT("this should not happen")
1449 : END IF
1450 :
1451 142630 : IF (unit_nr_prv > 0) THEN
1452 10 : IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1453 4 : WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
1454 4 : IF (compat1 == 1 .AND. compat2 == 2) THEN
1455 4 : WRITE (unit_nr_prv, '(A)') "Normal"
1456 0 : ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1457 0 : WRITE (unit_nr_prv, '(A)') "Transposed"
1458 : ELSE
1459 0 : WRITE (unit_nr_prv, '(A)') "Not compatible"
1460 : END IF
1461 : END IF
1462 : END IF
1463 :
1464 142630 : END SUBROUTINE
1465 :
1466 : ! **************************************************************************************************
1467 : !> \brief update contraction storage that keeps track of process grids during a batched contraction
1468 : !> and decide if tensor process grid needs to be optimized
1469 : !> \param split_opt optimized TAS process grid
1470 : !> \param split current TAS process grid
1471 : !> \author Patrick Seewald
1472 : ! **************************************************************************************************
1473 80187 : FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1474 : TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
1475 : TYPE(dbt_tas_split_info), INTENT(IN) :: split_opt
1476 : TYPE(dbt_tas_split_info), INTENT(IN) :: split
1477 : INTEGER, DIMENSION(2) :: pdims, pdims_sub
1478 : LOGICAL, DIMENSION(2) :: do_change_pgrid
1479 : REAL(kind=dp) :: change_criterion, pdims_ratio
1480 : INTEGER :: nsplit_opt, nsplit
1481 :
1482 80187 : CPASSERT(ALLOCATED(split_opt%ngroup_opt))
1483 80187 : nsplit_opt = split_opt%ngroup_opt
1484 80187 : nsplit = split%ngroup
1485 :
1486 240561 : pdims = split%mp_comm%num_pe_cart
1487 :
1488 80187 : storage%ibatch = storage%ibatch + 1
1489 :
1490 : storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, dp) + REAL(nsplit_opt, dp)) &
1491 80187 : /REAL(storage%ibatch, dp)
1492 :
1493 80187 : SELECT CASE (split_opt%split_rowcol)
1494 : CASE (rowsplit)
1495 80187 : pdims_ratio = REAL(pdims(1), dp)/pdims(2)
1496 : CASE (colsplit)
1497 80187 : pdims_ratio = REAL(pdims(2), dp)/pdims(1)
1498 : END SELECT
1499 :
1500 240561 : do_change_pgrid(:) = .FALSE.
1501 :
1502 : ! check for process grid dimensions
1503 240561 : pdims_sub = split%mp_comm_group%num_pe_cart
1504 481122 : change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
1505 80187 : IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
1506 :
1507 : ! check for split factor
1508 80187 : change_criterion = MAX(REAL(nsplit, dp)/storage%nsplit_avg, REAL(storage%nsplit_avg, dp)/nsplit)
1509 80187 : IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
1510 :
1511 80187 : END FUNCTION
1512 :
1513 : ! **************************************************************************************************
1514 : !> \brief Check if 2d index is compatible with tensor index
1515 : !> \author Patrick Seewald
1516 : ! **************************************************************************************************
1517 570520 : FUNCTION compat_map(nd_index, compat_ind)
1518 : TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1519 : INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind
1520 1141040 : INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1521 1141040 : INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1522 : INTEGER :: compat_map
1523 :
1524 570520 : CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1525 :
1526 570520 : compat_map = 0
1527 570520 : IF (array_eq_i(map1, compat_ind)) THEN
1528 : compat_map = 1
1529 258538 : ELSEIF (array_eq_i(map2, compat_ind)) THEN
1530 226673 : compat_map = 2
1531 : END IF
1532 :
1533 570520 : END FUNCTION
1534 :
1535 : ! **************************************************************************************************
1536 : !> \brief
1537 : !> \author Patrick Seewald
1538 : ! **************************************************************************************************
1539 427890 : SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1540 : INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1541 855780 : INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices
1542 :
1543 427890 : CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1544 1980572 : ind_dep(:) = ind_dep(sort_indices)
1545 :
1546 427890 : END SUBROUTINE
1547 :
1548 : ! **************************************************************************************************
1549 : !> \brief
1550 : !> \author Patrick Seewald
1551 : ! **************************************************************************************************
1552 0 : FUNCTION opt_pgrid(tensor, tas_split_info)
1553 : TYPE(dbt_type), INTENT(IN) :: tensor
1554 : TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
1555 0 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1556 0 : INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1557 : TYPE(dbt_pgrid_type) :: opt_pgrid
1558 0 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1559 :
1560 0 : CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1561 0 : CALL blk_dims_tensor(tensor, dims)
1562 0 : opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1563 :
1564 0 : ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
1565 0 : CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1566 0 : END FUNCTION
1567 :
1568 : ! **************************************************************************************************
1569 : !> \brief Copy tensor to tensor with modified index mapping
1570 : !> \param map1_2d new index mapping
1571 : !> \param map2_2d new index mapping
1572 : !> \author Patrick Seewald
1573 : ! **************************************************************************************************
1574 292185 : SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1575 32465 : mp_dims_1, mp_dims_2, name, nodata, move_data)
1576 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1577 : INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
1578 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1579 : CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1580 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1581 : CLASS(mp_comm_type), INTENT(IN), OPTIONAL :: comm_2d
1582 : TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1583 : INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1584 : INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1585 : CHARACTER(len=default_string_length) :: name_tmp
1586 32465 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("blk_sizes")}$, &
1587 32465 : ${varlist("nd_dist")}$
1588 227255 : TYPE(dbt_distribution_type) :: dist
1589 32465 : TYPE(mp_cart_type) :: comm_2d_prv
1590 : INTEGER :: handle, i
1591 32465 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1592 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_remap'
1593 : LOGICAL :: nodata_prv
1594 97395 : TYPE(dbt_pgrid_type) :: comm_nd
1595 :
1596 32465 : CALL timeset(routineN, handle)
1597 :
1598 32465 : IF (PRESENT(name)) THEN
1599 0 : name_tmp = name
1600 : ELSE
1601 32465 : name_tmp = tensor_in%name
1602 : END IF
1603 32465 : IF (PRESENT(dist1)) THEN
1604 7694 : CPASSERT(PRESENT(mp_dims_1))
1605 : END IF
1606 :
1607 32465 : IF (PRESENT(dist2)) THEN
1608 7344 : CPASSERT(PRESENT(mp_dims_2))
1609 : END IF
1610 :
1611 32465 : IF (PRESENT(comm_2d)) THEN
1612 32297 : comm_2d_prv = comm_2d
1613 : ELSE
1614 168 : comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1615 : END IF
1616 :
1617 32465 : comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1618 32465 : CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1619 :
1620 : #:for ndim in ndims
1621 64794 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1622 32329 : CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
1623 : END IF
1624 : #:endfor
1625 :
1626 : #:for ndim in ndims
1627 64926 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1628 : #:for idim in range(1, ndim+1)
1629 97263 : IF (PRESENT(dist1)) THEN
1630 53724 : IF (ANY(map1_2d == ${idim}$)) THEN
1631 45620 : i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
1632 7696 : CALL get_ith_array(dist1, i, nd_dist_${idim}$)
1633 : END IF
1634 : END IF
1635 :
1636 97263 : IF (PRESENT(dist2)) THEN
1637 58736 : IF (ANY(map2_2d == ${idim}$)) THEN
1638 44032 : i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
1639 14680 : CALL get_ith_array(dist2, i, nd_dist_${idim}$)
1640 : END IF
1641 : END IF
1642 :
1643 97263 : IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
1644 201993 : ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
1645 67331 : CALL dbt_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
1646 : END IF
1647 : #:endfor
1648 : CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1649 32465 : ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
1650 : END IF
1651 : #:endfor
1652 :
1653 : #:for ndim in ndims
1654 64926 : IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1655 : CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1656 32465 : ${varlist("blk_sizes", nmax=ndim)}$)
1657 : END IF
1658 : #:endfor
1659 :
1660 32465 : IF (PRESENT(nodata)) THEN
1661 14914 : nodata_prv = nodata
1662 : ELSE
1663 : nodata_prv = .FALSE.
1664 : END IF
1665 :
1666 32465 : IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1667 32465 : CALL dbt_distribution_destroy(dist)
1668 :
1669 32465 : CALL timestop(handle)
1670 97395 : END SUBROUTINE
1671 :
1672 : ! **************************************************************************************************
1673 : !> \brief Align index with data
1674 : !> \param order permutation resulting from alignment
1675 : !> \author Patrick Seewald
1676 : ! **************************************************************************************************
1677 3423120 : SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1678 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1679 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1680 855780 : INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1681 855780 : INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1682 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1683 : INTENT(OUT), OPTIONAL :: order
1684 427890 : INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1685 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_align_index'
1686 : INTEGER :: handle
1687 :
1688 427890 : CALL timeset(routineN, handle)
1689 :
1690 427890 : CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1691 2677474 : order_prv = dbt_inverse_order([map1_2d, map2_2d])
1692 427890 : CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1693 :
1694 1552682 : IF (PRESENT(order)) order = order_prv
1695 :
1696 427890 : CALL timestop(handle)
1697 427890 : END SUBROUTINE
1698 :
1699 : ! **************************************************************************************************
1700 : !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
1701 : !> \author Patrick Seewald
1702 : ! **************************************************************************************************
1703 4851360 : SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1704 : TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1705 : TYPE(dbt_type), INTENT(OUT) :: tensor_out
1706 : INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1707 : INTENT(IN) :: order
1708 :
1709 2695200 : TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1710 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_permute_index'
1711 : INTEGER :: handle
1712 : INTEGER :: ndims
1713 :
1714 539040 : CALL timeset(routineN, handle)
1715 :
1716 539040 : ndims = ndims_tensor(tensor_in)
1717 :
1718 539040 : CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1719 539040 : CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1720 539040 : CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1721 :
1722 539040 : tensor_out%matrix_rep => tensor_in%matrix_rep
1723 539040 : tensor_out%owns_matrix = .FALSE.
1724 :
1725 539040 : tensor_out%nd_index = nd_index_rs
1726 539040 : tensor_out%nd_index_blk = nd_index_blk_rs
1727 539040 : tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1728 539040 : IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1729 539040 : ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
1730 : END IF
1731 539040 : tensor_out%refcount => tensor_in%refcount
1732 539040 : CALL dbt_hold(tensor_out)
1733 :
1734 539040 : CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1735 539040 : CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1736 539040 : CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1737 539040 : CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1738 1617120 : ALLOCATE (tensor_out%nblks_local(ndims))
1739 1078080 : ALLOCATE (tensor_out%nfull_local(ndims))
1740 1979092 : tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1741 1979092 : tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1742 539040 : tensor_out%name = tensor_in%name
1743 539040 : tensor_out%valid = .TRUE.
1744 :
1745 539040 : IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1746 283486 : ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
1747 283486 : CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1748 283486 : CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1749 : END IF
1750 :
1751 539040 : CALL timestop(handle)
1752 1078080 : END SUBROUTINE
1753 :
1754 : ! **************************************************************************************************
1755 : !> \brief Map contraction bounds to bounds referring to tensor indices
1756 : !> see dbt_contract for docu of dummy arguments
1757 : !> \param bounds_t1 bounds mapped to tensor_1
1758 : !> \param bounds_t2 bounds mapped to tensor_2
1759 : !> \param do_crop_1 whether tensor 1 should be cropped
1760 : !> \param do_crop_2 whether tensor 2 should be cropped
1761 : !> \author Patrick Seewald
1762 : ! **************************************************************************************************
1763 170606 : SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1764 170606 : contract_1, notcontract_1, &
1765 341212 : contract_2, notcontract_2, &
1766 170606 : bounds_t1, bounds_t2, &
1767 116968 : bounds_1, bounds_2, bounds_3, &
1768 : do_crop_1, do_crop_2)
1769 :
1770 : TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2
1771 : INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, &
1772 : notcontract_1, notcontract_2
1773 : INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1774 : INTENT(OUT) :: bounds_t1
1775 : INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1776 : INTENT(OUT) :: bounds_t2
1777 : INTEGER, DIMENSION(2, SIZE(contract_1)), &
1778 : INTENT(IN), OPTIONAL :: bounds_1
1779 : INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1780 : INTENT(IN), OPTIONAL :: bounds_2
1781 : INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1782 : INTENT(IN), OPTIONAL :: bounds_3
1783 : LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2
1784 : LOGICAL, DIMENSION(2) :: do_crop
1785 :
1786 170606 : do_crop = .FALSE.
1787 :
1788 604668 : bounds_t1(1, :) = 1
1789 604668 : CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1790 :
1791 637922 : bounds_t2(1, :) = 1
1792 637922 : CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1793 :
1794 170606 : IF (PRESENT(bounds_1)) THEN
1795 172274 : bounds_t1(:, contract_1) = bounds_1
1796 74490 : do_crop(1) = .TRUE.
1797 172274 : bounds_t2(:, contract_2) = bounds_1
1798 170606 : do_crop(2) = .TRUE.
1799 : END IF
1800 :
1801 170606 : IF (PRESENT(bounds_2)) THEN
1802 231548 : bounds_t1(:, notcontract_1) = bounds_2
1803 170606 : do_crop(1) = .TRUE.
1804 : END IF
1805 :
1806 170606 : IF (PRESENT(bounds_3)) THEN
1807 256302 : bounds_t2(:, notcontract_2) = bounds_3
1808 170606 : do_crop(2) = .TRUE.
1809 : END IF
1810 :
1811 170606 : IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1812 170606 : IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1813 :
1814 387240 : END SUBROUTINE
1815 :
1816 : ! **************************************************************************************************
1817 : !> \brief print tensor contraction indices in a human readable way
1818 : !> \param indchar1 characters printed for index of tensor 1
1819 : !> \param indchar2 characters printed for index of tensor 2
1820 : !> \param indchar3 characters printed for index of tensor 3
1821 : !> \param unit_nr output unit
1822 : !> \author Patrick Seewald
1823 : ! **************************************************************************************************
1824 140526 : SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1825 : TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1826 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1827 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1828 : CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1829 : INTEGER, INTENT(IN) :: unit_nr
1830 281052 : INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1831 281052 : INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1832 281052 : INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1833 281052 : INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1834 281052 : INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1835 281052 : INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1836 : INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1837 :
1838 140526 : unit_nr_prv = prep_output_unit(unit_nr)
1839 :
1840 140526 : IF (unit_nr_prv /= 0) THEN
1841 140526 : CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1842 140526 : CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
1843 140526 : CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
1844 : END IF
1845 :
1846 140526 : IF (unit_nr_prv > 0) THEN
1847 30 : WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
1848 30 : WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
1849 123 : DO ichar1 = 1, SIZE(indchar1)
1850 123 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
1851 : END DO
1852 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1853 120 : DO ichar2 = 1, SIZE(indchar2)
1854 120 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
1855 : END DO
1856 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1857 123 : DO ichar3 = 1, SIZE(indchar3)
1858 123 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
1859 : END DO
1860 30 : WRITE (unit_nr_prv, '(A)') ")"
1861 :
1862 30 : WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
1863 82 : DO ichar1 = 1, SIZE(map11)
1864 82 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
1865 : END DO
1866 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1867 71 : DO ichar1 = 1, SIZE(map12)
1868 71 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
1869 : END DO
1870 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1871 76 : DO ichar2 = 1, SIZE(map21)
1872 76 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
1873 : END DO
1874 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1875 74 : DO ichar2 = 1, SIZE(map22)
1876 74 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
1877 : END DO
1878 30 : WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1879 79 : DO ichar3 = 1, SIZE(map31)
1880 79 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
1881 : END DO
1882 30 : WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1883 74 : DO ichar3 = 1, SIZE(map32)
1884 74 : WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
1885 : END DO
1886 30 : WRITE (unit_nr_prv, '(A)') ")"
1887 : END IF
1888 :
1889 140526 : END SUBROUTINE
1890 :
1891 : ! **************************************************************************************************
1892 : !> \brief Initialize batched contraction for this tensor.
1893 : !>
1894 : !> Explanation: A batched contraction is a contraction performed in several consecutive steps
1895 : !> by specification of bounds in dbt_contract. This can be used to reduce memory by
1896 : !> a large factor. The routines dbt_batched_contract_init and
1897 : !> dbt_batched_contract_finalize should be called to define the scope of a batched
1898 : !> contraction as this enables important optimizations (adapting communication scheme to
1899 : !> batches and adapting process grid to multiplication algorithm). The routines
1900 : !> dbt_batched_contract_init and dbt_batched_contract_finalize must be
1901 : !> called before the first and after the last contraction step on all 3 tensors.
1902 : !>
1903 : !> Requirements:
1904 : !> - the tensors are in a compatible matrix layout (see documentation of
1905 : !> `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
1906 : !> disabled and a warning is issued.
1907 : !> - within the scope of a batched contraction, it is not allowed to access or change tensor
1908 : !> data except by calling the routines dbt_contract & dbt_copy.
1909 : !> - the bounds affecting indices of the smallest tensor must not change in the course of a
1910 : !> batched contraction (todo: get rid of this requirement).
1911 : !>
1912 : !> Side effects:
1913 : !> - the parallel layout (process grid and distribution) of all tensors may change. In order
1914 : !> to disable the process grid optimization including this side effect, call this routine
1915 : !> only on the smallest of the 3 tensors.
1916 : !>
1917 : !> \note
1918 : !> Note 1: for an example of batched contraction see `examples/dbt_example.F`.
1919 : !> (todo: the example is outdated and should be updated).
1920 : !>
1921 : !> Note 2: it is meaningful to use this feature if the contraction consists of one batch only
1922 : !> but if multiple contractions involving the same 3 tensors are performed
1923 : !> (batched_contract_init and batched_contract_finalize must then be called before/after each
1924 : !> contraction call). The process grid is then optimized after the first contraction
1925 : !> and future contraction may profit from this optimization.
1926 : !>
1927 : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
1928 : !> a new range. The size should be the number of ranges plus one, the last
1929 : !> element being the block index plus one of the last block in the last range.
1930 : !> For internal load balancing optimizations, optionally specify the index
1931 : !> ranges of batched contraction.
1932 : !> \author Patrick Seewald
1933 : ! **************************************************************************************************
1934 99291 : SUBROUTINE dbt_batched_contract_init(tensor, ${varlist("batch_range")}$)
1935 : TYPE(dbt_type), INTENT(INOUT) :: tensor
1936 : INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$
1937 198582 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
1938 99291 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range_prv")}$
1939 : LOGICAL :: static_range
1940 :
1941 99291 : CALL dbt_get_info(tensor, nblks_total=tdims)
1942 :
1943 99291 : static_range = .TRUE.
1944 : #:for idim in range(1, maxdim+1)
1945 99291 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
1946 233136 : IF (PRESENT(batch_range_${idim}$)) THEN
1947 375232 : ALLOCATE (batch_range_prv_${idim}$, source=batch_range_${idim}$)
1948 233136 : static_range = .FALSE.
1949 : ELSE
1950 175662 : ALLOCATE (batch_range_prv_${idim}$ (2))
1951 175662 : batch_range_prv_${idim}$ (1) = 1
1952 175662 : batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
1953 : END IF
1954 : END IF
1955 : #:endfor
1956 :
1957 99291 : ALLOCATE (tensor%contraction_storage)
1958 99291 : tensor%contraction_storage%static = static_range
1959 99291 : IF (static_range) THEN
1960 67327 : CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
1961 : END IF
1962 99291 : tensor%contraction_storage%nsplit_avg = 0.0_dp
1963 99291 : tensor%contraction_storage%ibatch = 0
1964 :
1965 : #:for ndim in range(1, maxdim+1)
1966 198582 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
1967 : CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
1968 99291 : ${varlist("batch_range_prv", nmax=ndim)}$)
1969 : END IF
1970 : #:endfor
1971 :
1972 99291 : END SUBROUTINE
1973 :
1974 : ! **************************************************************************************************
1975 : !> \brief finalize batched contraction. This performs all communication that has been postponed in
1976 : !> the contraction calls.
1977 : !> \author Patrick Seewald
1978 : ! **************************************************************************************************
1979 198582 : SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
1980 : TYPE(dbt_type), INTENT(INOUT) :: tensor
1981 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1982 : LOGICAL :: do_write
1983 : INTEGER :: unit_nr_prv, handle
1984 :
1985 99291 : CALL tensor%pgrid%mp_comm_2d%sync()
1986 99291 : CALL timeset("dbt_total", handle)
1987 99291 : unit_nr_prv = prep_output_unit(unit_nr)
1988 :
1989 99291 : do_write = .FALSE.
1990 :
1991 99291 : IF (tensor%contraction_storage%static) THEN
1992 67327 : IF (tensor%matrix_rep%do_batched > 0) THEN
1993 67327 : IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
1994 : END IF
1995 67327 : CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
1996 : END IF
1997 :
1998 99291 : IF (do_write .AND. unit_nr_prv /= 0) THEN
1999 15990 : IF (unit_nr_prv > 0) THEN
2000 : WRITE (unit_nr_prv, "(T2,A)") &
2001 0 : "FINALIZING BATCHED PROCESSING OF MATMUL"
2002 : END IF
2003 15990 : CALL dbt_write_tensor_info(tensor, unit_nr_prv)
2004 15990 : CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
2005 : END IF
2006 :
2007 99291 : CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2008 99291 : DEALLOCATE (tensor%contraction_storage)
2009 99291 : CALL tensor%pgrid%mp_comm_2d%sync()
2010 99291 : CALL timestop(handle)
2011 :
2012 99291 : END SUBROUTINE
2013 :
2014 : ! **************************************************************************************************
2015 : !> \brief change the process grid of a tensor
2016 : !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
2017 : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2018 : !> a new range. The size should be the number of ranges plus one, the last
2019 : !> element being the block index plus one of the last block in the last range.
2020 : !> For internal load balancing optimizations, optionally specify the index
2021 : !> ranges of batched contraction.
2022 : !> \author Patrick Seewald
2023 : ! **************************************************************************************************
2024 776 : SUBROUTINE dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
2025 : nodata, pgrid_changed, unit_nr)
2026 : TYPE(dbt_type), INTENT(INOUT) :: tensor
2027 : TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
2028 : INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$
2029 : !!
2030 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
2031 : LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2032 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2033 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_change_pgrid'
2034 : CHARACTER(default_string_length) :: name
2035 : INTEGER :: handle
2036 776 : INTEGER, ALLOCATABLE, DIMENSION(:) :: ${varlist("bs")}$, &
2037 776 : ${varlist("dist")}$
2038 1552 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2039 1552 : tdims
2040 5432 : TYPE(dbt_type) :: t_tmp
2041 5432 : TYPE(dbt_distribution_type) :: dist
2042 1552 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2043 : INTEGER, &
2044 1552 : DIMENSION(ndims_matrix_column(tensor)) :: map2
2045 1552 : LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware
2046 776 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2047 : INTEGER :: ind1, ind2, batch_size, ibatch
2048 :
2049 776 : IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
2050 776 : CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2051 776 : CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2052 :
2053 800 : IF (ALL(pdims == pdims_ref)) THEN
2054 8 : IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2055 8 : IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2056 : RETURN
2057 : END IF
2058 : END IF
2059 : END IF
2060 :
2061 768 : CALL timeset(routineN, handle)
2062 :
2063 : #:for idim in range(1, maxdim+1)
2064 3072 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
2065 2304 : mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
2066 2304 : IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
2067 : END IF
2068 : #:endfor
2069 :
2070 768 : CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
2071 :
2072 : #:for idim in range(1, maxdim+1)
2073 3072 : IF (ndims_tensor(tensor) >= ${idim}$) THEN
2074 6912 : ALLOCATE (bs_${idim}$ (dbt_nblks_total(tensor, ${idim}$)))
2075 2304 : CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
2076 6912 : ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
2077 16860 : dist_${idim}$ = 0
2078 2304 : IF (mem_aware(${idim}$)) THEN
2079 6300 : DO ibatch = 1, nbatch(${idim}$)
2080 3996 : ind1 = batch_range_${idim}$ (ibatch)
2081 3996 : ind2 = batch_range_${idim}$ (ibatch + 1) - 1
2082 3996 : batch_size = ind2 - ind1 + 1
2083 : CALL dbt_default_distvec(batch_size, pdims(${idim}$), &
2084 6300 : bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
2085 : END DO
2086 : ELSE
2087 0 : CALL dbt_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
2088 : END IF
2089 : END IF
2090 : #:endfor
2091 :
2092 768 : CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2093 : #:for ndim in ndims
2094 1536 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
2095 768 : CALL dbt_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
2096 768 : CALL dbt_create(t_tmp, name, dist, map1, map2, ${varlist("bs", nmax=ndim)}$)
2097 : END IF
2098 : #:endfor
2099 768 : CALL dbt_distribution_destroy(dist)
2100 :
2101 768 : IF (PRESENT(nodata)) THEN
2102 0 : IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2103 : ELSE
2104 768 : CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2105 : END IF
2106 :
2107 768 : CALL dbt_copy_contraction_storage(tensor, t_tmp)
2108 :
2109 768 : CALL dbt_destroy(tensor)
2110 768 : tensor = t_tmp
2111 :
2112 768 : IF (PRESENT(unit_nr)) THEN
2113 768 : IF (unit_nr > 0) THEN
2114 0 : WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
2115 0 : WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2116 0 : CALL dbt_write_split_info(pgrid, unit_nr)
2117 : END IF
2118 : END IF
2119 :
2120 768 : IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
2121 :
2122 768 : CALL timestop(handle)
2123 776 : END SUBROUTINE
2124 :
2125 : ! **************************************************************************************************
2126 : !> \brief map tensor to a new 2d process grid for the matrix representation.
2127 : !> \author Patrick Seewald
2128 : ! **************************************************************************************************
2129 776 : SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2130 : TYPE(dbt_type), INTENT(INOUT) :: tensor
2131 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm
2132 : INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2133 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
2134 : INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2135 : LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2136 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2137 1552 : INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2138 1552 : INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2139 1552 : INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2140 2328 : TYPE(dbt_pgrid_type) :: pgrid
2141 776 : INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
2142 776 : INTEGER, DIMENSION(:), ALLOCATABLE :: array
2143 : INTEGER :: idim
2144 :
2145 776 : CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2146 776 : CALL blk_dims_tensor(tensor, dims)
2147 :
2148 776 : IF (ALLOCATED(tensor%contraction_storage)) THEN
2149 : ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
2150 3104 : nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2151 : ! for good load balancing the process grid dimensions should be chosen adapted to the
2152 : ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2153 : ! the number of batches (number of index ranges).
2154 3880 : DO idim = 1, ndims_tensor(tensor)
2155 2328 : CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2156 2328 : dims(idim) = array(nbatches(idim) + 1) - array(1)
2157 2328 : DEALLOCATE (array)
2158 2328 : dims(idim) = dims(idim)/nbatches(idim)
2159 5432 : IF (dims(idim) <= 0) dims(idim) = 1
2160 : END DO
2161 : END ASSOCIATE
2162 : END IF
2163 :
2164 776 : pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2165 776 : IF (ALLOCATED(tensor%contraction_storage)) THEN
2166 : #:for ndim in range(1, maxdim+1)
2167 1552 : IF (ndims_tensor(tensor) == ${ndim}$) THEN
2168 776 : CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
2169 : CALL dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
2170 776 : nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2171 : END IF
2172 : #:endfor
2173 : ELSE
2174 0 : CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2175 : END IF
2176 776 : CALL dbt_pgrid_destroy(pgrid)
2177 :
2178 776 : END SUBROUTINE
2179 :
2180 128466 : END MODULE
|