LCOV - code coverage report
Current view: top level - src/dbt - dbt_methods.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 845 895 94.4 %
Date: 2024-12-21 06:28:57 Functions: 22 23 95.7 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief DBT tensor framework for block-sparse tensor contraction.
      10             : !>        Representation of n-rank tensors as DBT tall-and-skinny matrices.
      11             : !>        Support for arbitrary redistribution between different representations.
      12             : !>        Support for arbitrary tensor contractions
      13             : !> \todo implement checks and error messages
      14             : !> \author Patrick Seewald
      15             : ! **************************************************************************************************
      16             : MODULE dbt_methods
      17             :    #:include "dbt_macros.fypp"
      18             :    #:set maxdim = maxrank
      19             :    #:set ndims = range(2,maxdim+1)
      20             : 
      21             :    USE cp_dbcsr_api, ONLY: &
      22             :       dbcsr_type, dbcsr_release, &
      23             :       dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
      24             :       dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
      25             :    USE dbt_allocate_wrap, ONLY: &
      26             :       allocate_any
      27             :    USE dbt_array_list_methods, ONLY: &
      28             :       get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
      29             :       create_array_list, destroy_array_list, sizes_of_arrays
      30             :    USE dbm_api, ONLY: &
      31             :       dbm_clear
      32             :    USE dbt_tas_types, ONLY: &
      33             :       dbt_tas_split_info
      34             :    USE dbt_tas_base, ONLY: &
      35             :       dbt_tas_copy, dbt_tas_finalize, dbt_tas_get_info, dbt_tas_info
      36             :    USE dbt_tas_mm, ONLY: &
      37             :       dbt_tas_multiply, dbt_tas_batched_mm_init, dbt_tas_batched_mm_finalize, &
      38             :       dbt_tas_batched_mm_complete, dbt_tas_set_batched_state
      39             :    USE dbt_block, ONLY: &
      40             :       dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
      41             :       dbt_iterator_blocks_left, dbt_iterator_stop, dbt_iterator_next_block, &
      42             :       ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
      43             :    USE dbt_index, ONLY: &
      44             :       dbt_get_mapping_info, nd_to_2d_mapping, dbt_inverse_order, permute_index, get_nd_indices_tensor, &
      45             :       ndims_mapping_row, ndims_mapping_column, ndims_mapping
      46             :    USE dbt_types, ONLY: &
      47             :       dbt_create, dbt_type, ndims_tensor, dims_tensor, &
      48             :       dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
      49             :       dbt_distribution_destroy, dbt_distribution_new_expert, dbt_get_stored_coordinates, &
      50             :       blk_dims_tensor, dbt_hold, dbt_pgrid_type, mp_environ_pgrid, dbt_filter, &
      51             :       dbt_clear, dbt_finalize, dbt_get_num_blocks, dbt_scale, &
      52             :       dbt_get_num_blocks_total, dbt_get_info, ndims_matrix_row, ndims_matrix_column, &
      53             :       dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
      54             :       dbt_distribution_new, dbt_copy_contraction_storage, dbt_pgrid_destroy
      55             :    USE kinds, ONLY: &
      56             :       dp, default_string_length, int_8, dp
      57             :    USE message_passing, ONLY: &
      58             :       mp_cart_type
      59             :    USE util, ONLY: &
      60             :       sort
      61             :    USE dbt_reshape_ops, ONLY: &
      62             :       dbt_reshape
      63             :    USE dbt_tas_split, ONLY: &
      64             :       dbt_tas_mp_comm, rowsplit, colsplit, dbt_tas_info_hold, dbt_tas_release_info, default_nsplit_accept_ratio, &
      65             :       default_pdims_accept_ratio, dbt_tas_create_split
      66             :    USE dbt_split, ONLY: &
      67             :       dbt_split_copyback, dbt_make_compatible_blocks, dbt_crop
      68             :    USE dbt_io, ONLY: &
      69             :       dbt_write_tensor_info, dbt_write_tensor_dist, prep_output_unit, dbt_write_split_info
      70             :    USE message_passing, ONLY: mp_comm_type
      71             : 
      72             : #include "../base/base_uses.f90"
      73             : 
      74             :    IMPLICIT NONE
      75             :    PRIVATE
      76             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
      77             : 
      78             :    PUBLIC :: &
      79             :       dbt_contract, &
      80             :       dbt_copy, &
      81             :       dbt_get_block, &
      82             :       dbt_get_stored_coordinates, &
      83             :       dbt_inverse_order, &
      84             :       dbt_iterator_blocks_left, &
      85             :       dbt_iterator_next_block, &
      86             :       dbt_iterator_start, &
      87             :       dbt_iterator_stop, &
      88             :       dbt_iterator_type, &
      89             :       dbt_put_block, &
      90             :       dbt_reserve_blocks, &
      91             :       dbt_copy_matrix_to_tensor, &
      92             :       dbt_copy_tensor_to_matrix, &
      93             :       dbt_batched_contract_init, &
      94             :       dbt_batched_contract_finalize
      95             : 
      96             : CONTAINS
      97             : 
      98             : ! **************************************************************************************************
      99             : !> \brief Copy tensor data.
     100             : !>        Redistributes tensor data according to distributions of target and source tensor.
     101             : !>        Permutes tensor index according to `order` argument (if present).
     102             : !>        Source and target tensor formats are arbitrary as long as the following requirements are met:
     103             : !>        * source and target tensors have the same rank and the same sizes in each dimension in terms
     104             : !>          of tensor elements (block sizes don't need to be the same).
     105             : !>          If `order` argument is present, sizes must match after index permutation.
     106             : !>        OR
     107             : !>        * target tensor is not yet created, in this case an exact copy of source tensor is returned.
     108             : !> \param tensor_in Source
     109             : !> \param tensor_out Target
     110             : !> \param order Permutation of target tensor index.
     111             : !>              Exact same convention as order argument of RESHAPE intrinsic.
     112             : !> \param bounds crop tensor data: start and end index for each tensor dimension
     113             : !> \author Patrick Seewald
     114             : ! **************************************************************************************************
     115      776824 :    SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     116             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     117             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     118             :          INTENT(IN), OPTIONAL                        :: order
     119             :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     120             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     121             :          INTENT(IN), OPTIONAL                        :: bounds
     122             :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     123             :       INTEGER :: handle
     124             : 
     125      388412 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     126      388412 :       CALL timeset("dbt_total", handle)
     127             : 
     128             :       ! make sure that it is safe to use dbt_copy during a batched contraction
     129      388412 :       CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
     130      388412 :       CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
     131             : 
     132      388412 :       CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     133      388412 :       CALL tensor_in%pgrid%mp_comm_2d%sync()
     134      388412 :       CALL timestop(handle)
     135      498986 :    END SUBROUTINE
     136             : 
     137             : ! **************************************************************************************************
     138             : !> \brief expert routine for copying a tensor. For internal use only.
     139             : !> \author Patrick Seewald
     140             : ! **************************************************************************************************
     141      421725 :    SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
     142             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
     143             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
     144             :          INTENT(IN), OPTIONAL                        :: order
     145             :       LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
     146             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
     147             :          INTENT(IN), OPTIONAL                        :: bounds
     148             :       INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
     149             : 
     150             :       TYPE(dbt_type), POINTER                    :: in_tmp_1, in_tmp_2, &
     151             :                                                     in_tmp_3, out_tmp_1
     152             :       INTEGER                                        :: handle, unit_nr_prv
     153      421725 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
     154             : 
     155             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy'
     156             :       LOGICAL                                        :: dist_compatible_tas, dist_compatible_tensor, &
     157             :                                                         summation_prv, new_in_1, new_in_2, &
     158             :                                                         new_in_3, new_out_1, block_compatible, &
     159             :                                                         move_prv
     160      421725 :       TYPE(array_list)                               :: blk_sizes_in
     161             : 
     162      421725 :       CALL timeset(routineN, handle)
     163             : 
     164      421725 :       CPASSERT(tensor_out%valid)
     165             : 
     166      421725 :       unit_nr_prv = prep_output_unit(unit_nr)
     167             : 
     168      421725 :       IF (PRESENT(move_data)) THEN
     169      312191 :          move_prv = move_data
     170             :       ELSE
     171      109534 :          move_prv = .FALSE.
     172             :       END IF
     173             : 
     174      421725 :       dist_compatible_tas = .FALSE.
     175      421725 :       dist_compatible_tensor = .FALSE.
     176      421725 :       block_compatible = .FALSE.
     177      421725 :       new_in_1 = .FALSE.
     178      421725 :       new_in_2 = .FALSE.
     179      421725 :       new_in_3 = .FALSE.
     180      421725 :       new_out_1 = .FALSE.
     181             : 
     182      421725 :       IF (PRESENT(summation)) THEN
     183      111235 :          summation_prv = summation
     184             :       ELSE
     185             :          summation_prv = .FALSE.
     186             :       END IF
     187             : 
     188      421725 :       IF (PRESENT(bounds)) THEN
     189       39424 :          ALLOCATE (in_tmp_1)
     190        5632 :          CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
     191        5632 :          new_in_1 = .TRUE.
     192        5632 :          move_prv = .TRUE.
     193             :       ELSE
     194             :          in_tmp_1 => tensor_in
     195             :       END IF
     196             : 
     197      421725 :       IF (PRESENT(order)) THEN
     198      110574 :          CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
     199      110574 :          block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
     200             :       ELSE
     201      311151 :          block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
     202             :       END IF
     203             : 
     204      421725 :       IF (.NOT. block_compatible) THEN
     205      923273 :          ALLOCATE (in_tmp_2, out_tmp_1)
     206             :          CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
     207       71021 :                                          nodata2=.NOT. summation_prv, move_data=move_prv)
     208       71021 :          new_in_2 = .TRUE.; new_out_1 = .TRUE.
     209       71021 :          move_prv = .TRUE.
     210             :       ELSE
     211             :          in_tmp_2 => in_tmp_1
     212             :          out_tmp_1 => tensor_out
     213             :       END IF
     214             : 
     215      421725 :       IF (PRESENT(order)) THEN
     216      774018 :          ALLOCATE (in_tmp_3)
     217      110574 :          CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
     218      110574 :          new_in_3 = .TRUE.
     219             :       ELSE
     220             :          in_tmp_3 => in_tmp_2
     221             :       END IF
     222             : 
     223     1265175 :       ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
     224     1265175 :       ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
     225      421725 :       CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
     226             : 
     227     1265175 :       ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
     228     1265175 :       ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
     229      421725 :       CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
     230             : 
     231      421725 :       IF (.NOT. PRESENT(order)) THEN
     232      311151 :          IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
     233      263410 :             dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     234      620585 :          ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
     235       22490 :             dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
     236             :          END IF
     237             :       END IF
     238             : 
     239      263410 :       IF (dist_compatible_tas) THEN
     240      215954 :          CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
     241      215954 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     242      205771 :       ELSEIF (dist_compatible_tensor) THEN
     243       14878 :          CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
     244       14878 :          IF (move_prv) CALL dbt_clear(in_tmp_3)
     245             :       ELSE
     246      190893 :          CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
     247             :       END IF
     248             : 
     249      421725 :       IF (new_in_1) THEN
     250        5632 :          CALL dbt_destroy(in_tmp_1)
     251        5632 :          DEALLOCATE (in_tmp_1)
     252             :       END IF
     253             : 
     254      421725 :       IF (new_in_2) THEN
     255       71021 :          CALL dbt_destroy(in_tmp_2)
     256       71021 :          DEALLOCATE (in_tmp_2)
     257             :       END IF
     258             : 
     259      421725 :       IF (new_in_3) THEN
     260      110574 :          CALL dbt_destroy(in_tmp_3)
     261      110574 :          DEALLOCATE (in_tmp_3)
     262             :       END IF
     263             : 
     264      421725 :       IF (new_out_1) THEN
     265       71021 :          IF (unit_nr_prv /= 0) THEN
     266           0 :             CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
     267             :          END IF
     268       71021 :          CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
     269       71021 :          CALL dbt_destroy(out_tmp_1)
     270       71021 :          DEALLOCATE (out_tmp_1)
     271             :       END IF
     272             : 
     273      421725 :       CALL timestop(handle)
     274             : 
     275      843450 :    END SUBROUTINE
     276             : 
     277             : ! **************************************************************************************************
     278             : !> \brief copy without communication, requires that both tensors have same process grid and distribution
     279             : !> \param summation Whether to sum matrices b = a + b
     280             : !> \author Patrick Seewald
     281             : ! **************************************************************************************************
     282       14878 :    SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
     283             :       TYPE(dbt_type), INTENT(INOUT) :: tensor_in
     284             :       TYPE(dbt_type), INTENT(INOUT) :: tensor_out
     285             :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     286             :       TYPE(dbt_iterator_type) :: iter
     287       14878 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))  :: ind_nd
     288       14878 :       TYPE(block_nd) :: blk_data
     289             :       LOGICAL :: found
     290             : 
     291             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_nocomm'
     292             :       INTEGER :: handle
     293             : 
     294       14878 :       CALL timeset(routineN, handle)
     295       14878 :       CPASSERT(tensor_out%valid)
     296             : 
     297       14878 :       IF (PRESENT(summation)) THEN
     298        7356 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     299             :       ELSE
     300        7522 :          CALL dbt_clear(tensor_out)
     301             :       END IF
     302             : 
     303       14878 :       CALL dbt_reserve_blocks(tensor_in, tensor_out)
     304             : 
     305             : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
     306       14878 : !$OMP PRIVATE(iter,ind_nd,blk_data,found)
     307             :       CALL dbt_iterator_start(iter, tensor_in)
     308             :       DO WHILE (dbt_iterator_blocks_left(iter))
     309             :          CALL dbt_iterator_next_block(iter, ind_nd)
     310             :          CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
     311             :          CPASSERT(found)
     312             :          CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
     313             :          CALL destroy_block(blk_data)
     314             :       END DO
     315             :       CALL dbt_iterator_stop(iter)
     316             : !$OMP END PARALLEL
     317             : 
     318       14878 :       CALL timestop(handle)
     319       29756 :    END SUBROUTINE
     320             : 
     321             : ! **************************************************************************************************
     322             : !> \brief copy matrix to tensor.
     323             : !> \param summation tensor_out = tensor_out + matrix_in
     324             : !> \author Patrick Seewald
     325             : ! **************************************************************************************************
     326       64638 :    SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
     327             :       TYPE(dbcsr_type), TARGET, INTENT(IN)               :: matrix_in
     328             :       TYPE(dbt_type), INTENT(INOUT)             :: tensor_out
     329             :       LOGICAL, INTENT(IN), OPTIONAL                      :: summation
     330             :       TYPE(dbcsr_type), POINTER                          :: matrix_in_desym
     331             : 
     332             :       INTEGER, DIMENSION(2)                              :: ind_2d
     333       64638 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)    :: block_arr
     334       64638 :       REAL(KIND=dp), DIMENSION(:, :), POINTER        :: block
     335             :       TYPE(dbcsr_iterator_type)                          :: iter
     336             : 
     337             :       INTEGER                                            :: handle
     338             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_matrix_to_tensor'
     339             : 
     340       64638 :       CALL timeset(routineN, handle)
     341       64638 :       CPASSERT(tensor_out%valid)
     342             : 
     343       64638 :       NULLIFY (block)
     344             : 
     345       64638 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     346        4522 :          ALLOCATE (matrix_in_desym)
     347        4522 :          CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
     348             :       ELSE
     349             :          matrix_in_desym => matrix_in
     350             :       END IF
     351             : 
     352       64638 :       IF (PRESENT(summation)) THEN
     353           0 :          IF (.NOT. summation) CALL dbt_clear(tensor_out)
     354             :       ELSE
     355       64638 :          CALL dbt_clear(tensor_out)
     356             :       END IF
     357             : 
     358       64638 :       CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
     359             : 
     360             : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
     361       64638 : !$OMP PRIVATE(iter,ind_2d,block,block_arr)
     362             :       CALL dbcsr_iterator_start(iter, matrix_in_desym)
     363             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     364             :          CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block)
     365             :          CALL allocate_any(block_arr, source=block)
     366             :          CALL dbt_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
     367             :          DEALLOCATE (block_arr)
     368             :       END DO
     369             :       CALL dbcsr_iterator_stop(iter)
     370             : !$OMP END PARALLEL
     371             : 
     372       64638 :       IF (dbcsr_has_symmetry(matrix_in)) THEN
     373        4522 :          CALL dbcsr_release(matrix_in_desym)
     374        4522 :          DEALLOCATE (matrix_in_desym)
     375             :       END IF
     376             : 
     377       64638 :       CALL timestop(handle)
     378             : 
     379      129276 :    END SUBROUTINE
     380             : 
     381             : ! **************************************************************************************************
     382             : !> \brief copy tensor to matrix
     383             : !> \param summation matrix_out = matrix_out + tensor_in
     384             : !> \author Patrick Seewald
     385             : ! **************************************************************************************************
     386       42180 :    SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
     387             :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
     388             :       TYPE(dbcsr_type), INTENT(INOUT)             :: matrix_out
     389             :       LOGICAL, INTENT(IN), OPTIONAL          :: summation
     390             :       TYPE(dbt_iterator_type)            :: iter
     391             :       INTEGER                                :: handle
     392             :       INTEGER, DIMENSION(2)                  :: ind_2d
     393       42180 :       REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE :: block
     394             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_copy_tensor_to_matrix'
     395             :       LOGICAL :: found
     396             : 
     397       42180 :       CALL timeset(routineN, handle)
     398             : 
     399       42180 :       IF (PRESENT(summation)) THEN
     400        5812 :          IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
     401             :       ELSE
     402       36368 :          CALL dbcsr_clear(matrix_out)
     403             :       END IF
     404             : 
     405       42180 :       CALL dbt_reserve_blocks(tensor_in, matrix_out)
     406             : 
     407             : !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
     408       42180 : !$OMP PRIVATE(iter,ind_2d,block,found)
     409             :       CALL dbt_iterator_start(iter, tensor_in)
     410             :       DO WHILE (dbt_iterator_blocks_left(iter))
     411             :          CALL dbt_iterator_next_block(iter, ind_2d)
     412             :          IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
     413             : 
     414             :          CALL dbt_get_block(tensor_in, ind_2d, block, found)
     415             :          CPASSERT(found)
     416             : 
     417             :          IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
     418             :             CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
     419             :          ELSE
     420             :             CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
     421             :          END IF
     422             :          DEALLOCATE (block)
     423             :       END DO
     424             :       CALL dbt_iterator_stop(iter)
     425             : !$OMP END PARALLEL
     426             : 
     427       42180 :       CALL timestop(handle)
     428             : 
     429       84360 :    END SUBROUTINE
     430             : 
     431             : ! **************************************************************************************************
     432             : !> \brief Contract tensors by multiplying matrix representations.
     433             : !>        tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
     434             : !>        * tensor_2(contract_2, notcontract_2)
     435             : !>        + beta * tensor_3(map_1, map_2)
     436             : !>
     437             : !> \note
     438             : !>      note 1: block sizes of the corresponding indices need to be the same in all tensors.
     439             : !>
     440             : !>      note 2: for best performance the tensors should have been created in matrix layouts
     441             : !>      compatible with the contraction, e.g. tensor_1 should have been created with either
     442             : !>      map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
     443             : !>      map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
     444             : !>      tensor_3 and map_1 / map_2).
     445             : !>      Furthermore the two largest tensors involved in the contraction should map both to either
     446             : !>      tall or short matrices: the largest matrix dimension should be "on the same side"
     447             : !>      and should have identical distribution (which is always the case if the distributions were
     448             : !>      obtained with dbt_default_distvec).
     449             : !>
     450             : !>      note 3: if the same tensor occurs in multiple contractions, a different tensor object should
     451             : !>      be created for each contraction and the data should be copied between the tensors by use of
     452             : !>      dbt_copy. If the same tensor object is used in multiple contractions,
     453             : !>       matrix layouts are not compatible for all contractions (see note 2).
     454             : !>
     455             : !>      note 4: automatic optimizations are enabled by using the feature of batched contraction, see
     456             : !>      dbt_batched_contract_init, dbt_batched_contract_finalize.
     457             : !>      The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
     458             : !>
     459             : !> \param tensor_1 first tensor (in)
     460             : !> \param tensor_2 second tensor (in)
     461             : !> \param contract_1 indices of tensor_1 to contract
     462             : !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
     463             : !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
     464             : !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
     465             : !> \param notcontract_1 indices of tensor_1 not to contract
     466             : !> \param notcontract_2 indices of tensor_2 not to contract
     467             : !> \param tensor_3 contracted tensor (out)
     468             : !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
     469             : !>                 start and end index of an index range over which to contract.
     470             : !>                 For use in batched contraction.
     471             : !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
     472             : !>                 For use in batched contraction.
     473             : !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
     474             : !>                 For use in batched contraction.
     475             : !> \param optimize_dist Whether distribution should be optimized internally. In the current
     476             : !>                      implementation this guarantees optimal parameters only for dense matrices.
     477             : !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
     478             : !>                    This can be used to choose optimal process grids for subsequent tensor
     479             : !>                    contractions with tensors of similar shape and sparsity. Under some conditions,
     480             : !>                    pgrid_opt_1 can not be returned, in this case the pointer is not associated.
     481             : !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
     482             : !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
     483             : !> \param filter_eps As in DBM mm
     484             : !> \param flop As in DBM mm
     485             : !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
     486             : !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
     487             : !> \param unit_nr output unit for logging
     488             : !>                       set it to -1 on ranks that should not write (and any valid unit number on
     489             : !>                       ranks that should write output) if 0 on ALL ranks, no output is written
     490             : !> \param log_verbose verbose logging (for testing only)
     491             : !> \author Patrick Seewald
     492             : ! **************************************************************************************************
     493      338988 :    SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
     494      169494 :                            contract_1, notcontract_1, &
     495      169494 :                            contract_2, notcontract_2, &
     496      169494 :                            map_1, map_2, &
     497      115240 :                            bounds_1, bounds_2, bounds_3, &
     498             :                            optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     499             :                            filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
     500             :       REAL(dp), INTENT(IN)            :: alpha
     501             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     502             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     503             :       REAL(dp), INTENT(IN)            :: beta
     504             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     505             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     506             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     507             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     508             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     509             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     510             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     511             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     512             :          INTENT(IN), OPTIONAL                        :: bounds_1
     513             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     514             :          INTENT(IN), OPTIONAL                        :: bounds_2
     515             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     516             :          INTENT(IN), OPTIONAL                        :: bounds_3
     517             :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     518             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     519             :          POINTER, OPTIONAL                           :: pgrid_opt_1
     520             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     521             :          POINTER, OPTIONAL                           :: pgrid_opt_2
     522             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     523             :          POINTER, OPTIONAL                           :: pgrid_opt_3
     524             :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     525             :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     526             :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     527             :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     528             :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     529             :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     530             : 
     531             :       INTEGER                     :: handle
     532             : 
     533      169494 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     534      169494 :       CALL timeset("dbt_total", handle)
     535             :       CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     536             :                                contract_1, notcontract_1, &
     537             :                                contract_2, notcontract_2, &
     538             :                                map_1, map_2, &
     539             :                                bounds_1=bounds_1, &
     540             :                                bounds_2=bounds_2, &
     541             :                                bounds_3=bounds_3, &
     542             :                                optimize_dist=optimize_dist, &
     543             :                                pgrid_opt_1=pgrid_opt_1, &
     544             :                                pgrid_opt_2=pgrid_opt_2, &
     545             :                                pgrid_opt_3=pgrid_opt_3, &
     546             :                                filter_eps=filter_eps, &
     547             :                                flop=flop, &
     548             :                                move_data=move_data, &
     549             :                                retain_sparsity=retain_sparsity, &
     550             :                                unit_nr=unit_nr, &
     551      169494 :                                log_verbose=log_verbose)
     552      169494 :       CALL tensor_1%pgrid%mp_comm_2d%sync()
     553      169494 :       CALL timestop(handle)
     554             : 
     555      239968 :    END SUBROUTINE
     556             : 
     557             : ! **************************************************************************************************
     558             : !> \brief expert routine for tensor contraction. For internal use only.
     559             : !> \param nblks_local number of local blocks on this MPI rank
     560             : !> \author Patrick Seewald
     561             : ! **************************************************************************************************
     562      169494 :    SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
     563      169494 :                                   contract_1, notcontract_1, &
     564      169494 :                                   contract_2, notcontract_2, &
     565      169494 :                                   map_1, map_2, &
     566      169494 :                                   bounds_1, bounds_2, bounds_3, &
     567             :                                   optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
     568             :                                   filter_eps, flop, move_data, retain_sparsity, &
     569             :                                   nblks_local, unit_nr, log_verbose)
     570             :       REAL(dp), INTENT(IN)            :: alpha
     571             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_1
     572             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_2
     573             :       REAL(dp), INTENT(IN)            :: beta
     574             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
     575             :       INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
     576             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
     577             :       INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
     578             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
     579             :       INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
     580             :       TYPE(dbt_type), INTENT(INOUT), TARGET      :: tensor_3
     581             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
     582             :          INTENT(IN), OPTIONAL                        :: bounds_1
     583             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
     584             :          INTENT(IN), OPTIONAL                        :: bounds_2
     585             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
     586             :          INTENT(IN), OPTIONAL                        :: bounds_3
     587             :       LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
     588             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     589             :          POINTER, OPTIONAL                           :: pgrid_opt_1
     590             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     591             :          POINTER, OPTIONAL                           :: pgrid_opt_2
     592             :       TYPE(dbt_pgrid_type), INTENT(OUT), &
     593             :          POINTER, OPTIONAL                           :: pgrid_opt_3
     594             :       REAL(KIND=dp), INTENT(IN), OPTIONAL        :: filter_eps
     595             :       INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
     596             :       LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
     597             :       LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
     598             :       INTEGER, INTENT(OUT), OPTIONAL                 :: nblks_local
     599             :       INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
     600             :       LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
     601             : 
     602             :       TYPE(dbt_type), POINTER                    :: tensor_contr_1, tensor_contr_2, tensor_contr_3
     603     3220386 :       TYPE(dbt_type), TARGET                     :: tensor_algn_1, tensor_algn_2, tensor_algn_3
     604             :       TYPE(dbt_type), POINTER                    :: tensor_crop_1, tensor_crop_2
     605             :       TYPE(dbt_type), POINTER                    :: tensor_small, tensor_large
     606             : 
     607             :       LOGICAL                                        :: assert_stmt, tensors_remapped
     608             :       INTEGER                                        :: max_mm_dim, max_tensor, &
     609             :                                                         unit_nr_prv, ref_tensor, handle
     610      169494 :       TYPE(mp_cart_type) :: mp_comm_opt
     611      338988 :       INTEGER, DIMENSION(SIZE(contract_1))           :: contract_1_mod
     612      338988 :       INTEGER, DIMENSION(SIZE(notcontract_1))        :: notcontract_1_mod
     613      338988 :       INTEGER, DIMENSION(SIZE(contract_2))           :: contract_2_mod
     614      338988 :       INTEGER, DIMENSION(SIZE(notcontract_2))        :: notcontract_2_mod
     615      338988 :       INTEGER, DIMENSION(SIZE(map_1))                :: map_1_mod
     616      338988 :       INTEGER, DIMENSION(SIZE(map_2))                :: map_2_mod
     617             :       LOGICAL                                        :: trans_1, trans_2, trans_3
     618             :       LOGICAL                                        :: new_1, new_2, new_3, move_data_1, move_data_2
     619             :       INTEGER                                        :: ndims1, ndims2, ndims3
     620             :       INTEGER                                        :: occ_1, occ_2
     621      169494 :       INTEGER, DIMENSION(:), ALLOCATABLE             :: dims1, dims2, dims3
     622             : 
     623             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_contract'
     624      169494 :       CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE    :: indchar1, indchar2, indchar3, indchar1_mod, &
     625      169494 :                                                         indchar2_mod, indchar3_mod
     626             :       CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
     627             :                                                ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
     628      338988 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
     629      338988 :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
     630             :       LOGICAL                                        :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
     631             :                                                         pgrid_changed_any, do_change_pgrid(2)
     632     2203422 :       TYPE(dbt_tas_split_info)                     :: split_opt, split, split_opt_avg
     633             :       INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
     634             :       REAL(dp) :: pdim_ratio, pdim_ratio_opt
     635             : 
     636      169494 :       NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
     637      169494 :                tensor_small)
     638             : 
     639      169494 :       CALL timeset(routineN, handle)
     640             : 
     641      169494 :       CPASSERT(tensor_1%valid)
     642      169494 :       CPASSERT(tensor_2%valid)
     643      169494 :       CPASSERT(tensor_3%valid)
     644             : 
     645      169494 :       assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
     646      169494 :       CPASSERT(assert_stmt)
     647             : 
     648      169494 :       assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
     649      169494 :       CPASSERT(assert_stmt)
     650             : 
     651      169494 :       assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
     652      169494 :       CPASSERT(assert_stmt)
     653             : 
     654      169494 :       assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
     655      169494 :       CPASSERT(assert_stmt)
     656             : 
     657      169494 :       assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
     658      169494 :       CPASSERT(assert_stmt)
     659             : 
     660      169494 :       assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
     661      169494 :       CPASSERT(assert_stmt)
     662             : 
     663      169494 :       unit_nr_prv = prep_output_unit(unit_nr)
     664             : 
     665      169494 :       IF (PRESENT(flop)) flop = 0
     666      169494 :       IF (PRESENT(nblks_local)) nblks_local = 0
     667             : 
     668      169494 :       IF (PRESENT(move_data)) THEN
     669       34847 :          move_data_1 = move_data
     670       34847 :          move_data_2 = move_data
     671             :       ELSE
     672      134647 :          move_data_1 = .FALSE.
     673      134647 :          move_data_2 = .FALSE.
     674             :       END IF
     675             : 
     676      169494 :       nodata_3 = .TRUE.
     677      169494 :       IF (PRESENT(retain_sparsity)) THEN
     678        4762 :          IF (retain_sparsity) nodata_3 = .FALSE.
     679             :       END IF
     680             : 
     681             :       CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
     682             :                                      contract_1, notcontract_1, &
     683             :                                      contract_2, notcontract_2, &
     684             :                                      bounds_t1, bounds_t2, &
     685             :                                      bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
     686      169494 :                                      do_crop_1=do_crop_1, do_crop_2=do_crop_2)
     687             : 
     688      169494 :       IF (do_crop_1) THEN
     689      476574 :          ALLOCATE (tensor_crop_1)
     690       68082 :          CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
     691       68082 :          move_data_1 = .TRUE.
     692             :       ELSE
     693             :          tensor_crop_1 => tensor_1
     694             :       END IF
     695             : 
     696      169494 :       IF (do_crop_2) THEN
     697      461230 :          ALLOCATE (tensor_crop_2)
     698       65890 :          CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
     699       65890 :          move_data_2 = .TRUE.
     700             :       ELSE
     701             :          tensor_crop_2 => tensor_2
     702             :       END IF
     703             : 
     704             :       ! shortcut for empty tensors
     705             :       ! this is needed to avoid unnecessary work in case user contracts different portions of a
     706             :       ! tensor consecutively to save memory
     707             :       ASSOCIATE (mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
     708      169494 :          occ_1 = dbt_get_num_blocks(tensor_crop_1)
     709      169494 :          CALL mp_comm%max(occ_1)
     710      169494 :          occ_2 = dbt_get_num_blocks(tensor_crop_2)
     711      169494 :          CALL mp_comm%max(occ_2)
     712             :       END ASSOCIATE
     713             : 
     714      169494 :       IF (occ_1 == 0 .OR. occ_2 == 0) THEN
     715       27984 :          CALL dbt_scale(tensor_3, beta)
     716       27984 :          IF (do_crop_1) THEN
     717        2746 :             CALL dbt_destroy(tensor_crop_1)
     718        2746 :             DEALLOCATE (tensor_crop_1)
     719             :          END IF
     720       27984 :          IF (do_crop_2) THEN
     721        2760 :             CALL dbt_destroy(tensor_crop_2)
     722        2760 :             DEALLOCATE (tensor_crop_2)
     723             :          END IF
     724             : 
     725       27984 :          CALL timestop(handle)
     726       27984 :          RETURN
     727             :       END IF
     728             : 
     729      141510 :       IF (unit_nr_prv /= 0) THEN
     730       45682 :          IF (unit_nr_prv > 0) THEN
     731          10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     732          10 :             WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
     733          20 :                TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
     734          10 :             WRITE (unit_nr_prv, '(A)') repeat("-", 80)
     735             :          END IF
     736       45682 :          CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
     737       45682 :          CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
     738       45682 :          CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
     739       45682 :          CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
     740             :       END IF
     741             : 
     742             :       ! align tensor index with data, tensor data is not modified
     743      141510 :       ndims1 = ndims_tensor(tensor_crop_1)
     744      141510 :       ndims2 = ndims_tensor(tensor_crop_2)
     745      141510 :       ndims3 = ndims_tensor(tensor_3)
     746      566040 :       ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
     747      566040 :       ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
     748      566040 :       ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
     749             : 
     750             :       ! labeling tensor index with letters
     751             : 
     752     1217112 :       indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
     753      340892 :       indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
     754      330220 :       indchar2(contract_2) = indchar1(contract_1)
     755      311334 :       indchar3(map_1) = indchar1(notcontract_1)
     756      340892 :       indchar3(map_2) = indchar2(notcontract_2)
     757             : 
     758      141510 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
     759             :                                                              tensor_crop_2, indchar2, &
     760       45682 :                                                              tensor_3, indchar3, unit_nr_prv)
     761      141510 :       IF (unit_nr_prv > 0) THEN
     762          10 :          WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
     763             :       END IF
     764             : 
     765             :       CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
     766      141510 :                         tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
     767             : 
     768             :       CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
     769      141510 :                         tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
     770             : 
     771             :       CALL align_tensor(tensor_3, map_1, map_2, &
     772      141510 :                         tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
     773             : 
     774      141510 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
     775             :                                                              tensor_algn_2, indchar2_mod, &
     776       45682 :                                                              tensor_algn_3, indchar3_mod, unit_nr_prv)
     777             : 
     778      424530 :       ALLOCATE (dims1(ndims1))
     779      424530 :       ALLOCATE (dims2(ndims2))
     780      424530 :       ALLOCATE (dims3(ndims3))
     781             : 
     782             :       ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
     783             :       ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
     784      141510 :       CALL blk_dims_tensor(tensor_crop_1, dims1)
     785      141510 :       CALL blk_dims_tensor(tensor_crop_2, dims2)
     786      141510 :       CALL blk_dims_tensor(tensor_3, dims3)
     787             : 
     788             :       max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
     789             :                            PRODUCT(INT(dims1(contract_1), int_8)), &
     790     1123956 :                            PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
     791     1681872 :       max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
     792       36508 :       SELECT CASE (max_mm_dim)
     793             :       CASE (1)
     794       36508 :          IF (unit_nr_prv > 0) THEN
     795           3 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
     796           3 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     797             :          END IF
     798       36508 :          CALL index_linked_sort(contract_1_mod, contract_2_mod)
     799       36508 :          CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     800       36080 :          SELECT CASE (max_tensor)
     801             :          CASE (1)
     802       36080 :             CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     803             :          CASE (3)
     804         428 :             CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     805             :          CASE DEFAULT
     806       36508 :             CPABORT("should not happen")
     807             :          END SELECT
     808             : 
     809             :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
     810             :                                     contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
     811             :                                     trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     812       36508 :                                     move_data_1=move_data_1, unit_nr=unit_nr_prv)
     813             : 
     814             :          CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
     815       36508 :                                new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
     816             : 
     817       36080 :          SELECT CASE (ref_tensor)
     818             :          CASE (1)
     819       36080 :             tensor_large => tensor_contr_1
     820             :          CASE (2)
     821       36508 :             tensor_large => tensor_contr_3
     822             :          END SELECT
     823       36508 :          tensor_small => tensor_contr_2
     824             : 
     825             :       CASE (2)
     826       47672 :          IF (unit_nr_prv > 0) THEN
     827           5 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
     828           5 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     829             :          END IF
     830             : 
     831       47672 :          CALL index_linked_sort(notcontract_1_mod, map_1_mod)
     832       47672 :          CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     833       46852 :          SELECT CASE (max_tensor)
     834             :          CASE (1)
     835       46852 :             CALL index_linked_sort(contract_1_mod, contract_2_mod)
     836             :          CASE (2)
     837         820 :             CALL index_linked_sort(contract_2_mod, contract_1_mod)
     838             :          CASE DEFAULT
     839       47672 :             CPABORT("should not happen")
     840             :          END SELECT
     841             : 
     842             :          CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
     843             :                                     notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
     844             :                                     trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
     845       47672 :                                     move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
     846       47672 :          trans_1 = .NOT. trans_1
     847             : 
     848             :          CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
     849       47672 :                                new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
     850             : 
     851       46852 :          SELECT CASE (ref_tensor)
     852             :          CASE (1)
     853       46852 :             tensor_large => tensor_contr_1
     854             :          CASE (2)
     855       47672 :             tensor_large => tensor_contr_2
     856             :          END SELECT
     857       47672 :          tensor_small => tensor_contr_3
     858             : 
     859             :       CASE (3)
     860       57330 :          IF (unit_nr_prv > 0) THEN
     861           2 :             WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
     862           2 :             WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
     863             :          END IF
     864       57330 :          CALL index_linked_sort(map_1_mod, notcontract_1_mod)
     865       57330 :          CALL index_linked_sort(contract_2_mod, contract_1_mod)
     866       56948 :          SELECT CASE (max_tensor)
     867             :          CASE (2)
     868       56948 :             CALL index_linked_sort(notcontract_2_mod, map_2_mod)
     869             :          CASE (3)
     870         382 :             CALL index_linked_sort(map_2_mod, notcontract_2_mod)
     871             :          CASE DEFAULT
     872       57330 :             CPABORT("should not happen")
     873             :          END SELECT
     874             : 
     875             :          CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
     876             :                                     contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
     877             :                                     trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
     878       57330 :                                     move_data_1=move_data_2, unit_nr=unit_nr_prv)
     879             : 
     880       57330 :          trans_2 = .NOT. trans_2
     881       57330 :          trans_3 = .NOT. trans_3
     882             : 
     883             :          CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
     884       57330 :                                trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
     885             : 
     886       56948 :          SELECT CASE (ref_tensor)
     887             :          CASE (1)
     888       56948 :             tensor_large => tensor_contr_2
     889             :          CASE (2)
     890       57330 :             tensor_large => tensor_contr_3
     891             :          END SELECT
     892      198840 :          tensor_small => tensor_contr_1
     893             : 
     894             :       END SELECT
     895             : 
     896      141510 :       IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
     897             :                                                              tensor_contr_2, indchar2_mod, &
     898       45682 :                                                              tensor_contr_3, indchar3_mod, unit_nr_prv)
     899      141510 :       IF (unit_nr_prv /= 0) THEN
     900       45682 :          IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
     901       45682 :          IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
     902       45682 :          IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
     903       45682 :          IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
     904             :       END IF
     905             : 
     906             :       CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
     907             :                             tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
     908             :                             beta, &
     909             :                             tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
     910             :                             unit_nr=unit_nr_prv, log_verbose=log_verbose, &
     911             :                             split_opt=split_opt, &
     912      141510 :                             move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
     913             : 
     914      141510 :       IF (PRESENT(pgrid_opt_1)) THEN
     915           0 :          IF (.NOT. new_1) THEN
     916           0 :             ALLOCATE (pgrid_opt_1)
     917           0 :             pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
     918             :          END IF
     919             :       END IF
     920             : 
     921      141510 :       IF (PRESENT(pgrid_opt_2)) THEN
     922           0 :          IF (.NOT. new_2) THEN
     923           0 :             ALLOCATE (pgrid_opt_2)
     924           0 :             pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
     925             :          END IF
     926             :       END IF
     927             : 
     928      141510 :       IF (PRESENT(pgrid_opt_3)) THEN
     929           0 :          IF (.NOT. new_3) THEN
     930           0 :             ALLOCATE (pgrid_opt_3)
     931           0 :             pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
     932             :          END IF
     933             :       END IF
     934             : 
     935      141510 :       do_batched = tensor_small%matrix_rep%do_batched > 0
     936             : 
     937      141510 :       tensors_remapped = .FALSE.
     938      141510 :       IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
     939             : 
     940      141510 :       IF (tensors_remapped .AND. do_batched) THEN
     941             :          CALL cp_warn(__LOCATION__, &
     942           0 :                       "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
     943             :       END IF
     944             : 
     945             :       ! optimize process grid during batched contraction
     946      141510 :       do_change_pgrid(:) = .FALSE.
     947      141510 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
     948             :          ASSOCIATE (storage => tensor_small%contraction_storage)
     949           0 :             CPASSERT(storage%static)
     950       79027 :             split = dbt_tas_info(tensor_large%matrix_rep)
     951             :             do_change_pgrid(:) = &
     952       79027 :                update_contraction_storage(storage, split_opt, split)
     953             : 
     954      314176 :             IF (ANY(do_change_pgrid)) THEN
     955         966 :                mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
     956             :                CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
     957         966 :                                          NINT(storage%nsplit_avg), own_comm=.TRUE.)
     958        2898 :                pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
     959             :             END IF
     960             : 
     961             :          END ASSOCIATE
     962             : 
     963       79027 :          IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
     964             :             ! check if new grid has better subgrid, if not there is no need to change process grid
     965        2898 :             pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
     966        2898 :             pdims_sub = split%mp_comm_group%num_pe_cart
     967             : 
     968        4830 :             pdim_ratio = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
     969        4830 :             pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, dp))/MINVAL(pdims_sub_opt)
     970         966 :             IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
     971           0 :                do_change_pgrid(1) = .FALSE.
     972           0 :                CALL dbt_tas_release_info(split_opt_avg)
     973             :             END IF
     974             :          END IF
     975             :       END IF
     976             : 
     977      141510 :       IF (unit_nr_prv /= 0) THEN
     978       45682 :          do_write_3 = .TRUE.
     979       45682 :          IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
     980       20164 :             IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
     981             :          END IF
     982             :          IF (do_write_3) THEN
     983       25556 :             CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
     984       25556 :             CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
     985             :          END IF
     986             :       END IF
     987             : 
     988      141510 :       IF (new_3) THEN
     989             :          ! need redistribute if we created new tensor for tensor 3
     990       14934 :          CALL dbt_scale(tensor_algn_3, beta)
     991       14934 :          CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
     992       14934 :          IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
     993             :          ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
     994             :          ! pointer to data of tensor_3
     995             :       END IF
     996             : 
     997             :       ! transfer contraction storage
     998      141510 :       CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
     999      141510 :       CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
    1000      141510 :       CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
    1001             : 
    1002      141510 :       IF (unit_nr_prv /= 0) THEN
    1003       45682 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
    1004       45682 :          IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
    1005             :       END IF
    1006             : 
    1007      141510 :       CALL dbt_destroy(tensor_algn_1)
    1008      141510 :       CALL dbt_destroy(tensor_algn_2)
    1009      141510 :       CALL dbt_destroy(tensor_algn_3)
    1010             : 
    1011      141510 :       IF (do_crop_1) THEN
    1012       65336 :          CALL dbt_destroy(tensor_crop_1)
    1013       65336 :          DEALLOCATE (tensor_crop_1)
    1014             :       END IF
    1015             : 
    1016      141510 :       IF (do_crop_2) THEN
    1017       63130 :          CALL dbt_destroy(tensor_crop_2)
    1018       63130 :          DEALLOCATE (tensor_crop_2)
    1019             :       END IF
    1020             : 
    1021      141510 :       IF (new_1) THEN
    1022       15030 :          CALL dbt_destroy(tensor_contr_1)
    1023       15030 :          DEALLOCATE (tensor_contr_1)
    1024             :       END IF
    1025      141510 :       IF (new_2) THEN
    1026        2581 :          CALL dbt_destroy(tensor_contr_2)
    1027        2581 :          DEALLOCATE (tensor_contr_2)
    1028             :       END IF
    1029      141510 :       IF (new_3) THEN
    1030       14934 :          CALL dbt_destroy(tensor_contr_3)
    1031       14934 :          DEALLOCATE (tensor_contr_3)
    1032             :       END IF
    1033             : 
    1034      141510 :       IF (PRESENT(move_data)) THEN
    1035       31063 :          IF (move_data) THEN
    1036       27225 :             CALL dbt_clear(tensor_1)
    1037       27225 :             CALL dbt_clear(tensor_2)
    1038             :          END IF
    1039             :       END IF
    1040             : 
    1041      141510 :       IF (unit_nr_prv > 0) THEN
    1042          10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1043          10 :          WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
    1044          10 :          WRITE (unit_nr_prv, '(A)') repeat("-", 80)
    1045             :       END IF
    1046             : 
    1047      422598 :       IF (ANY(do_change_pgrid)) THEN
    1048         966 :          pgrid_changed_any = .FALSE.
    1049         264 :          SELECT CASE (max_mm_dim)
    1050             :          CASE (1)
    1051         264 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1052             :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1053             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1054             :                                         pgrid_changed=pgrid_changed, &
    1055           0 :                                         unit_nr=unit_nr_prv)
    1056           0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1057             :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1058             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1059             :                                         pgrid_changed=pgrid_changed, &
    1060           0 :                                         unit_nr=unit_nr_prv)
    1061           0 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1062             :             END IF
    1063           0 :             IF (pgrid_changed_any) THEN
    1064           0 :                IF (tensor_2%matrix_rep%do_batched == 3) THEN
    1065             :                   ! set flag that process grid has been optimized to make sure that no grid optimizations are done
    1066             :                   ! in TAS multiply algorithm
    1067           0 :                   CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
    1068             :                END IF
    1069             :             END IF
    1070             :          CASE (2)
    1071         174 :             IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
    1072             :                CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1073             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1074             :                                         pgrid_changed=pgrid_changed, &
    1075         174 :                                         unit_nr=unit_nr_prv)
    1076         174 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1077             :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1078             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1079             :                                         pgrid_changed=pgrid_changed, &
    1080         174 :                                         unit_nr=unit_nr_prv)
    1081         174 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1082             :             END IF
    1083           8 :             IF (pgrid_changed_any) THEN
    1084         174 :                IF (tensor_3%matrix_rep%do_batched == 3) THEN
    1085         162 :                   CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
    1086             :                END IF
    1087             :             END IF
    1088             :          CASE (3)
    1089         528 :             IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
    1090             :                CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1091             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1092             :                                         pgrid_changed=pgrid_changed, &
    1093         214 :                                         unit_nr=unit_nr_prv)
    1094         214 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1095             :                CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
    1096             :                                         nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
    1097             :                                         pgrid_changed=pgrid_changed, &
    1098         214 :                                         unit_nr=unit_nr_prv)
    1099         214 :                IF (pgrid_changed) pgrid_changed_any = .TRUE.
    1100             :             END IF
    1101         966 :             IF (pgrid_changed_any) THEN
    1102         214 :                IF (tensor_1%matrix_rep%do_batched == 3) THEN
    1103         214 :                   CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
    1104             :                END IF
    1105             :             END IF
    1106             :          END SELECT
    1107         966 :          CALL dbt_tas_release_info(split_opt_avg)
    1108             :       END IF
    1109             : 
    1110      141510 :       IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
    1111             :          ! freeze TAS process grids if tensor grids were optimized
    1112       79027 :          CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
    1113       79027 :          CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
    1114       79027 :          CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
    1115             :       END IF
    1116             : 
    1117      141510 :       CALL dbt_tas_release_info(split_opt)
    1118             : 
    1119      141510 :       CALL timestop(handle)
    1120             : 
    1121      480498 :    END SUBROUTINE
    1122             : 
    1123             : ! **************************************************************************************************
    1124             : !> \brief align tensor index with data
    1125             : !> \author Patrick Seewald
    1126             : ! **************************************************************************************************
    1127     3820770 :    SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
    1128      424530 :                            tensor_out, contract_out, notcontract_out, indp_in, indp_out)
    1129             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1130             :       INTEGER, DIMENSION(:), INTENT(IN)            :: contract_in, notcontract_in
    1131             :       TYPE(dbt_type), INTENT(OUT)              :: tensor_out
    1132             :       INTEGER, DIMENSION(SIZE(contract_in)), &
    1133             :          INTENT(OUT)                               :: contract_out
    1134             :       INTEGER, DIMENSION(SIZE(notcontract_in)), &
    1135             :          INTENT(OUT)                               :: notcontract_out
    1136             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
    1137             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
    1138      424530 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
    1139             : 
    1140      424530 :       CALL dbt_align_index(tensor_in, tensor_out, order=align)
    1141      971774 :       contract_out = align(contract_in)
    1142      993118 :       notcontract_out = align(notcontract_in)
    1143     1540362 :       indp_out(align) = indp_in
    1144             : 
    1145      424530 :    END SUBROUTINE
    1146             : 
    1147             : ! **************************************************************************************************
    1148             : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1149             : !>        matrix multiplication. This routine reshapes the two largest of the three tensors.
    1150             : !>        Redistribution is avoided if tensors already in a consistent layout.
    1151             : !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
    1152             : !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
    1153             : !>                    1:1 correspondence with ind1_linked
    1154             : !> \param trans1 transpose flag of matrix rep. tensor 1
    1155             : !> \param trans2 transpose flag of matrix rep. tensor 2
    1156             : !> \param new1 whether a new tensor 1 was created
    1157             : !> \param new2 whether a new tensor 2 was created
    1158             : !> \param nodata1 don't copy data of tensor 1
    1159             : !> \param nodata2 don't copy data of tensor 2
    1160             : !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
    1161             : !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
    1162             : !> \param optimize_dist experimental: optimize distribution
    1163             : !> \param unit_nr output unit
    1164             : !> \author Patrick Seewald
    1165             : ! **************************************************************************************************
    1166      141510 :    SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
    1167      141510 :                                     ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
    1168             :                                     nodata1, nodata2, move_data_1, &
    1169             :                                     move_data_2, optimize_dist, unit_nr)
    1170             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor1
    1171             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor2
    1172             :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor1_out, tensor2_out
    1173             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_free, ind2_free
    1174             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_linked, ind2_linked
    1175             :       LOGICAL, INTENT(OUT)                        :: trans1, trans2
    1176             :       LOGICAL, INTENT(OUT)                        :: new1, new2
    1177             :       INTEGER, INTENT(OUT) :: ref_tensor
    1178             :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata1, nodata2
    1179             :       LOGICAL, INTENT(INOUT), OPTIONAL            :: move_data_1, move_data_2
    1180             :       LOGICAL, INTENT(IN), OPTIONAL               :: optimize_dist
    1181             :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1182             :       INTEGER                                     :: compat1, compat1_old, compat2, compat2_old, &
    1183             :                                                      unit_nr_prv
    1184      141510 :       TYPE(mp_cart_type)                          :: comm_2d
    1185      141510 :       TYPE(array_list)                            :: dist_list
    1186      141510 :       INTEGER, DIMENSION(:), ALLOCATABLE          :: mp_dims
    1187      990570 :       TYPE(dbt_distribution_type)             :: dist_in
    1188             :       INTEGER(KIND=int_8)                         :: nblkrows, nblkcols
    1189             :       LOGICAL                                     :: optimize_dist_prv
    1190      283020 :       INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
    1191      141510 :       INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
    1192             : 
    1193      141510 :       NULLIFY (tensor1_out, tensor2_out)
    1194             : 
    1195      141510 :       unit_nr_prv = prep_output_unit(unit_nr)
    1196             : 
    1197      141510 :       CALL blk_dims_tensor(tensor1, dims1)
    1198      141510 :       CALL blk_dims_tensor(tensor2, dims2)
    1199             : 
    1200      972780 :       IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
    1201      139880 :          ref_tensor = 1
    1202             :       ELSE
    1203        1630 :          ref_tensor = 2
    1204             :       END IF
    1205             : 
    1206      141510 :       IF (PRESENT(optimize_dist)) THEN
    1207         346 :          optimize_dist_prv = optimize_dist
    1208             :       ELSE
    1209             :          optimize_dist_prv = .FALSE.
    1210             :       END IF
    1211             : 
    1212      141510 :       compat1 = compat_map(tensor1%nd_index, ind1_linked)
    1213      141510 :       compat2 = compat_map(tensor2%nd_index, ind2_linked)
    1214      141510 :       compat1_old = compat1
    1215      141510 :       compat2_old = compat2
    1216             : 
    1217      141510 :       IF (unit_nr_prv > 0) THEN
    1218          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
    1219           6 :          SELECT CASE (compat1)
    1220             :          CASE (0)
    1221           6 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1222             :          CASE (1)
    1223           3 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1224             :          CASE (2)
    1225          10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1226             :          END SELECT
    1227          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
    1228           5 :          SELECT CASE (compat2)
    1229             :          CASE (0)
    1230           5 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1231             :          CASE (1)
    1232           4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1233             :          CASE (2)
    1234          10 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1235             :          END SELECT
    1236             :       END IF
    1237             : 
    1238      141510 :       new1 = .FALSE.
    1239      141510 :       new2 = .FALSE.
    1240             : 
    1241      141510 :       IF (compat1 == 0 .OR. optimize_dist_prv) THEN
    1242       17441 :          new1 = .TRUE.
    1243             :       END IF
    1244             : 
    1245      141510 :       IF (compat2 == 0 .OR. optimize_dist_prv) THEN
    1246       15082 :          new2 = .TRUE.
    1247             :       END IF
    1248             : 
    1249      141510 :       IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
    1250      139880 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1251       17279 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
    1252       51837 :             nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
    1253       34560 :             nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
    1254       17279 :             comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1255      120953 :             ALLOCATE (tensor1_out)
    1256             :             CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
    1257       17279 :                            nodata=nodata1, move_data=move_data_1)
    1258       17279 :             CALL comm_2d%free()
    1259       17279 :             compat1 = 1
    1260             :          ELSE
    1261      122601 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1262      122601 :             tensor1_out => tensor1
    1263             :          END IF
    1264      139880 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1265       14920 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
    1266           8 :                TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
    1267       14916 :             dist_in = dbt_distribution(tensor1_out)
    1268       14916 :             dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
    1269       14916 :             IF (compat1 == 1) THEN ! linked index is first 2d dimension
    1270             :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1271             :                ! get grid dimensions of linked index
    1272       22722 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1273        7574 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1274       53018 :                ALLOCATE (tensor2_out)
    1275             :                CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1276        7574 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
    1277        7342 :             ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
    1278             :                ! get distribution of linked index, tensor 2 must adopt this distribution
    1279             :                ! get grid dimensions of linked index
    1280       22026 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1281        7342 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1282       51394 :                ALLOCATE (tensor2_out)
    1283             :                CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1284        7342 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
    1285             :             ELSE
    1286           0 :                CPABORT("should not happen")
    1287             :             END IF
    1288       14916 :             compat2 = compat1
    1289             :          ELSE
    1290      124964 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1291      124964 :             tensor2_out => tensor2
    1292             :          END IF
    1293             :       ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
    1294        1630 :          IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
    1295         166 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
    1296         342 :             nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
    1297         334 :             nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
    1298         166 :             comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
    1299        1162 :             ALLOCATE (tensor2_out)
    1300         166 :             CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
    1301         166 :             CALL comm_2d%free()
    1302         166 :             compat2 = 1
    1303             :          ELSE
    1304        1464 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
    1305        1464 :             tensor2_out => tensor2
    1306             :          END IF
    1307        1630 :          IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
    1308         165 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
    1309           6 :                "compatible with", TRIM(tensor2%name)
    1310         162 :             dist_in = dbt_distribution(tensor2_out)
    1311         162 :             dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
    1312         162 :             IF (compat2 == 1) THEN
    1313         480 :                ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
    1314         160 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
    1315        1120 :                ALLOCATE (tensor1_out)
    1316             :                CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1317         160 :                               dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
    1318           2 :             ELSEIF (compat2 == 2) THEN
    1319           6 :                ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
    1320           2 :                CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
    1321          14 :                ALLOCATE (tensor1_out)
    1322             :                CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
    1323           2 :                               dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
    1324             :             ELSE
    1325           0 :                CPABORT("should not happen")
    1326             :             END IF
    1327         162 :             compat1 = compat2
    1328             :          ELSE
    1329        1468 :             IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
    1330        1468 :             tensor1_out => tensor1
    1331             :          END IF
    1332             :       END IF
    1333             : 
    1334       95771 :       SELECT CASE (compat1)
    1335             :       CASE (1)
    1336       95771 :          trans1 = .FALSE.
    1337             :       CASE (2)
    1338       45739 :          trans1 = .TRUE.
    1339             :       CASE DEFAULT
    1340      141510 :          CPABORT("should not happen")
    1341             :       END SELECT
    1342             : 
    1343       96324 :       SELECT CASE (compat2)
    1344             :       CASE (1)
    1345       96324 :          trans2 = .FALSE.
    1346             :       CASE (2)
    1347       45186 :          trans2 = .TRUE.
    1348             :       CASE DEFAULT
    1349      141510 :          CPABORT("should not happen")
    1350             :       END SELECT
    1351             : 
    1352      141510 :       IF (unit_nr_prv > 0) THEN
    1353          10 :          IF (compat1 .NE. compat1_old) THEN
    1354           6 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
    1355           0 :             SELECT CASE (compat1)
    1356             :             CASE (0)
    1357           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1358             :             CASE (1)
    1359           5 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1360             :             CASE (2)
    1361           6 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1362             :             END SELECT
    1363             :          END IF
    1364          10 :          IF (compat2 .NE. compat2_old) THEN
    1365           5 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
    1366           0 :             SELECT CASE (compat2)
    1367             :             CASE (0)
    1368           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1369             :             CASE (1)
    1370           4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1371             :             CASE (2)
    1372           5 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1373             :             END SELECT
    1374             :          END IF
    1375             :       END IF
    1376             : 
    1377      141510 :       IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
    1378      141510 :       IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
    1379             : 
    1380      141510 :    END SUBROUTINE
    1381             : 
    1382             : ! **************************************************************************************************
    1383             : !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
    1384             : !>        matrix multiplication. This routine reshapes the smallest of the three tensors.
    1385             : !> \param ind1 index that should be mapped to first matrix dimension
    1386             : !> \param ind2 index that should be mapped to second matrix dimension
    1387             : !> \param trans transpose flag of matrix rep.
    1388             : !> \param new whether a new tensor was created for tensor_out
    1389             : !> \param nodata don't copy tensor data
    1390             : !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
    1391             : !> \param unit_nr output unit
    1392             : !> \author Patrick Seewald
    1393             : ! **************************************************************************************************
    1394      141510 :    SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
    1395             :       TYPE(dbt_type), TARGET, INTENT(INOUT)   :: tensor_in
    1396             :       INTEGER, DIMENSION(:), INTENT(IN)           :: ind1, ind2
    1397             :       TYPE(dbt_type), POINTER, INTENT(OUT)    :: tensor_out
    1398             :       LOGICAL, INTENT(OUT)                        :: trans
    1399             :       LOGICAL, INTENT(OUT)                        :: new
    1400             :       LOGICAL, INTENT(IN), OPTIONAL               :: nodata, move_data
    1401             :       INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
    1402             :       INTEGER                                     :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
    1403             :       LOGICAL                                     :: nodata_prv
    1404             : 
    1405      141510 :       NULLIFY (tensor_out)
    1406             :       IF (PRESENT(nodata)) THEN
    1407      141510 :          nodata_prv = nodata
    1408             :       ELSE
    1409      141510 :          nodata_prv = .FALSE.
    1410             :       END IF
    1411             : 
    1412      141510 :       unit_nr_prv = prep_output_unit(unit_nr)
    1413             : 
    1414      141510 :       new = .FALSE.
    1415      141510 :       compat1 = compat_map(tensor_in%nd_index, ind1)
    1416      141510 :       compat2 = compat_map(tensor_in%nd_index, ind2)
    1417      141510 :       compat1_old = compat1; compat2_old = compat2
    1418      141510 :       IF (unit_nr_prv > 0) THEN
    1419          10 :          WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
    1420          10 :          IF (compat1 == 1 .AND. compat2 == 2) THEN
    1421           4 :             WRITE (unit_nr_prv, '(A)') "Normal"
    1422           6 :          ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1423           2 :             WRITE (unit_nr_prv, '(A)') "Transposed"
    1424             :          ELSE
    1425           4 :             WRITE (unit_nr_prv, '(A)') "Not compatible"
    1426             :          END IF
    1427             :       END IF
    1428      141510 :       IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
    1429             : 
    1430          22 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
    1431             : 
    1432         154 :          ALLOCATE (tensor_out)
    1433          22 :          CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
    1434          22 :          CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
    1435          22 :          compat1 = 1
    1436          22 :          compat2 = 2
    1437          22 :          new = .TRUE.
    1438             :       ELSE
    1439      141488 :          IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
    1440      141488 :          tensor_out => tensor_in
    1441             :       END IF
    1442             : 
    1443      141510 :       IF (compat1 == 1 .AND. compat2 == 2) THEN
    1444      104876 :          trans = .FALSE.
    1445       36634 :       ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1446       36634 :          trans = .TRUE.
    1447             :       ELSE
    1448           0 :          CPABORT("this should not happen")
    1449             :       END IF
    1450             : 
    1451      141510 :       IF (unit_nr_prv > 0) THEN
    1452          10 :          IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
    1453           4 :             WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
    1454           4 :             IF (compat1 == 1 .AND. compat2 == 2) THEN
    1455           4 :                WRITE (unit_nr_prv, '(A)') "Normal"
    1456           0 :             ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
    1457           0 :                WRITE (unit_nr_prv, '(A)') "Transposed"
    1458             :             ELSE
    1459           0 :                WRITE (unit_nr_prv, '(A)') "Not compatible"
    1460             :             END IF
    1461             :          END IF
    1462             :       END IF
    1463             : 
    1464      141510 :    END SUBROUTINE
    1465             : 
    1466             : ! **************************************************************************************************
    1467             : !> \brief update contraction storage that keeps track of process grids during a batched contraction
    1468             : !>        and decide if tensor process grid needs to be optimized
    1469             : !> \param split_opt optimized TAS process grid
    1470             : !> \param split current TAS process grid
    1471             : !> \author Patrick Seewald
    1472             : ! **************************************************************************************************
    1473       79027 :    FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
    1474             :       TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
    1475             :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split_opt
    1476             :       TYPE(dbt_tas_split_info), INTENT(IN)           :: split
    1477             :       INTEGER, DIMENSION(2) :: pdims, pdims_sub
    1478             :       LOGICAL, DIMENSION(2) :: do_change_pgrid
    1479             :       REAL(kind=dp) :: change_criterion, pdims_ratio
    1480             :       INTEGER :: nsplit_opt, nsplit
    1481             : 
    1482       79027 :       CPASSERT(ALLOCATED(split_opt%ngroup_opt))
    1483       79027 :       nsplit_opt = split_opt%ngroup_opt
    1484       79027 :       nsplit = split%ngroup
    1485             : 
    1486      237081 :       pdims = split%mp_comm%num_pe_cart
    1487             : 
    1488       79027 :       storage%ibatch = storage%ibatch + 1
    1489             : 
    1490             :       storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, dp) + REAL(nsplit_opt, dp)) &
    1491       79027 :                            /REAL(storage%ibatch, dp)
    1492             : 
    1493       79027 :       SELECT CASE (split_opt%split_rowcol)
    1494             :       CASE (rowsplit)
    1495       79027 :          pdims_ratio = REAL(pdims(1), dp)/pdims(2)
    1496             :       CASE (colsplit)
    1497       79027 :          pdims_ratio = REAL(pdims(2), dp)/pdims(1)
    1498             :       END SELECT
    1499             : 
    1500      237081 :       do_change_pgrid(:) = .FALSE.
    1501             : 
    1502             :       ! check for process grid dimensions
    1503      237081 :       pdims_sub = split%mp_comm_group%num_pe_cart
    1504      474162 :       change_criterion = MAXVAL(REAL(pdims_sub, dp))/MINVAL(pdims_sub)
    1505       79027 :       IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
    1506             : 
    1507             :       ! check for split factor
    1508       79027 :       change_criterion = MAX(REAL(nsplit, dp)/storage%nsplit_avg, REAL(storage%nsplit_avg, dp)/nsplit)
    1509       79027 :       IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
    1510             : 
    1511       79027 :    END FUNCTION
    1512             : 
    1513             : ! **************************************************************************************************
    1514             : !> \brief Check if 2d index is compatible with tensor index
    1515             : !> \author Patrick Seewald
    1516             : ! **************************************************************************************************
    1517      566040 :    FUNCTION compat_map(nd_index, compat_ind)
    1518             :       TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
    1519             :       INTEGER, DIMENSION(:), INTENT(IN)  :: compat_ind
    1520     1132080 :       INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
    1521     1132080 :       INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
    1522             :       INTEGER                            :: compat_map
    1523             : 
    1524      566040 :       CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
    1525             : 
    1526      566040 :       compat_map = 0
    1527      566040 :       IF (array_eq_i(map1, compat_ind)) THEN
    1528             :          compat_map = 1
    1529      257478 :       ELSEIF (array_eq_i(map2, compat_ind)) THEN
    1530      225613 :          compat_map = 2
    1531             :       END IF
    1532             : 
    1533      566040 :    END FUNCTION
    1534             : 
    1535             : ! **************************************************************************************************
    1536             : !> \brief
    1537             : !> \author Patrick Seewald
    1538             : ! **************************************************************************************************
    1539      424530 :    SUBROUTINE index_linked_sort(ind_ref, ind_dep)
    1540             :       INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
    1541      849060 :       INTEGER, DIMENSION(SIZE(ind_ref))    :: sort_indices
    1542             : 
    1543      424530 :       CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
    1544     1964892 :       ind_dep(:) = ind_dep(sort_indices)
    1545             : 
    1546      424530 :    END SUBROUTINE
    1547             : 
    1548             : ! **************************************************************************************************
    1549             : !> \brief
    1550             : !> \author Patrick Seewald
    1551             : ! **************************************************************************************************
    1552           0 :    FUNCTION opt_pgrid(tensor, tas_split_info)
    1553             :       TYPE(dbt_type), INTENT(IN) :: tensor
    1554             :       TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
    1555           0 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    1556           0 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    1557             :       TYPE(dbt_pgrid_type) :: opt_pgrid
    1558           0 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
    1559             : 
    1560           0 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    1561           0 :       CALL blk_dims_tensor(tensor, dims)
    1562           0 :       opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
    1563             : 
    1564           0 :       ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
    1565           0 :       CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
    1566           0 :    END FUNCTION
    1567             : 
    1568             : ! **************************************************************************************************
    1569             : !> \brief Copy tensor to tensor with modified index mapping
    1570             : !> \param map1_2d new index mapping
    1571             : !> \param map2_2d new index mapping
    1572             : !> \author Patrick Seewald
    1573             : ! **************************************************************************************************
    1574      292905 :    SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
    1575       32545 :                         mp_dims_1, mp_dims_2, name, nodata, move_data)
    1576             :       TYPE(dbt_type), INTENT(INOUT)      :: tensor_in
    1577             :       INTEGER, DIMENSION(:), INTENT(IN)      :: map1_2d, map2_2d
    1578             :       TYPE(dbt_type), INTENT(OUT)        :: tensor_out
    1579             :       CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
    1580             :       LOGICAL, INTENT(IN), OPTIONAL          :: nodata, move_data
    1581             :       CLASS(mp_comm_type), INTENT(IN), OPTIONAL          :: comm_2d
    1582             :       TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
    1583             :       INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
    1584             :       INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
    1585             :       CHARACTER(len=default_string_length)   :: name_tmp
    1586       32545 :       INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("blk_sizes")}$, &
    1587       32545 :                                                 ${varlist("nd_dist")}$
    1588      227815 :       TYPE(dbt_distribution_type)        :: dist
    1589       32545 :       TYPE(mp_cart_type) :: comm_2d_prv
    1590             :       INTEGER                                :: handle, i
    1591       32545 :       INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
    1592             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_remap'
    1593             :       LOGICAL                               :: nodata_prv
    1594       97635 :       TYPE(dbt_pgrid_type)              :: comm_nd
    1595             : 
    1596       32545 :       CALL timeset(routineN, handle)
    1597             : 
    1598       32545 :       IF (PRESENT(name)) THEN
    1599           0 :          name_tmp = name
    1600             :       ELSE
    1601       32545 :          name_tmp = tensor_in%name
    1602             :       END IF
    1603       32545 :       IF (PRESENT(dist1)) THEN
    1604        7734 :          CPASSERT(PRESENT(mp_dims_1))
    1605             :       END IF
    1606             : 
    1607       32545 :       IF (PRESENT(dist2)) THEN
    1608        7344 :          CPASSERT(PRESENT(mp_dims_2))
    1609             :       END IF
    1610             : 
    1611       32545 :       IF (PRESENT(comm_2d)) THEN
    1612       32357 :          comm_2d_prv = comm_2d
    1613             :       ELSE
    1614         188 :          comm_2d_prv = tensor_in%pgrid%mp_comm_2d
    1615             :       END IF
    1616             : 
    1617       32545 :       comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
    1618       32545 :       CALL mp_environ_pgrid(comm_nd, pdims, myploc)
    1619             : 
    1620             :       #:for ndim in ndims
    1621       64934 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1622       32389 :             CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
    1623             :          END IF
    1624             :       #:endfor
    1625             : 
    1626             :       #:for ndim in ndims
    1627       65086 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1628             :             #:for idim in range(1, ndim+1)
    1629       97483 :                IF (PRESENT(dist1)) THEN
    1630       53984 :                   IF (ANY(map1_2d == ${idim}$)) THEN
    1631       45780 :                      i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
    1632        7736 :                      CALL get_ith_array(dist1, i, nd_dist_${idim}$)
    1633             :                   END IF
    1634             :                END IF
    1635             : 
    1636       97483 :                IF (PRESENT(dist2)) THEN
    1637       58736 :                   IF (ANY(map2_2d == ${idim}$)) THEN
    1638       44032 :                      i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
    1639       14680 :                      CALL get_ith_array(dist2, i, nd_dist_${idim}$)
    1640             :                   END IF
    1641             :                END IF
    1642             : 
    1643       97483 :                IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
    1644      202473 :                   ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
    1645       67491 :                   CALL dbt_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
    1646             :                END IF
    1647             :             #:endfor
    1648             :             CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
    1649       32545 :                                              ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
    1650             :          END IF
    1651             :       #:endfor
    1652             : 
    1653             :       #:for ndim in ndims
    1654       65086 :          IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
    1655             :             CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
    1656       32545 :                             ${varlist("blk_sizes", nmax=ndim)}$)
    1657             :          END IF
    1658             :       #:endfor
    1659             : 
    1660       32545 :       IF (PRESENT(nodata)) THEN
    1661       14934 :          nodata_prv = nodata
    1662             :       ELSE
    1663             :          nodata_prv = .FALSE.
    1664             :       END IF
    1665             : 
    1666       32545 :       IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
    1667       32545 :       CALL dbt_distribution_destroy(dist)
    1668             : 
    1669       32545 :       CALL timestop(handle)
    1670       97635 :    END SUBROUTINE
    1671             : 
    1672             : ! **************************************************************************************************
    1673             : !> \brief Align index with data
    1674             : !> \param order permutation resulting from alignment
    1675             : !> \author Patrick Seewald
    1676             : ! **************************************************************************************************
    1677     3396240 :    SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
    1678             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in
    1679             :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1680      849060 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
    1681      849060 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
    1682             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1683             :          INTENT(OUT), OPTIONAL                        :: order
    1684      424530 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))     :: order_prv
    1685             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_align_index'
    1686             :       INTEGER                                         :: handle
    1687             : 
    1688      424530 :       CALL timeset(routineN, handle)
    1689             : 
    1690      424530 :       CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
    1691     2656194 :       order_prv = dbt_inverse_order([map1_2d, map2_2d])
    1692      424530 :       CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
    1693             : 
    1694     1540362 :       IF (PRESENT(order)) order = order_prv
    1695             : 
    1696      424530 :       CALL timestop(handle)
    1697      424530 :    END SUBROUTINE
    1698             : 
    1699             : ! **************************************************************************************************
    1700             : !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
    1701             : !> \author Patrick Seewald
    1702             : ! **************************************************************************************************
    1703     4815936 :    SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
    1704             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor_in
    1705             :       TYPE(dbt_type), INTENT(OUT)                 :: tensor_out
    1706             :       INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
    1707             :          INTENT(IN)                                   :: order
    1708             : 
    1709     2675520 :       TYPE(nd_to_2d_mapping)                          :: nd_index_blk_rs, nd_index_rs
    1710             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_permute_index'
    1711             :       INTEGER                                         :: handle
    1712             :       INTEGER                                         :: ndims
    1713             : 
    1714      535104 :       CALL timeset(routineN, handle)
    1715             : 
    1716      535104 :       ndims = ndims_tensor(tensor_in)
    1717             : 
    1718      535104 :       CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
    1719      535104 :       CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
    1720      535104 :       CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
    1721             : 
    1722      535104 :       tensor_out%matrix_rep => tensor_in%matrix_rep
    1723      535104 :       tensor_out%owns_matrix = .FALSE.
    1724             : 
    1725      535104 :       tensor_out%nd_index = nd_index_rs
    1726      535104 :       tensor_out%nd_index_blk = nd_index_blk_rs
    1727      535104 :       tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
    1728      535104 :       IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
    1729      535104 :          ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
    1730             :       END IF
    1731      535104 :       tensor_out%refcount => tensor_in%refcount
    1732      535104 :       CALL dbt_hold(tensor_out)
    1733             : 
    1734      535104 :       CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
    1735      535104 :       CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
    1736      535104 :       CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
    1737      535104 :       CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
    1738     1605312 :       ALLOCATE (tensor_out%nblks_local(ndims))
    1739     1070208 :       ALLOCATE (tensor_out%nfull_local(ndims))
    1740     1964468 :       tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
    1741     1964468 :       tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
    1742      535104 :       tensor_out%name = tensor_in%name
    1743      535104 :       tensor_out%valid = .TRUE.
    1744             : 
    1745      535104 :       IF (ALLOCATED(tensor_in%contraction_storage)) THEN
    1746      279430 :          ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
    1747      279430 :          CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
    1748      279430 :          CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
    1749             :       END IF
    1750             : 
    1751      535104 :       CALL timestop(handle)
    1752     1070208 :    END SUBROUTINE
    1753             : 
    1754             : ! **************************************************************************************************
    1755             : !> \brief Map contraction bounds to bounds referring to tensor indices
    1756             : !>        see dbt_contract for docu of dummy arguments
    1757             : !> \param bounds_t1 bounds mapped to tensor_1
    1758             : !> \param bounds_t2 bounds mapped to tensor_2
    1759             : !> \param do_crop_1 whether tensor 1 should be cropped
    1760             : !> \param do_crop_2 whether tensor 2 should be cropped
    1761             : !> \author Patrick Seewald
    1762             : ! **************************************************************************************************
    1763      169494 :    SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
    1764      169494 :                                         contract_1, notcontract_1, &
    1765      338988 :                                         contract_2, notcontract_2, &
    1766      169494 :                                         bounds_t1, bounds_t2, &
    1767      115240 :                                         bounds_1, bounds_2, bounds_3, &
    1768             :                                         do_crop_1, do_crop_2)
    1769             : 
    1770             :       TYPE(dbt_type), INTENT(IN)      :: tensor_1, tensor_2
    1771             :       INTEGER, DIMENSION(:), INTENT(IN)   :: contract_1, contract_2, &
    1772             :                                              notcontract_1, notcontract_2
    1773             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
    1774             :          INTENT(OUT)                                 :: bounds_t1
    1775             :       INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
    1776             :          INTENT(OUT)                                 :: bounds_t2
    1777             :       INTEGER, DIMENSION(2, SIZE(contract_1)), &
    1778             :          INTENT(IN), OPTIONAL                        :: bounds_1
    1779             :       INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
    1780             :          INTENT(IN), OPTIONAL                        :: bounds_2
    1781             :       INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
    1782             :          INTENT(IN), OPTIONAL                        :: bounds_3
    1783             :       LOGICAL, INTENT(OUT), OPTIONAL                 :: do_crop_1, do_crop_2
    1784             :       LOGICAL, DIMENSION(2)                          :: do_crop
    1785             : 
    1786      169494 :       do_crop = .FALSE.
    1787             : 
    1788      600796 :       bounds_t1(1, :) = 1
    1789      600796 :       CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
    1790             : 
    1791      633434 :       bounds_t2(1, :) = 1
    1792      633434 :       CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
    1793             : 
    1794      169494 :       IF (PRESENT(bounds_1)) THEN
    1795      168242 :          bounds_t1(:, contract_1) = bounds_1
    1796       72762 :          do_crop(1) = .TRUE.
    1797      168242 :          bounds_t2(:, contract_2) = bounds_1
    1798      169494 :          do_crop(2) = .TRUE.
    1799             :       END IF
    1800             : 
    1801      169494 :       IF (PRESENT(bounds_2)) THEN
    1802      229244 :          bounds_t1(:, notcontract_1) = bounds_2
    1803      169494 :          do_crop(1) = .TRUE.
    1804             :       END IF
    1805             : 
    1806      169494 :       IF (PRESENT(bounds_3)) THEN
    1807      252270 :          bounds_t2(:, notcontract_2) = bounds_3
    1808      169494 :          do_crop(2) = .TRUE.
    1809             :       END IF
    1810             : 
    1811      169494 :       IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
    1812      169494 :       IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
    1813             : 
    1814      384440 :    END SUBROUTINE
    1815             : 
    1816             : ! **************************************************************************************************
    1817             : !> \brief print tensor contraction indices in a human readable way
    1818             : !> \param indchar1 characters printed for index of tensor 1
    1819             : !> \param indchar2 characters printed for index of tensor 2
    1820             : !> \param indchar3 characters printed for index of tensor 3
    1821             : !> \param unit_nr output unit
    1822             : !> \author Patrick Seewald
    1823             : ! **************************************************************************************************
    1824      137046 :    SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
    1825             :       TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
    1826             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
    1827             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
    1828             :       CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
    1829             :       INTEGER, INTENT(IN) :: unit_nr
    1830      274092 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
    1831      274092 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
    1832      274092 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
    1833      274092 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
    1834      274092 :       INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
    1835      274092 :       INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
    1836             :       INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
    1837             : 
    1838      137046 :       unit_nr_prv = prep_output_unit(unit_nr)
    1839             : 
    1840      137046 :       IF (unit_nr_prv /= 0) THEN
    1841      137046 :          CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
    1842      137046 :          CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
    1843      137046 :          CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
    1844             :       END IF
    1845             : 
    1846      137046 :       IF (unit_nr_prv > 0) THEN
    1847          30 :          WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
    1848          30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
    1849         123 :          DO ichar1 = 1, SIZE(indchar1)
    1850         123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
    1851             :          END DO
    1852          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1853         120 :          DO ichar2 = 1, SIZE(indchar2)
    1854         120 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
    1855             :          END DO
    1856          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1857         123 :          DO ichar3 = 1, SIZE(indchar3)
    1858         123 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
    1859             :          END DO
    1860          30 :          WRITE (unit_nr_prv, '(A)') ")"
    1861             : 
    1862          30 :          WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
    1863          82 :          DO ichar1 = 1, SIZE(map11)
    1864          82 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
    1865             :          END DO
    1866          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1867          71 :          DO ichar1 = 1, SIZE(map12)
    1868          71 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
    1869             :          END DO
    1870          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
    1871          76 :          DO ichar2 = 1, SIZE(map21)
    1872          76 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
    1873             :          END DO
    1874          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1875          74 :          DO ichar2 = 1, SIZE(map22)
    1876          74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
    1877             :          END DO
    1878          30 :          WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
    1879          79 :          DO ichar3 = 1, SIZE(map31)
    1880          79 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
    1881             :          END DO
    1882          30 :          WRITE (unit_nr_prv, '(A1)', advance='no') "|"
    1883          74 :          DO ichar3 = 1, SIZE(map32)
    1884          74 :             WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
    1885             :          END DO
    1886          30 :          WRITE (unit_nr_prv, '(A)') ")"
    1887             :       END IF
    1888             : 
    1889      137046 :    END SUBROUTINE
    1890             : 
    1891             : ! **************************************************************************************************
    1892             : !> \brief Initialize batched contraction for this tensor.
    1893             : !>
    1894             : !>        Explanation: A batched contraction is a contraction performed in several consecutive steps
    1895             : !>        by specification of bounds in dbt_contract. This can be used to reduce memory by
    1896             : !>        a large factor. The routines dbt_batched_contract_init and
    1897             : !>        dbt_batched_contract_finalize should be called to define the scope of a batched
    1898             : !>        contraction as this enables important optimizations (adapting communication scheme to
    1899             : !>        batches and adapting process grid to multiplication algorithm). The routines
    1900             : !>        dbt_batched_contract_init and dbt_batched_contract_finalize must be
    1901             : !>        called before the first and after the last contraction step on all 3 tensors.
    1902             : !>
    1903             : !>        Requirements:
    1904             : !>        - the tensors are in a compatible matrix layout (see documentation of
    1905             : !>          `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
    1906             : !>          disabled and a warning is issued.
    1907             : !>        - within the scope of a batched contraction, it is not allowed to access or change tensor
    1908             : !>          data except by calling the routines dbt_contract & dbt_copy.
    1909             : !>        - the bounds affecting indices of the smallest tensor must not change in the course of a
    1910             : !>          batched contraction (todo: get rid of this requirement).
    1911             : !>
    1912             : !>        Side effects:
    1913             : !>        - the parallel layout (process grid and distribution) of all tensors may change. In order
    1914             : !>          to disable the process grid optimization including this side effect, call this routine
    1915             : !>          only on the smallest of the 3 tensors.
    1916             : !>
    1917             : !> \note
    1918             : !>        Note 1: for an example of batched contraction see `examples/dbt_example.F`.
    1919             : !>        (todo: the example is outdated and should be updated).
    1920             : !>
    1921             : !>        Note 2: it is meaningful to use this feature if the contraction consists of one batch only
    1922             : !>        but if multiple contractions involving the same 3 tensors are performed
    1923             : !>        (batched_contract_init and batched_contract_finalize must then be called before/after each
    1924             : !>        contraction call). The process grid is then optimized after the first contraction
    1925             : !>        and future contraction may profit from this optimization.
    1926             : !>
    1927             : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    1928             : !>                      a new range. The size should be the number of ranges plus one, the last
    1929             : !>                      element being the block index plus one of the last block in the last range.
    1930             : !>                      For internal load balancing optimizations, optionally specify the index
    1931             : !>                      ranges of batched contraction.
    1932             : !> \author Patrick Seewald
    1933             : ! **************************************************************************************************
    1934       98011 :    SUBROUTINE dbt_batched_contract_init(tensor, ${varlist("batch_range")}$)
    1935             :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1936             :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    1937      196022 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
    1938       98011 :       INTEGER, DIMENSION(:), ALLOCATABLE                 :: ${varlist("batch_range_prv")}$
    1939             :       LOGICAL :: static_range
    1940             : 
    1941       98011 :       CALL dbt_get_info(tensor, nblks_total=tdims)
    1942             : 
    1943       98011 :       static_range = .TRUE.
    1944             :       #:for idim in range(1, maxdim+1)
    1945       98011 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    1946      230064 :             IF (PRESENT(batch_range_${idim}$)) THEN
    1947      368064 :                ALLOCATE (batch_range_prv_${idim}$, source=batch_range_${idim}$)
    1948      230064 :                static_range = .FALSE.
    1949             :             ELSE
    1950      173614 :                ALLOCATE (batch_range_prv_${idim}$ (2))
    1951      173614 :                batch_range_prv_${idim}$ (1) = 1
    1952      173614 :                batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
    1953             :             END IF
    1954             :          END IF
    1955             :       #:endfor
    1956             : 
    1957       98011 :       ALLOCATE (tensor%contraction_storage)
    1958       98011 :       tensor%contraction_storage%static = static_range
    1959       98011 :       IF (static_range) THEN
    1960       66559 :          CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
    1961             :       END IF
    1962       98011 :       tensor%contraction_storage%nsplit_avg = 0.0_dp
    1963       98011 :       tensor%contraction_storage%ibatch = 0
    1964             : 
    1965             :       #:for ndim in range(1, maxdim+1)
    1966      196022 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    1967             :             CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
    1968       98011 :                                    ${varlist("batch_range_prv", nmax=ndim)}$)
    1969             :          END IF
    1970             :       #:endfor
    1971             : 
    1972       98011 :    END SUBROUTINE
    1973             : 
    1974             : ! **************************************************************************************************
    1975             : !> \brief finalize batched contraction. This performs all communication that has been postponed in
    1976             : !>         the contraction calls.
    1977             : !> \author Patrick Seewald
    1978             : ! **************************************************************************************************
    1979      196022 :    SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
    1980             :       TYPE(dbt_type), INTENT(INOUT) :: tensor
    1981             :       INTEGER, INTENT(IN), OPTIONAL :: unit_nr
    1982             :       LOGICAL :: do_write
    1983             :       INTEGER :: unit_nr_prv, handle
    1984             : 
    1985       98011 :       CALL tensor%pgrid%mp_comm_2d%sync()
    1986       98011 :       CALL timeset("dbt_total", handle)
    1987       98011 :       unit_nr_prv = prep_output_unit(unit_nr)
    1988             : 
    1989       98011 :       do_write = .FALSE.
    1990             : 
    1991       98011 :       IF (tensor%contraction_storage%static) THEN
    1992       66559 :          IF (tensor%matrix_rep%do_batched > 0) THEN
    1993       66559 :             IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
    1994             :          END IF
    1995       66559 :          CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
    1996             :       END IF
    1997             : 
    1998       98011 :       IF (do_write .AND. unit_nr_prv /= 0) THEN
    1999       15406 :          IF (unit_nr_prv > 0) THEN
    2000             :             WRITE (unit_nr_prv, "(T2,A)") &
    2001           0 :                "FINALIZING BATCHED PROCESSING OF MATMUL"
    2002             :          END IF
    2003       15406 :          CALL dbt_write_tensor_info(tensor, unit_nr_prv)
    2004       15406 :          CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
    2005             :       END IF
    2006             : 
    2007       98011 :       CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
    2008       98011 :       DEALLOCATE (tensor%contraction_storage)
    2009       98011 :       CALL tensor%pgrid%mp_comm_2d%sync()
    2010       98011 :       CALL timestop(handle)
    2011             : 
    2012       98011 :    END SUBROUTINE
    2013             : 
    2014             : ! **************************************************************************************************
    2015             : !> \brief change the process grid of a tensor
    2016             : !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
    2017             : !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
    2018             : !>                      a new range. The size should be the number of ranges plus one, the last
    2019             : !>                      element being the block index plus one of the last block in the last range.
    2020             : !>                      For internal load balancing optimizations, optionally specify the index
    2021             : !>                      ranges of batched contraction.
    2022             : !> \author Patrick Seewald
    2023             : ! **************************************************************************************************
    2024         776 :    SUBROUTINE dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
    2025             :                                nodata, pgrid_changed, unit_nr)
    2026             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2027             :       TYPE(dbt_pgrid_type), INTENT(IN)               :: pgrid
    2028             :       INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
    2029             :       !!
    2030             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2031             :       LOGICAL, INTENT(OUT), OPTIONAL                     :: pgrid_changed
    2032             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2033             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_change_pgrid'
    2034             :       CHARACTER(default_string_length)                   :: name
    2035             :       INTEGER                                            :: handle
    2036         776 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ${varlist("bs")}$, &
    2037         776 :                                                             ${varlist("dist")}$
    2038        1552 :       INTEGER, DIMENSION(ndims_tensor(tensor))           :: pcoord, pcoord_ref, pdims, pdims_ref, &
    2039        1552 :                                                             tdims
    2040        5432 :       TYPE(dbt_type)                                 :: t_tmp
    2041        5432 :       TYPE(dbt_distribution_type)                    :: dist
    2042        1552 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2043             :       INTEGER, &
    2044        1552 :          DIMENSION(ndims_matrix_column(tensor))    :: map2
    2045        1552 :       LOGICAL, DIMENSION(ndims_tensor(tensor))             :: mem_aware
    2046         776 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
    2047             :       INTEGER :: ind1, ind2, batch_size, ibatch
    2048             : 
    2049         776 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
    2050         776 :       CALL mp_environ_pgrid(pgrid, pdims, pcoord)
    2051         776 :       CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
    2052             : 
    2053         800 :       IF (ALL(pdims == pdims_ref)) THEN
    2054           8 :          IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
    2055           8 :             IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
    2056             :                RETURN
    2057             :             END IF
    2058             :          END IF
    2059             :       END IF
    2060             : 
    2061         768 :       CALL timeset(routineN, handle)
    2062             : 
    2063             :       #:for idim in range(1, maxdim+1)
    2064        3072 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2065        2304 :             mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
    2066        2304 :             IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
    2067             :          END IF
    2068             :       #:endfor
    2069             : 
    2070         768 :       CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
    2071             : 
    2072             :       #:for idim in range(1, maxdim+1)
    2073        3072 :          IF (ndims_tensor(tensor) >= ${idim}$) THEN
    2074        6912 :             ALLOCATE (bs_${idim}$ (dbt_nblks_total(tensor, ${idim}$)))
    2075        2304 :             CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
    2076        6912 :             ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
    2077       16860 :             dist_${idim}$ = 0
    2078        2304 :             IF (mem_aware(${idim}$)) THEN
    2079        6300 :                DO ibatch = 1, nbatch(${idim}$)
    2080        3996 :                   ind1 = batch_range_${idim}$ (ibatch)
    2081        3996 :                   ind2 = batch_range_${idim}$ (ibatch + 1) - 1
    2082        3996 :                   batch_size = ind2 - ind1 + 1
    2083             :                   CALL dbt_default_distvec(batch_size, pdims(${idim}$), &
    2084        6300 :                                            bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
    2085             :                END DO
    2086             :             ELSE
    2087           0 :                CALL dbt_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
    2088             :             END IF
    2089             :          END IF
    2090             :       #:endfor
    2091             : 
    2092         768 :       CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
    2093             :       #:for ndim in ndims
    2094        1536 :          IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2095         768 :             CALL dbt_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
    2096         768 :             CALL dbt_create(t_tmp, name, dist, map1, map2, ${varlist("bs", nmax=ndim)}$)
    2097             :          END IF
    2098             :       #:endfor
    2099         768 :       CALL dbt_distribution_destroy(dist)
    2100             : 
    2101         768 :       IF (PRESENT(nodata)) THEN
    2102           0 :          IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2103             :       ELSE
    2104         768 :          CALL dbt_copy_expert(tensor, t_tmp, move_data=.TRUE.)
    2105             :       END IF
    2106             : 
    2107         768 :       CALL dbt_copy_contraction_storage(tensor, t_tmp)
    2108             : 
    2109         768 :       CALL dbt_destroy(tensor)
    2110         768 :       tensor = t_tmp
    2111             : 
    2112         768 :       IF (PRESENT(unit_nr)) THEN
    2113         768 :          IF (unit_nr > 0) THEN
    2114           0 :             WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
    2115           0 :             WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
    2116           0 :             CALL dbt_write_split_info(pgrid, unit_nr)
    2117             :          END IF
    2118             :       END IF
    2119             : 
    2120         768 :       IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
    2121             : 
    2122         768 :       CALL timestop(handle)
    2123         776 :    END SUBROUTINE
    2124             : 
    2125             : ! **************************************************************************************************
    2126             : !> \brief map tensor to a new 2d process grid for the matrix representation.
    2127             : !> \author Patrick Seewald
    2128             : ! **************************************************************************************************
    2129         776 :    SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
    2130             :       TYPE(dbt_type), INTENT(INOUT)                  :: tensor
    2131             :       TYPE(mp_cart_type), INTENT(IN)               :: mp_comm
    2132             :       INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
    2133             :       LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
    2134             :       INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
    2135             :       LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
    2136             :       INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
    2137        1552 :       INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
    2138        1552 :       INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
    2139        1552 :       INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
    2140        2328 :       TYPE(dbt_pgrid_type) :: pgrid
    2141         776 :       INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
    2142         776 :       INTEGER, DIMENSION(:), ALLOCATABLE :: array
    2143             :       INTEGER :: idim
    2144             : 
    2145         776 :       CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
    2146         776 :       CALL blk_dims_tensor(tensor, dims)
    2147             : 
    2148         776 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2149             :          ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
    2150        3104 :             nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
    2151             :             ! for good load balancing the process grid dimensions should be chosen adapted to the
    2152             :             ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
    2153             :             ! the number of batches (number of index ranges).
    2154        3880 :             DO idim = 1, ndims_tensor(tensor)
    2155        2328 :                CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
    2156        2328 :                dims(idim) = array(nbatches(idim) + 1) - array(1)
    2157        2328 :                DEALLOCATE (array)
    2158        2328 :                dims(idim) = dims(idim)/nbatches(idim)
    2159        5432 :                IF (dims(idim) <= 0) dims(idim) = 1
    2160             :             END DO
    2161             :          END ASSOCIATE
    2162             :       END IF
    2163             : 
    2164         776 :       pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
    2165         776 :       IF (ALLOCATED(tensor%contraction_storage)) THEN
    2166             :          #:for ndim in range(1, maxdim+1)
    2167        1552 :             IF (ndims_tensor(tensor) == ${ndim}$) THEN
    2168         776 :                CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
    2169             :                CALL dbt_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
    2170         776 :                                      nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2171             :             END IF
    2172             :          #:endfor
    2173             :       ELSE
    2174           0 :          CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
    2175             :       END IF
    2176         776 :       CALL dbt_pgrid_destroy(pgrid)
    2177             : 
    2178         776 :    END SUBROUTINE
    2179             : 
    2180      127426 : END MODULE

Generated by: LCOV version 1.15