LCOV - code coverage report
Current view: top level - src/dbt/tas - dbt_tas_unittest.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 99 100 99.0 %
Date: 2024-11-21 06:45:46 Functions: 2 2 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Unit testing for tall-and-skinny matrices
      10             : !> \author Patrick Seewald
      11             : ! **************************************************************************************************
      12           2 : PROGRAM dbt_tas_unittest
      13           2 :    USE cp_dbcsr_api,                    ONLY: dbcsr_finalize_lib,&
      14             :                                               dbcsr_init_lib
      15             :    USE dbm_api,                         ONLY: dbm_get_name,&
      16             :                                               dbm_library_finalize,&
      17             :                                               dbm_library_init,&
      18             :                                               dbm_library_print_stats
      19             :    USE dbt_tas_base,                    ONLY: dbt_tas_create,&
      20             :                                               dbt_tas_destroy,&
      21             :                                               dbt_tas_info,&
      22             :                                               dbt_tas_nblkcols_total,&
      23             :                                               dbt_tas_nblkrows_total
      24             :    USE dbt_tas_io,                      ONLY: dbt_tas_write_split_info
      25             :    USE dbt_tas_test,                    ONLY: dbt_tas_random_bsizes,&
      26             :                                               dbt_tas_reset_randmat_seed,&
      27             :                                               dbt_tas_setup_test_matrix,&
      28             :                                               dbt_tas_test_mm
      29             :    USE dbt_tas_types,                   ONLY: dbt_tas_type
      30             :    USE kinds,                           ONLY: dp,&
      31             :                                               int_8
      32             :    USE machine,                         ONLY: default_output_unit
      33             :    USE message_passing,                 ONLY: mp_cart_type,&
      34             :                                               mp_comm_type,&
      35             :                                               mp_world_finalize,&
      36             :                                               mp_world_init
      37             :    USE offload_api,                     ONLY: offload_get_device_count,&
      38             :                                               offload_set_chosen_device
      39             : #include "../../base/base_uses.f90"
      40             : 
      41             :    IMPLICIT NONE
      42             : 
      43             :    INTEGER(KIND=int_8), PARAMETER :: m = 100, k = 20, n = 10
      44          98 :    TYPE(dbt_tas_type)             :: A, B, C, At, Bt, Ct, A_out, B_out, C_out, At_out, Bt_out, Ct_out
      45             :    INTEGER, DIMENSION(m)          :: bsize_m
      46             :    INTEGER, DIMENSION(n)          :: bsize_n
      47             :    INTEGER, DIMENSION(k)          :: bsize_k
      48             :    REAL(KIND=dp), PARAMETER   :: sparsity = 0.1
      49             :    INTEGER                        :: mynode, io_unit
      50             :    TYPE(mp_comm_type)             :: mp_comm
      51           2 :    TYPE(mp_cart_type) :: mp_comm_A, mp_comm_At, mp_comm_B, mp_comm_Bt, mp_comm_C, mp_comm_Ct
      52             :    REAL(KIND=dp), PARAMETER   :: filter_eps = 1.0E-08
      53             : 
      54           2 :    CALL mp_world_init(mp_comm)
      55             : 
      56           2 :    mynode = mp_comm%mepos
      57             : 
      58             :    ! Select active offload device when available.
      59           2 :    IF (offload_get_device_count() > 0) THEN
      60           0 :       CALL offload_set_chosen_device(MOD(mynode, offload_get_device_count()))
      61             :    END IF
      62             : 
      63           2 :    io_unit = -1
      64           2 :    IF (mynode .EQ. 0) io_unit = default_output_unit
      65             : 
      66           2 :    CALL dbcsr_init_lib(mp_comm%get_handle(), io_unit) ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
      67           2 :    CALL dbm_library_init()
      68             : 
      69           2 :    CALL dbt_tas_reset_randmat_seed()
      70             : 
      71           2 :    CALL dbt_tas_random_bsizes([13, 8, 5, 25, 12], 2, bsize_m)
      72           2 :    CALL dbt_tas_random_bsizes([3, 78, 33, 12, 3, 15], 1, bsize_n)
      73           2 :    CALL dbt_tas_random_bsizes([9, 64, 23, 2], 3, bsize_k)
      74             : 
      75           2 :    CALL dbt_tas_setup_test_matrix(A, mp_comm_A, mp_comm, m, k, bsize_m, bsize_k, [5, 1], "A", sparsity)
      76           2 :    CALL dbt_tas_setup_test_matrix(At, mp_comm_At, mp_comm, k, m, bsize_k, bsize_m, [3, 8], "A^t", sparsity)
      77           2 :    CALL dbt_tas_setup_test_matrix(B, mp_comm_B, mp_comm, n, m, bsize_n, bsize_m, [3, 2], "B", sparsity)
      78           2 :    CALL dbt_tas_setup_test_matrix(Bt, mp_comm_Bt, mp_comm, m, n, bsize_m, bsize_n, [1, 3], "B^t", sparsity)
      79           2 :    CALL dbt_tas_setup_test_matrix(C, mp_comm_C, mp_comm, k, n, bsize_k, bsize_n, [5, 7], "C", sparsity)
      80           2 :    CALL dbt_tas_setup_test_matrix(Ct, mp_comm_Ct, mp_comm, n, k, bsize_n, bsize_k, [1, 1], "C^t", sparsity)
      81             : 
      82           2 :    CALL dbt_tas_create(A, A_out)
      83           2 :    CALL dbt_tas_create(At, At_out)
      84           2 :    CALL dbt_tas_create(B, B_out)
      85           2 :    CALL dbt_tas_create(Bt, Bt_out)
      86           2 :    CALL dbt_tas_create(C, C_out)
      87           2 :    CALL dbt_tas_create(Ct, Ct_out)
      88             : 
      89           2 :    IF (mynode == 0) WRITE (io_unit, '(A)') "DBM TALL-AND-SKINNY MATRICES"
      90           1 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      91           1 :       TRIM(dbm_get_name(A%matrix)), &
      92           2 :       dbt_tas_nblkrows_total(A), 'X', dbt_tas_nblkcols_total(A)
      93           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(A), io_unit, name="A")
      94           3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      95           1 :       TRIM(dbm_get_name(At%matrix)), &
      96           2 :       dbt_tas_nblkrows_total(At), 'X', dbt_tas_nblkcols_total(At)
      97           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(At), io_unit, name="At")
      98           3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      99           1 :       TRIM(dbm_get_name(B%matrix)), &
     100           2 :       dbt_tas_nblkrows_total(B), 'X', dbt_tas_nblkcols_total(B)
     101           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(B), io_unit, name="B")
     102           3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     103           1 :       TRIM(dbm_get_name(Bt%matrix)), &
     104           2 :       dbt_tas_nblkrows_total(Bt), 'X', dbt_tas_nblkcols_total(Bt)
     105           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(Bt), io_unit, name="Bt")
     106           3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     107           1 :       TRIM(dbm_get_name(C%matrix)), &
     108           2 :       dbt_tas_nblkrows_total(C), 'X', dbt_tas_nblkcols_total(C)
     109           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(C), io_unit, name="C")
     110           3 :    IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
     111           1 :       TRIM(dbm_get_name(Ct%matrix)), &
     112           2 :       dbt_tas_nblkrows_total(Ct), 'X', dbt_tas_nblkcols_total(Ct)
     113           2 :    CALL dbt_tas_write_split_info(dbt_tas_info(Ct), io_unit, name="Ct")
     114             : 
     115           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., B, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     116           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Bt, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     117           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., B, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     118           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Bt, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
     119           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., B, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     120           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Bt, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     121           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., B, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     122           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Bt, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
     123             : 
     124           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., A, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     125           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., At, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     126           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., A, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     127           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., At, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
     128             : 
     129           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., A, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     130           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., At, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     131           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., A, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     132           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., At, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
     133             : 
     134           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., C, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     135           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Ct, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     136           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., C, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     137           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Ct, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
     138             : 
     139           2 :    CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., C, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     140           2 :    CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Ct, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     141           2 :    CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., C, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     142           2 :    CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Ct, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
     143             : 
     144           2 :    CALL dbt_tas_destroy(A)
     145           2 :    CALL dbt_tas_destroy(At)
     146           2 :    CALL dbt_tas_destroy(B)
     147           2 :    CALL dbt_tas_destroy(Bt)
     148           2 :    CALL dbt_tas_destroy(C)
     149           2 :    CALL dbt_tas_destroy(Ct)
     150           2 :    CALL dbt_tas_destroy(A_out)
     151           2 :    CALL dbt_tas_destroy(At_out)
     152           2 :    CALL dbt_tas_destroy(B_out)
     153           2 :    CALL dbt_tas_destroy(Bt_out)
     154           2 :    CALL dbt_tas_destroy(C_out)
     155           2 :    CALL dbt_tas_destroy(Ct_out)
     156             : 
     157           2 :    CALL mp_comm_A%free()
     158           2 :    CALL mp_comm_At%free()
     159           2 :    CALL mp_comm_B%free()
     160           2 :    CALL mp_comm_Bt%free()
     161           2 :    CALL mp_comm_C%free()
     162           2 :    CALL mp_comm_Ct%free()
     163             : 
     164           2 :    CALL dbm_library_print_stats(mp_comm, io_unit)
     165           2 :    CALL dbm_library_finalize()
     166           2 :    CALL dbcsr_finalize_lib() ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
     167           2 :    CALL mp_world_finalize()
     168             : 
     169           2 : END PROGRAM

Generated by: LCOV version 1.15