LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_test.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 135 211 64.0 %
Date: 2024-11-21 06:45:46 Functions: 5 6 83.3 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief testing infrastructure for tall-and-skinny matrices
      10             : !> \author Patrick Seewald
      11             : ! **************************************************************************************************
      12             : MODULE dbt_tas_test
      13             :    USE dbm_api,                         ONLY: &
      14             :         dbm_add, dbm_checksum, dbm_create, dbm_distribution_new, dbm_distribution_obj, &
      15             :         dbm_distribution_release, dbm_finalize, dbm_get_col_block_sizes, dbm_get_name, &
      16             :         dbm_get_row_block_sizes, dbm_maxabs, dbm_multiply, dbm_redistribute, dbm_release, &
      17             :         dbm_scale, dbm_type
      18             :    USE dbm_tests,                       ONLY: generate_larnv_seed
      19             :    USE dbt_tas_base,                    ONLY: &
      20             :         dbt_tas_convert_to_dbm, dbt_tas_create, dbt_tas_distribution_new, dbt_tas_finalize, &
      21             :         dbt_tas_get_stored_coordinates, dbt_tas_info, dbt_tas_nblkcols_total, &
      22             :         dbt_tas_nblkrows_total, dbt_tas_put_block
      23             :    USE dbt_tas_global,                  ONLY: dbt_tas_blk_size_arb,&
      24             :                                               dbt_tas_default_distvec,&
      25             :                                               dbt_tas_dist_cyclic
      26             :    USE dbt_tas_mm,                      ONLY: dbt_tas_multiply
      27             :    USE dbt_tas_split,                   ONLY: dbt_tas_get_split_info,&
      28             :                                               dbt_tas_mp_comm
      29             :    USE dbt_tas_types,                   ONLY: dbt_tas_distribution_type,&
      30             :                                               dbt_tas_type
      31             :    USE kinds,                           ONLY: dp,&
      32             :                                               int_8
      33             :    USE message_passing,                 ONLY: mp_cart_type,&
      34             :                                               mp_comm_type
      35             : #include "../../base/base_uses.f90"
      36             : 
      37             :    IMPLICIT NONE
      38             :    PRIVATE
      39             : 
      40             :    PUBLIC :: &
      41             :       dbt_tas_benchmark_mm, &
      42             :       dbt_tas_checksum, &
      43             :       dbt_tas_random_bsizes, &
      44             :       dbt_tas_setup_test_matrix, &
      45             :       dbt_tas_test_mm, &
      46             :       dbt_tas_reset_randmat_seed
      47             : 
      48             :    INTEGER, SAVE :: randmat_counter = 0
      49             :    INTEGER, PARAMETER, PRIVATE :: rand_seed_init = 12341313
      50             : 
      51             : CONTAINS
      52             : 
      53             : ! **************************************************************************************************
      54             : !> \brief Setup tall-and-skinny matrix for testing
      55             : !> \param matrix ...
      56             : !> \param mp_comm_out ...
      57             : !> \param mp_comm ...
      58             : !> \param nrows ...
      59             : !> \param ncols ...
      60             : !> \param rbsizes ...
      61             : !> \param cbsizes ...
      62             : !> \param dist_splitsize ...
      63             : !> \param name ...
      64             : !> \param sparsity ...
      65             : !> \param reuse_comm ...
      66             : !> \author Patrick Seewald
      67             : ! **************************************************************************************************
      68          72 :    SUBROUTINE dbt_tas_setup_test_matrix(matrix, mp_comm_out, mp_comm, nrows, ncols, rbsizes, &
      69          12 :                                         cbsizes, dist_splitsize, name, sparsity, reuse_comm)
      70             :       TYPE(dbt_tas_type), INTENT(OUT)                    :: matrix
      71             :       TYPE(mp_cart_type), INTENT(OUT)                    :: mp_comm_out
      72             : 
      73             :       CLASS(mp_comm_type), INTENT(IN)                     :: mp_comm
      74             :       INTEGER(KIND=int_8), INTENT(IN)                    :: nrows, ncols
      75             :       INTEGER, DIMENSION(nrows), INTENT(IN)              :: rbsizes
      76             :       INTEGER, DIMENSION(ncols), INTENT(IN)              :: cbsizes
      77             :       INTEGER, DIMENSION(2), INTENT(IN)                  :: dist_splitsize
      78             :       CHARACTER(len=*), INTENT(IN)                       :: name
      79             :       REAL(KIND=dp), INTENT(IN)                          :: sparsity
      80             :       LOGICAL, INTENT(IN), OPTIONAL                      :: reuse_comm
      81             : 
      82             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_setup_test_matrix'
      83             : 
      84             :       INTEGER                                            :: col_size, handle, max_col_size, max_nze, &
      85             :                                                             max_row_size, mynode, node_holds_blk, &
      86             :                                                             nze, row_size
      87             :       INTEGER(KIND=int_8)                                :: col, col_s, ncol, nrow, row, row_s
      88             :       INTEGER, DIMENSION(2)                              :: pdims
      89             :       INTEGER, DIMENSION(4)                              :: iseed, jseed
      90             :       LOGICAL                                            :: reuse_comm_prv, tr
      91          12 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: values
      92             :       REAL(KIND=dp), DIMENSION(1)                        :: rn
      93          24 :       TYPE(dbt_tas_blk_size_arb)                         :: cbsize_obj, rbsize_obj
      94             :       TYPE(dbt_tas_dist_cyclic)                          :: col_dist_obj, row_dist_obj
      95          60 :       TYPE(dbt_tas_distribution_type)                    :: dist
      96             : 
      97             :       ! we don't reserve blocks prior to putting them, so this time is meaningless and should not
      98             :       ! be considered in benchmark!
      99          12 :       CALL timeset(routineN, handle)
     100             : 
     101             :       ! Check that the counter was initialised (or has not overflowed)
     102          12 :       CPASSERT(randmat_counter .NE. 0)
     103             :       ! the counter goes into the seed. Every new call gives a new random matrix
     104          12 :       randmat_counter = randmat_counter + 1
     105             : 
     106          12 :       IF (PRESENT(reuse_comm)) THEN
     107           0 :          reuse_comm_prv = reuse_comm
     108             :       ELSE
     109             :          reuse_comm_prv = .FALSE.
     110             :       END IF
     111             : 
     112           0 :       IF (reuse_comm_prv) THEN
     113           0 :          mp_comm_out = mp_comm
     114             :       ELSE
     115          12 :          mp_comm_out = dbt_tas_mp_comm(mp_comm, nrows, ncols)
     116             :       END IF
     117             : 
     118          12 :       mynode = mp_comm_out%mepos
     119          36 :       pdims = mp_comm_out%num_pe_cart
     120             : 
     121          12 :       row_dist_obj = dbt_tas_dist_cyclic(dist_splitsize(1), pdims(1), nrows)
     122          12 :       col_dist_obj = dbt_tas_dist_cyclic(dist_splitsize(2), pdims(2), ncols)
     123             : 
     124          12 :       rbsize_obj = dbt_tas_blk_size_arb(rbsizes)
     125          12 :       cbsize_obj = dbt_tas_blk_size_arb(cbsizes)
     126             : 
     127          12 :       CALL dbt_tas_distribution_new(dist, mp_comm_out, row_dist_obj, col_dist_obj)
     128             :       CALL dbt_tas_create(matrix, name=TRIM(name), dist=dist, &
     129          12 :                           row_blk_size=rbsize_obj, col_blk_size=cbsize_obj, own_dist=.TRUE.)
     130             : 
     131         532 :       max_row_size = MAXVAL(rbsizes)
     132         532 :       max_col_size = MAXVAL(cbsizes)
     133          12 :       max_nze = max_row_size*max_col_size
     134             : 
     135          12 :       nrow = dbt_tas_nblkrows_total(matrix)
     136          12 :       ncol = dbt_tas_nblkcols_total(matrix)
     137             : 
     138          48 :       ALLOCATE (values(max_row_size, max_col_size))
     139             : 
     140          12 :       jseed = generate_larnv_seed(7, 42, 3, 42, randmat_counter)
     141             : 
     142         532 :       DO row = 1, dbt_tas_nblkrows_total(matrix)
     143       13332 :          DO col = 1, dbt_tas_nblkcols_total(matrix)
     144       12800 :             CALL dlarnv(1, jseed, 1, rn)
     145       13320 :             IF (rn(1) .LT. sparsity) THEN
     146        1306 :                tr = .FALSE.
     147        1306 :                row_s = row; col_s = col
     148        1306 :                CALL dbt_tas_get_stored_coordinates(matrix, row_s, col_s, node_holds_blk)
     149             : 
     150        1306 :                IF (node_holds_blk .EQ. mynode) THEN
     151         653 :                   row_size = rbsize_obj%data(row_s)
     152         653 :                   col_size = cbsize_obj%data(col_s)
     153         653 :                   nze = row_size*col_size
     154         653 :                   iseed = generate_larnv_seed(INT(row_s), INT(nrow), INT(col_s), INT(ncol), randmat_counter)
     155         653 :                   CALL dlarnv(1, iseed, max_nze, values)
     156         653 :                   CALL dbt_tas_put_block(matrix, row_s, col_s, values(1:row_size, 1:col_size))
     157             :                END IF
     158             :             END IF
     159             :          END DO
     160             :       END DO
     161             : 
     162          12 :       CALL dbt_tas_finalize(matrix)
     163             : 
     164          12 :       CALL timestop(handle)
     165             : 
     166          48 :    END SUBROUTINE
     167             : 
     168             : ! **************************************************************************************************
     169             : !> \brief Benchmark routine. Due to random sparsity (as opposed to structured sparsity pattern),
     170             : !>        this may not be representative for actual applications.
     171             : !> \param transa ...
     172             : !> \param transb ...
     173             : !> \param transc ...
     174             : !> \param matrix_a ...
     175             : !> \param matrix_b ...
     176             : !> \param matrix_c ...
     177             : !> \param compare_dbm ...
     178             : !> \param filter_eps ...
     179             : !> \param io_unit ...
     180             : !> \author Patrick Seewald
     181             : ! **************************************************************************************************
     182           0 :    SUBROUTINE dbt_tas_benchmark_mm(transa, transb, transc, matrix_a, matrix_b, matrix_c, compare_dbm, filter_eps, io_unit)
     183             : 
     184             :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     185             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_a, matrix_b, matrix_c
     186             :       LOGICAL, INTENT(IN)                                :: compare_dbm
     187             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     188             :       INTEGER, INTENT(IN), OPTIONAL                      :: io_unit
     189             : 
     190             :       INTEGER                                            :: handle1, handle2
     191           0 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: cd_a, cd_b, cd_c, col_block_sizes_a, &
     192           0 :          col_block_sizes_b, col_block_sizes_c, rd_a, rd_b, rd_c, row_block_sizes_a, &
     193           0 :          row_block_sizes_b, row_block_sizes_c
     194             :       INTEGER, DIMENSION(2)                              :: npdims
     195             :       TYPE(dbm_distribution_obj)                         :: dist_a, dist_b, dist_c
     196             :       TYPE(dbm_type)                                     :: dbm_a, dbm_a_mm, dbm_b, dbm_b_mm, dbm_c, &
     197             :                                                             dbm_c_mm
     198           0 :       TYPE(mp_cart_type)                                 :: comm_dbm, mp_comm
     199             : 
     200             : !
     201             : ! TODO: Dedup with code in dbt_tas_test_mm.
     202             : !
     203           0 :       IF (PRESENT(io_unit)) THEN
     204           0 :       IF (io_unit > 0) THEN
     205           0 :          WRITE (io_unit, "(A)") "starting tall-and-skinny benchmark"
     206             :       END IF
     207             :       END IF
     208           0 :       CALL timeset("benchmark_tas_mm", handle1)
     209             :       CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a, matrix_b, &
     210             :                             0.0_dp, matrix_c, &
     211           0 :                             filter_eps=filter_eps, unit_nr=io_unit)
     212           0 :       CALL timestop(handle1)
     213           0 :       IF (PRESENT(io_unit)) THEN
     214           0 :       IF (io_unit > 0) THEN
     215           0 :          WRITE (io_unit, "(A)") "tall-and-skinny benchmark completed"
     216             :       END IF
     217             :       END IF
     218             : 
     219           0 :       IF (compare_dbm) THEN
     220           0 :          CALL dbt_tas_convert_to_dbm(matrix_a, dbm_a)
     221           0 :          CALL dbt_tas_convert_to_dbm(matrix_b, dbm_b)
     222           0 :          CALL dbt_tas_convert_to_dbm(matrix_c, dbm_c)
     223             : 
     224           0 :          CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     225           0 :          npdims(:) = 0
     226           0 :          CALL comm_dbm%create(mp_comm, 2, npdims)
     227             : 
     228           0 :          ALLOCATE (rd_a(SIZE(dbm_get_row_block_sizes(dbm_a))))
     229           0 :          ALLOCATE (rd_b(SIZE(dbm_get_row_block_sizes(dbm_b))))
     230           0 :          ALLOCATE (rd_c(SIZE(dbm_get_row_block_sizes(dbm_c))))
     231           0 :          ALLOCATE (cd_a(SIZE(dbm_get_col_block_sizes(dbm_a))))
     232           0 :          ALLOCATE (cd_b(SIZE(dbm_get_col_block_sizes(dbm_b))))
     233           0 :          ALLOCATE (cd_c(SIZE(dbm_get_col_block_sizes(dbm_c))))
     234             : 
     235             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_a))), &
     236           0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_a), rd_a)
     237             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_a))), &
     238           0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_a), cd_a)
     239             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_b))), &
     240           0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_b), rd_b)
     241             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_b))), &
     242           0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_b), cd_b)
     243             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_c))), &
     244           0 :                                       npdims(1), dbm_get_row_block_sizes(dbm_c), rd_c)
     245             :          CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_c))), &
     246           0 :                                       npdims(2), dbm_get_col_block_sizes(dbm_c), cd_c)
     247             : 
     248           0 :          CALL dbm_distribution_new(dist_a, comm_dbm, rd_a, cd_a)
     249           0 :          CALL dbm_distribution_new(dist_b, comm_dbm, rd_b, cd_b)
     250           0 :          CALL dbm_distribution_new(dist_c, comm_dbm, rd_c, cd_c)
     251           0 :          DEALLOCATE (rd_a, rd_b, rd_c, cd_a, cd_b, cd_c)
     252             : 
     253             :          ! Store pointers in intermediate variables to workaround a CCE error.
     254           0 :          row_block_sizes_a => dbm_get_row_block_sizes(dbm_a)
     255           0 :          col_block_sizes_a => dbm_get_col_block_sizes(dbm_a)
     256           0 :          row_block_sizes_b => dbm_get_row_block_sizes(dbm_b)
     257           0 :          col_block_sizes_b => dbm_get_col_block_sizes(dbm_b)
     258           0 :          row_block_sizes_c => dbm_get_row_block_sizes(dbm_c)
     259           0 :          col_block_sizes_c => dbm_get_col_block_sizes(dbm_c)
     260             : 
     261             :          CALL dbm_create(matrix=dbm_a_mm, name=dbm_get_name(dbm_a), dist=dist_a, &
     262           0 :                          row_block_sizes=row_block_sizes_a, col_block_sizes=col_block_sizes_a)
     263             : 
     264             :          CALL dbm_create(matrix=dbm_b_mm, name=dbm_get_name(dbm_b), dist=dist_b, &
     265           0 :                          row_block_sizes=row_block_sizes_b, col_block_sizes=col_block_sizes_b)
     266             : 
     267             :          CALL dbm_create(matrix=dbm_c_mm, name=dbm_get_name(dbm_c), dist=dist_c, &
     268           0 :                          row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     269             : 
     270           0 :          CALL dbm_finalize(dbm_a_mm)
     271           0 :          CALL dbm_finalize(dbm_b_mm)
     272           0 :          CALL dbm_finalize(dbm_c_mm)
     273             : 
     274           0 :          CALL dbm_redistribute(dbm_a, dbm_a_mm)
     275           0 :          CALL dbm_redistribute(dbm_b, dbm_b_mm)
     276           0 :          IF (PRESENT(io_unit)) THEN
     277           0 :          IF (io_unit > 0) THEN
     278           0 :             WRITE (io_unit, "(A)") "starting dbm benchmark"
     279             :          END IF
     280             :          END IF
     281           0 :          CALL timeset("benchmark_block_mm", handle2)
     282             :          CALL dbm_multiply(transa, transb, 1.0_dp, dbm_a_mm, dbm_b_mm, &
     283           0 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     284           0 :          CALL timestop(handle2)
     285           0 :          IF (PRESENT(io_unit)) THEN
     286           0 :          IF (io_unit > 0) THEN
     287           0 :             WRITE (io_unit, "(A)") "dbm benchmark completed"
     288             :          END IF
     289             :          END IF
     290             : 
     291           0 :          CALL dbm_release(dbm_a)
     292           0 :          CALL dbm_release(dbm_b)
     293           0 :          CALL dbm_release(dbm_c)
     294           0 :          CALL dbm_release(dbm_a_mm)
     295           0 :          CALL dbm_release(dbm_b_mm)
     296           0 :          CALL dbm_release(dbm_c_mm)
     297           0 :          CALL dbm_distribution_release(dist_a)
     298           0 :          CALL dbm_distribution_release(dist_b)
     299           0 :          CALL dbm_distribution_release(dist_c)
     300             : 
     301           0 :          CALL comm_dbm%free()
     302             :       END IF
     303             : 
     304           0 :    END SUBROUTINE
     305             : 
     306             : ! **************************************************************************************************
     307             : !> \brief Test tall-and-skinny matrix multiplication for accuracy
     308             : !> \param transa ...
     309             : !> \param transb ...
     310             : !> \param transc ...
     311             : !> \param matrix_a ...
     312             : !> \param matrix_b ...
     313             : !> \param matrix_c ...
     314             : !> \param filter_eps ...
     315             : !> \param unit_nr ...
     316             : !> \param log_verbose ...
     317             : !> \author Patrick Seewald
     318             : ! **************************************************************************************************
     319          48 :    SUBROUTINE dbt_tas_test_mm(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, unit_nr, log_verbose)
     320             :       LOGICAL, INTENT(IN)                                :: transa, transb, transc
     321             :       TYPE(dbt_tas_type), INTENT(INOUT)                  :: matrix_a, matrix_b, matrix_c
     322             :       REAL(KIND=dp), INTENT(IN), OPTIONAL                :: filter_eps
     323             :       INTEGER, INTENT(IN)                                :: unit_nr
     324             :       LOGICAL, INTENT(IN), OPTIONAL                      :: log_verbose
     325             : 
     326             :       REAL(KIND=dp), PARAMETER                           :: test_tol = 1.0E-10_dp
     327             : 
     328             :       CHARACTER(LEN=8)                                   :: status_str
     329             :       INTEGER                                            :: io_unit, mynode
     330          48 :       INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: cd_a, cd_b, cd_c, col_block_sizes_a, &
     331          48 :          col_block_sizes_b, col_block_sizes_c, rd_a, rd_b, rd_c, row_block_sizes_a, &
     332          48 :          row_block_sizes_b, row_block_sizes_c
     333             :       INTEGER, DIMENSION(2)                              :: npdims
     334             :       LOGICAL                                            :: abort, transa_prv, transb_prv
     335             :       REAL(KIND=dp)                                      :: norm, rc_cs, sq_cs
     336             :       TYPE(dbm_distribution_obj)                         :: dist_a, dist_b, dist_c
     337             :       TYPE(dbm_type)                                     :: dbm_a, dbm_a_mm, dbm_b, dbm_b_mm, dbm_c, &
     338             :                                                             dbm_c_mm, dbm_c_mm_check
     339          48 :       TYPE(mp_cart_type)                                 :: comm_dbm, mp_comm
     340             : 
     341             : !
     342             : ! TODO: Dedup with code in dbt_tas_benchmark_mm.
     343             : !
     344             : 
     345          48 :       CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
     346          48 :       mynode = mp_comm%mepos
     347          48 :       abort = .FALSE.
     348          48 :       io_unit = -1
     349          48 :       IF (mynode .EQ. 0) io_unit = unit_nr
     350             : 
     351             :       CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a, matrix_b, &
     352             :                             0.0_dp, matrix_c, &
     353          48 :                             filter_eps=filter_eps, unit_nr=io_unit, log_verbose=log_verbose, optimize_dist=.TRUE.)
     354             : 
     355          48 :       CALL dbt_tas_convert_to_dbm(matrix_a, dbm_a)
     356          48 :       CALL dbt_tas_convert_to_dbm(matrix_b, dbm_b)
     357          48 :       CALL dbt_tas_convert_to_dbm(matrix_c, dbm_c)
     358             : 
     359          48 :       npdims(:) = 0
     360          48 :       CALL comm_dbm%create(mp_comm, 2, npdims)
     361             : 
     362         144 :       ALLOCATE (rd_a(SIZE(dbm_get_row_block_sizes(dbm_a))))
     363         144 :       ALLOCATE (rd_b(SIZE(dbm_get_row_block_sizes(dbm_b))))
     364         144 :       ALLOCATE (rd_c(SIZE(dbm_get_row_block_sizes(dbm_c))))
     365         144 :       ALLOCATE (cd_a(SIZE(dbm_get_col_block_sizes(dbm_a))))
     366         144 :       ALLOCATE (cd_b(SIZE(dbm_get_col_block_sizes(dbm_b))))
     367         144 :       ALLOCATE (cd_c(SIZE(dbm_get_col_block_sizes(dbm_c))))
     368             : 
     369             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_a))), &
     370          48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_a), rd_a)
     371             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_a))), &
     372          48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_a), cd_a)
     373             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_b))), &
     374          48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_b), rd_b)
     375             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_b))), &
     376          48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_b), cd_b)
     377             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_row_block_sizes(dbm_c))), &
     378          48 :                                    npdims(1), dbm_get_row_block_sizes(dbm_c), rd_c)
     379             :       CALL dbt_tas_default_distvec(INT(SIZE(dbm_get_col_block_sizes(dbm_c))), &
     380          48 :                                    npdims(2), dbm_get_col_block_sizes(dbm_c), cd_c)
     381             : 
     382          48 :       CALL dbm_distribution_new(dist_a, comm_dbm, rd_a, cd_a)
     383          48 :       CALL dbm_distribution_new(dist_b, comm_dbm, rd_b, cd_b)
     384          48 :       CALL dbm_distribution_new(dist_c, comm_dbm, rd_c, cd_c)
     385          48 :       DEALLOCATE (rd_a, rd_b, rd_c, cd_a, cd_b, cd_c)
     386             : 
     387             :       ! Store pointers in intermediate variables to workaround a CCE error.
     388          48 :       row_block_sizes_a => dbm_get_row_block_sizes(dbm_a)
     389          48 :       col_block_sizes_a => dbm_get_col_block_sizes(dbm_a)
     390          48 :       row_block_sizes_b => dbm_get_row_block_sizes(dbm_b)
     391          48 :       col_block_sizes_b => dbm_get_col_block_sizes(dbm_b)
     392          48 :       row_block_sizes_c => dbm_get_row_block_sizes(dbm_c)
     393          48 :       col_block_sizes_c => dbm_get_col_block_sizes(dbm_c)
     394             : 
     395             :       CALL dbm_create(matrix=dbm_a_mm, name="matrix a", dist=dist_a, &
     396          48 :                       row_block_sizes=row_block_sizes_a, col_block_sizes=col_block_sizes_a)
     397             : 
     398             :       CALL dbm_create(matrix=dbm_b_mm, name="matrix b", dist=dist_b, &
     399          48 :                       row_block_sizes=row_block_sizes_b, col_block_sizes=col_block_sizes_b)
     400             : 
     401             :       CALL dbm_create(matrix=dbm_c_mm, name="matrix c", dist=dist_c, &
     402          48 :                       row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     403             : 
     404             :       CALL dbm_create(matrix=dbm_c_mm_check, name="matrix c check", dist=dist_c, &
     405          48 :                       row_block_sizes=row_block_sizes_c, col_block_sizes=col_block_sizes_c)
     406             : 
     407          48 :       CALL dbm_finalize(dbm_a_mm)
     408          48 :       CALL dbm_finalize(dbm_b_mm)
     409          48 :       CALL dbm_finalize(dbm_c_mm)
     410          48 :       CALL dbm_finalize(dbm_c_mm_check)
     411             : 
     412          48 :       CALL dbm_redistribute(dbm_a, dbm_a_mm)
     413          48 :       CALL dbm_redistribute(dbm_b, dbm_b_mm)
     414          48 :       CALL dbm_redistribute(dbm_c, dbm_c_mm_check)
     415             : 
     416          48 :       transa_prv = transa; transb_prv = transb
     417             : 
     418          48 :       IF (.NOT. transc) THEN
     419             :          CALL dbm_multiply(transa_prv, transb_prv, 1.0_dp, &
     420             :                            dbm_a_mm, dbm_b_mm, &
     421          24 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     422             :       ELSE
     423          24 :          transa_prv = .NOT. transa_prv
     424          24 :          transb_prv = .NOT. transb_prv
     425             :          CALL dbm_multiply(transb_prv, transa_prv, 1.0_dp, &
     426             :                            dbm_b_mm, dbm_a_mm, &
     427          24 :                            0.0_dp, dbm_c_mm, filter_eps=filter_eps)
     428             :       END IF
     429             : 
     430          48 :       sq_cs = dbm_checksum(dbm_c_mm)
     431          48 :       rc_cs = dbm_checksum(dbm_c_mm_check)
     432          48 :       CALL dbm_scale(dbm_c_mm_check, -1.0_dp)
     433          48 :       CALL dbm_add(dbm_c_mm_check, dbm_c_mm)
     434          48 :       norm = dbm_maxabs(dbm_c_mm_check)
     435             : 
     436          48 :       IF (io_unit > 0) THEN
     437          24 :          IF (ABS(norm) > test_tol) THEN
     438           0 :             status_str = " failed!"
     439           0 :             abort = .TRUE.
     440             :          ELSE
     441          24 :             status_str = " passed!"
     442          24 :             abort = .FALSE.
     443             :          END IF
     444             :          WRITE (io_unit, "(A)") &
     445             :             TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
     446          24 :             TRIM(dbm_get_name(matrix_b%matrix))//TRIM(status_str)
     447          24 :          WRITE (io_unit, "(A,1X,E9.2,1X,E9.2)") "checksums", sq_cs, rc_cs
     448          24 :          WRITE (io_unit, "(A,1X,E9.2)") "difference norm", norm
     449          24 :          IF (abort) CPABORT("DBT TAS test failed")
     450             :       END IF
     451             : 
     452          48 :       CALL dbm_release(dbm_a)
     453          48 :       CALL dbm_release(dbm_a_mm)
     454          48 :       CALL dbm_release(dbm_b)
     455          48 :       CALL dbm_release(dbm_b_mm)
     456          48 :       CALL dbm_release(dbm_c)
     457          48 :       CALL dbm_release(dbm_c_mm)
     458          48 :       CALL dbm_release(dbm_c_mm_check)
     459             : 
     460          48 :       CALL dbm_distribution_release(dist_a)
     461          48 :       CALL dbm_distribution_release(dist_b)
     462          48 :       CALL dbm_distribution_release(dist_c)
     463             : 
     464          48 :       CALL comm_dbm%free()
     465             : 
     466         432 :    END SUBROUTINE dbt_tas_test_mm
     467             : 
     468             : ! **************************************************************************************************
     469             : !> \brief Calculate checksum of tall-and-skinny matrix consistent with dbm_checksum
     470             : !> \param matrix ...
     471             : !> \return ...
     472             : !> \author Patrick Seewald
     473             : ! **************************************************************************************************
     474          80 :    FUNCTION dbt_tas_checksum(matrix)
     475             :       TYPE(dbt_tas_type), INTENT(IN)                     :: matrix
     476             :       REAL(KIND=dp)                                      :: dbt_tas_checksum
     477             : 
     478             :       TYPE(dbm_type)                                     :: dbm_m
     479             : 
     480          80 :       CALL dbt_tas_convert_to_dbm(matrix, dbm_m)
     481          80 :       dbt_tas_checksum = dbm_checksum(dbm_m)
     482          80 :       CALL dbm_release(dbm_m)
     483          80 :    END FUNCTION
     484             : 
     485             : ! **************************************************************************************************
     486             : !> \brief Create random block sizes
     487             : !> \param sizes ...
     488             : !> \param repeat ...
     489             : !> \param dbt_sizes ...
     490             : !> \author Patrick Seewald
     491             : ! **************************************************************************************************
     492           6 :    SUBROUTINE dbt_tas_random_bsizes(sizes, repeat, dbt_sizes)
     493             :       INTEGER, DIMENSION(:), INTENT(IN)                  :: sizes
     494             :       INTEGER, INTENT(IN)                                :: repeat
     495             :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: dbt_sizes
     496             : 
     497             :       INTEGER                                            :: d, size_i
     498             : 
     499         266 :       DO d = 1, SIZE(dbt_sizes)
     500         260 :          size_i = MOD((d - 1)/repeat, SIZE(sizes)) + 1
     501         266 :          dbt_sizes(d) = sizes(size_i)
     502             :       END DO
     503           6 :    END SUBROUTINE
     504             : 
     505             : ! **************************************************************************************************
     506             : !> \brief Reset the seed used for generating random matrices to default value
     507             : !> \author Patrick Seewald
     508             : ! **************************************************************************************************
     509           2 :    SUBROUTINE dbt_tas_reset_randmat_seed()
     510           2 :       randmat_counter = rand_seed_init
     511           2 :    END SUBROUTINE
     512             : 
     513             : END MODULE

Generated by: LCOV version 1.15