Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Unit testing for tall-and-skinny matrices
10 : !> \author Patrick Seewald
11 : ! **************************************************************************************************
12 2 : PROGRAM dbt_tas_unittest
13 2 : USE cp_dbcsr_api, ONLY: dbcsr_finalize_lib,&
14 : dbcsr_init_lib
15 : USE dbm_api, ONLY: dbm_get_name,&
16 : dbm_library_finalize,&
17 : dbm_library_init,&
18 : dbm_library_print_stats
19 : USE dbt_tas_base, ONLY: dbt_tas_create,&
20 : dbt_tas_destroy,&
21 : dbt_tas_info,&
22 : dbt_tas_nblkcols_total,&
23 : dbt_tas_nblkrows_total
24 : USE dbt_tas_io, ONLY: dbt_tas_write_split_info
25 : USE dbt_tas_test, ONLY: dbt_tas_random_bsizes,&
26 : dbt_tas_reset_randmat_seed,&
27 : dbt_tas_setup_test_matrix,&
28 : dbt_tas_test_mm
29 : USE dbt_tas_types, ONLY: dbt_tas_type
30 : USE kinds, ONLY: dp,&
31 : int_8
32 : USE machine, ONLY: default_output_unit
33 : USE message_passing, ONLY: mp_cart_type,&
34 : mp_comm_type,&
35 : mp_world_finalize,&
36 : mp_world_init
37 : USE offload_api, ONLY: offload_get_device_count,&
38 : offload_set_chosen_device
39 : #include "../../base/base_uses.f90"
40 :
41 : IMPLICIT NONE
42 :
43 : INTEGER(KIND=int_8), PARAMETER :: m = 100, k = 20, n = 10
44 98 : TYPE(dbt_tas_type) :: A, B, C, At, Bt, Ct, A_out, B_out, C_out, At_out, Bt_out, Ct_out
45 : INTEGER, DIMENSION(m) :: bsize_m
46 : INTEGER, DIMENSION(n) :: bsize_n
47 : INTEGER, DIMENSION(k) :: bsize_k
48 : REAL(KIND=dp), PARAMETER :: sparsity = 0.1
49 : INTEGER :: mynode, io_unit
50 : TYPE(mp_comm_type) :: mp_comm
51 2 : TYPE(mp_cart_type) :: mp_comm_A, mp_comm_At, mp_comm_B, mp_comm_Bt, mp_comm_C, mp_comm_Ct
52 : REAL(KIND=dp), PARAMETER :: filter_eps = 1.0E-08
53 :
54 2 : CALL mp_world_init(mp_comm)
55 :
56 2 : mynode = mp_comm%mepos
57 :
58 : ! Select active offload device when available.
59 2 : IF (offload_get_device_count() > 0) THEN
60 0 : CALL offload_set_chosen_device(MOD(mynode, offload_get_device_count()))
61 : END IF
62 :
63 2 : io_unit = -1
64 2 : IF (mynode .EQ. 0) io_unit = default_output_unit
65 :
66 2 : CALL dbcsr_init_lib(mp_comm%get_handle(), io_unit) ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
67 2 : CALL dbm_library_init()
68 :
69 2 : CALL dbt_tas_reset_randmat_seed()
70 :
71 2 : CALL dbt_tas_random_bsizes([13, 8, 5, 25, 12], 2, bsize_m)
72 2 : CALL dbt_tas_random_bsizes([3, 78, 33, 12, 3, 15], 1, bsize_n)
73 2 : CALL dbt_tas_random_bsizes([9, 64, 23, 2], 3, bsize_k)
74 :
75 2 : CALL dbt_tas_setup_test_matrix(A, mp_comm_A, mp_comm, m, k, bsize_m, bsize_k, [5, 1], "A", sparsity)
76 2 : CALL dbt_tas_setup_test_matrix(At, mp_comm_At, mp_comm, k, m, bsize_k, bsize_m, [3, 8], "A^t", sparsity)
77 2 : CALL dbt_tas_setup_test_matrix(B, mp_comm_B, mp_comm, n, m, bsize_n, bsize_m, [3, 2], "B", sparsity)
78 2 : CALL dbt_tas_setup_test_matrix(Bt, mp_comm_Bt, mp_comm, m, n, bsize_m, bsize_n, [1, 3], "B^t", sparsity)
79 2 : CALL dbt_tas_setup_test_matrix(C, mp_comm_C, mp_comm, k, n, bsize_k, bsize_n, [5, 7], "C", sparsity)
80 2 : CALL dbt_tas_setup_test_matrix(Ct, mp_comm_Ct, mp_comm, n, k, bsize_n, bsize_k, [1, 1], "C^t", sparsity)
81 :
82 2 : CALL dbt_tas_create(A, A_out)
83 2 : CALL dbt_tas_create(At, At_out)
84 2 : CALL dbt_tas_create(B, B_out)
85 2 : CALL dbt_tas_create(Bt, Bt_out)
86 2 : CALL dbt_tas_create(C, C_out)
87 2 : CALL dbt_tas_create(Ct, Ct_out)
88 :
89 2 : IF (mynode == 0) WRITE (io_unit, '(A)') "DBM TALL-AND-SKINNY MATRICES"
90 1 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
91 1 : TRIM(dbm_get_name(A%matrix)), &
92 2 : dbt_tas_nblkrows_total(A), 'X', dbt_tas_nblkcols_total(A)
93 2 : CALL dbt_tas_write_split_info(dbt_tas_info(A), io_unit, name="A")
94 3 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
95 1 : TRIM(dbm_get_name(At%matrix)), &
96 2 : dbt_tas_nblkrows_total(At), 'X', dbt_tas_nblkcols_total(At)
97 2 : CALL dbt_tas_write_split_info(dbt_tas_info(At), io_unit, name="At")
98 3 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
99 1 : TRIM(dbm_get_name(B%matrix)), &
100 2 : dbt_tas_nblkrows_total(B), 'X', dbt_tas_nblkcols_total(B)
101 2 : CALL dbt_tas_write_split_info(dbt_tas_info(B), io_unit, name="B")
102 3 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
103 1 : TRIM(dbm_get_name(Bt%matrix)), &
104 2 : dbt_tas_nblkrows_total(Bt), 'X', dbt_tas_nblkcols_total(Bt)
105 2 : CALL dbt_tas_write_split_info(dbt_tas_info(Bt), io_unit, name="Bt")
106 3 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
107 1 : TRIM(dbm_get_name(C%matrix)), &
108 2 : dbt_tas_nblkrows_total(C), 'X', dbt_tas_nblkcols_total(C)
109 2 : CALL dbt_tas_write_split_info(dbt_tas_info(C), io_unit, name="C")
110 3 : IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
111 1 : TRIM(dbm_get_name(Ct%matrix)), &
112 2 : dbt_tas_nblkrows_total(Ct), 'X', dbt_tas_nblkcols_total(Ct)
113 2 : CALL dbt_tas_write_split_info(dbt_tas_info(Ct), io_unit, name="Ct")
114 :
115 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., B, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
116 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Bt, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
117 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., B, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
118 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Bt, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
119 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., B, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
120 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Bt, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
121 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., B, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
122 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Bt, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
123 :
124 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., A, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
125 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., At, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
126 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., A, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
127 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., At, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
128 :
129 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., A, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
130 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., At, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
131 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., A, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
132 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., At, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
133 :
134 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .FALSE., C, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
135 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .FALSE., Ct, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
136 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .FALSE., C, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
137 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .FALSE., Ct, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
138 :
139 2 : CALL dbt_tas_test_mm(.FALSE., .FALSE., .TRUE., C, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
140 2 : CALL dbt_tas_test_mm(.TRUE., .FALSE., .TRUE., Ct, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
141 2 : CALL dbt_tas_test_mm(.FALSE., .TRUE., .TRUE., C, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
142 2 : CALL dbt_tas_test_mm(.TRUE., .TRUE., .TRUE., Ct, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
143 :
144 2 : CALL dbt_tas_destroy(A)
145 2 : CALL dbt_tas_destroy(At)
146 2 : CALL dbt_tas_destroy(B)
147 2 : CALL dbt_tas_destroy(Bt)
148 2 : CALL dbt_tas_destroy(C)
149 2 : CALL dbt_tas_destroy(Ct)
150 2 : CALL dbt_tas_destroy(A_out)
151 2 : CALL dbt_tas_destroy(At_out)
152 2 : CALL dbt_tas_destroy(B_out)
153 2 : CALL dbt_tas_destroy(Bt_out)
154 2 : CALL dbt_tas_destroy(C_out)
155 2 : CALL dbt_tas_destroy(Ct_out)
156 :
157 2 : CALL mp_comm_A%free()
158 2 : CALL mp_comm_At%free()
159 2 : CALL mp_comm_B%free()
160 2 : CALL mp_comm_Bt%free()
161 2 : CALL mp_comm_C%free()
162 2 : CALL mp_comm_Ct%free()
163 :
164 2 : CALL dbm_library_print_stats(mp_comm, io_unit)
165 2 : CALL dbm_library_finalize()
166 2 : CALL dbcsr_finalize_lib() ! Needed for DBM_VALIDATE_AGAINST_DBCSR.
167 2 : CALL mp_world_finalize()
168 :
169 2 : END PROGRAM
|