LCOV - code coverage report
Current view: top level - src/common - parallel_rng_types_unittest.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:d1f8d1b) Lines: 103 111 92.8 %
Date: 2024-11-29 06:42:44 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8           2 : PROGRAM parallel_rng_types_TEST
       9           2 :    USE message_passing, ONLY: mp_world_finalize, &
      10             :                               mp_world_init, &
      11             :                               mp_comm_type
      12             :    USE kinds, ONLY: dp
      13             :    USE machine, ONLY: m_walltime, &
      14             :                       default_output_unit
      15             :    USE parallel_rng_types, ONLY: GAUSSIAN, &
      16             :                                  UNIFORM, &
      17             :                                  check_rng, &
      18             :                                  rng_stream_type, &
      19             :                                  rng_stream_type_from_record, &
      20             :                                  rng_name_length, &
      21             :                                  rng_record_length
      22             : 
      23             :    IMPLICIT NONE
      24             : 
      25             :    INTEGER                          :: i, nsamples, nargs, stat
      26             :    LOGICAL                          :: ionode
      27             :    REAL(KIND=dp)                    :: t, tend, tmax, tmin, tstart, tsum, tsum2
      28             :    TYPE(mp_comm_type) :: mpi_comm
      29             :    TYPE(rng_stream_type)            :: rng_stream
      30             :    CHARACTER(len=32)                :: arg
      31             : 
      32           2 :    nsamples = 1000
      33           2 :    nargs = command_argument_count()
      34             : 
      35           2 :    IF (nargs .GT. 1) &
      36           0 :       ERROR STOP "Usage: parallel_rng_types_TEST [<int:nsamples>]"
      37             : 
      38           2 :    IF (nargs == 1) THEN
      39           2 :       CALL get_command_argument(1, arg)
      40           2 :       READ (arg, *, iostat=stat) nsamples
      41           2 :       IF (stat /= 0) &
      42           0 :          ERROR STOP "Usage: parallel_rng_types_TEST [<int:nsamples>]"
      43             :    END IF
      44             : 
      45           2 :    CALL mp_world_init(mpi_comm)
      46           2 :    ionode = mpi_comm%is_source()
      47             : 
      48           2 :    CALL check_rng(default_output_unit, ionode)
      49             : 
      50             :    ! Check performance
      51             : 
      52           2 :    IF (ionode) THEN
      53             :       WRITE (UNIT=default_output_unit, FMT="(/,/,T2,A,I10,A)") &
      54           1 :          "Check distributions using", nsamples, " random numbers:"
      55             :    END IF
      56             : 
      57             :    ! Test uniform distribution [0,1]
      58             : 
      59             :    rng_stream = rng_stream_type(name="Test uniform distribution [0,1]", &
      60             :                                 distribution_type=UNIFORM, &
      61           2 :                                 extended_precision=.TRUE.)
      62             : 
      63           2 :    IF (ionode) &
      64           1 :       CALL rng_stream%write(default_output_unit)
      65             : 
      66           2 :    tmax = -HUGE(0.0_dp)
      67           2 :    tmin = +HUGE(0.0_dp)
      68           2 :    tsum = 0.0_dp
      69           2 :    tsum2 = 0.0_dp
      70             : 
      71           2 :    tstart = m_walltime()
      72        2002 :    DO i = 1, nsamples
      73        2000 :       t = rng_stream%next()
      74        2000 :       tsum = tsum + t
      75        2000 :       tsum2 = tsum2 + t*t
      76        2000 :       IF (t > tmax) tmax = t
      77        2002 :       IF (t < tmin) tmin = t
      78             :    END DO
      79           2 :    tend = m_walltime()
      80             : 
      81           2 :    IF (ionode) THEN
      82           1 :       CALL rng_stream%write(default_output_unit, write_all=.TRUE.)
      83             :       WRITE (UNIT=default_output_unit, FMT="(/,(T4,A,F12.6))") &
      84           1 :          "Minimum: ", tmin, &
      85           1 :          "Maximum: ", tmax, &
      86           1 :          "Average: ", tsum/REAL(nsamples, KIND=dp), &
      87           1 :          "Variance:", tsum2/REAL(nsamples, KIND=dp), &
      88           2 :          "Time [s]:", tend - tstart
      89             :    END IF
      90             : 
      91             :    ! Test normal Gaussian distribution
      92             : 
      93             :    rng_stream = rng_stream_type(name="Test normal Gaussian distribution", &
      94             :                                 distribution_type=GAUSSIAN, &
      95           2 :                                 extended_precision=.TRUE.)
      96             : 
      97           2 :    IF (ionode) &
      98           1 :       CALL rng_stream%write(default_output_unit)
      99             : 
     100           2 :    tmax = -HUGE(0.0_dp)
     101           2 :    tmin = +HUGE(0.0_dp)
     102           2 :    tsum = 0.0_dp
     103           2 :    tsum2 = 0.0_dp
     104             : 
     105           2 :    tstart = m_walltime()
     106        2002 :    DO i = 1, nsamples
     107        2000 :       t = rng_stream%next()
     108        2000 :       tsum = tsum + t
     109        2000 :       tsum2 = tsum2 + t*t
     110        2000 :       IF (t > tmax) tmax = t
     111        2002 :       IF (t < tmin) tmin = t
     112             :    END DO
     113           2 :    tend = m_walltime()
     114             : 
     115           2 :    IF (ionode) THEN
     116           1 :       CALL rng_stream%write(default_output_unit)
     117             :       WRITE (UNIT=default_output_unit, FMT="(/,(T4,A,F12.6))") &
     118           1 :          "Minimum: ", tmin, &
     119           1 :          "Maximum: ", tmax, &
     120           1 :          "Average: ", tsum/REAL(nsamples, KIND=dp), &
     121           1 :          "Variance:", tsum2/REAL(nsamples, KIND=dp), &
     122           2 :          "Time [s]:", tend - tstart
     123             :    END IF
     124             : 
     125           2 :    IF (ionode) THEN
     126           1 :       CALL dump_reload_check()
     127           1 :       CALL shuffle_check()
     128             :    END IF
     129             : 
     130           2 :    CALL mp_world_finalize()
     131             : 
     132             : CONTAINS
     133             : ! **************************************************************************************************
     134             : !> \brief ...
     135             : ! **************************************************************************************************
     136           1 :    SUBROUTINE dump_reload_check()
     137             :       TYPE(rng_stream_type)            :: rng_stream
     138             :       CHARACTER(len=rng_record_length) :: rng_record
     139             :       REAL(KIND=dp), DIMENSION(3, 2)   :: ig, ig_orig, cg, cg_orig, bg, bg_orig
     140             :       CHARACTER(len=rng_name_length)   :: name, name_orig
     141             :       CHARACTER(len=*), PARAMETER      :: serialized_string = &
     142             :          "qtb_rng_gaussian                         1 F T F   0.0000000000000000E+00&
     143             :          &                12.0                12.0                12.0&
     144             :          &                12.0                12.0                12.0&
     145             :          &                12.0                12.0                12.0&
     146             :          &                12.0                12.0                12.0&
     147             :          &                12.0                12.0                12.0&
     148             :          &                12.0                12.0                12.0"
     149             : 
     150             :       WRITE (UNIT=default_output_unit, FMT="(/,/,T2,A)") &
     151           1 :          "Checking dump and load round trip:"
     152             : 
     153             :       rng_stream = rng_stream_type(name="Roundtrip for normal Gaussian distrib", &
     154             :                                    distribution_type=GAUSSIAN, &
     155           1 :                                    extended_precision=.TRUE.)
     156             : 
     157           1 :       CALL rng_stream%advance(7, 42)
     158           1 :       CALL rng_stream%get(ig=ig_orig, cg=cg_orig, bg=bg_orig, name=name_orig)
     159           1 :       CALL rng_stream%dump(rng_record)
     160             : 
     161           1 :       rng_stream = rng_stream_type_from_record(rng_record)
     162           1 :       CALL rng_stream%get(ig=ig, cg=cg, bg=bg, name=name)
     163             : 
     164             :       IF (ANY(ig /= ig_orig) .OR. ANY(cg /= cg_orig) .OR. ANY(bg /= bg_orig) &
     165          27 :           .OR. (name /= name_orig)) &
     166           0 :          ERROR STOP "Stream dump and load roundtrip failed"
     167             : 
     168             :       WRITE (UNIT=default_output_unit, FMT="(T4,A)") &
     169           1 :          "Roundtrip successful"
     170             : 
     171             :       WRITE (UNIT=default_output_unit, FMT="(/,/,T2,A)") &
     172           1 :          "Checking dumped format:"
     173             : 
     174           9 :       ig(:, :) = 12.0_dp
     175             :       rng_stream = rng_stream_type(name="qtb_rng_gaussian", &
     176             :                                    distribution_type=GAUSSIAN, &
     177             :                                    extended_precision=.TRUE., &
     178           1 :                                    seed=ig)
     179             : 
     180           1 :       CALL rng_stream%dump(rng_record)
     181             : 
     182             :       WRITE (UNIT=default_output_unit, FMT="(T4,A10,A433)") &
     183           1 :          "EXPECTED:", serialized_string
     184             : 
     185             :       WRITE (UNIT=default_output_unit, FMT="(T4,A10,A433)") &
     186           1 :          "GENERATED:", rng_record
     187             : 
     188           1 :       IF (rng_record /= serialized_string) &
     189           0 :          ERROR STOP "Serialized record does not match the expected output"
     190             : 
     191             :       WRITE (UNIT=default_output_unit, FMT="(T4,A)") &
     192           1 :          "Serialized record matches the expected output"
     193             : 
     194          25 :    END SUBROUTINE
     195             : 
     196             : ! **************************************************************************************************
     197             : !> \brief ...
     198             : ! **************************************************************************************************
     199           1 :    SUBROUTINE shuffle_check()
     200             :       TYPE(rng_stream_type)              :: rng_stream
     201             : 
     202             :       INTEGER, PARAMETER                 :: sz = 20
     203             :       INTEGER, DIMENSION(1:sz)           :: arr, arr2, orig
     204             :       LOGICAL, DIMENSION(1:sz)           :: mask
     205             :       INTEGER :: idx
     206             :       REAL(KIND=dp), DIMENSION(3, 2), PARAMETER :: ig = 12.0_dp
     207             : 
     208             :       WRITE (UNIT=default_output_unit, FMT="(/,/,T2,A)", ADVANCE="no") &
     209           1 :          "Checking shuffle()"
     210             : 
     211           1 :       rng_stream = rng_stream_type(name="shuffle() check", seed=ig)
     212           1 :       orig = [(idx, idx=1, sz)]
     213             : 
     214           1 :       arr = orig
     215           1 :       CALL rng_stream%shuffle(arr)
     216             : 
     217           1 :       IF (ALL(arr == orig)) &
     218           0 :          ERROR STOP "shuffle failed: array was left untouched"
     219           1 :       WRITE (UNIT=default_output_unit, FMT="(A)", ADVANCE="no") "."
     220             : 
     221          21 :       IF (ANY(arr /= orig(arr))) &
     222           0 :          ERROR STOP "shuffle failed: the shuffled original is not the shuffled original"
     223           1 :       WRITE (UNIT=default_output_unit, FMT="(A)", ADVANCE="no") "."
     224             : 
     225             :       ! sort and compare to orig
     226          21 :       mask = .TRUE.
     227          21 :       DO idx = 1, size(orig)
     228         420 :          IF (MINVAL(arr, mask) /= orig(idx)) &
     229           0 :             ERROR STOP "shuffle failed: there is at least one unknown index"
     230         481 :          mask(MINLOC(arr, mask)) = .FALSE.
     231             :       END DO
     232           1 :       WRITE (UNIT=default_output_unit, FMT="(A)", ADVANCE="no") "."
     233             : 
     234           1 :       arr2 = orig
     235           1 :       CALL rng_stream%reset()
     236           1 :       CALL rng_stream%shuffle(arr2)
     237             : 
     238          21 :       IF (ANY(arr2 /= arr)) &
     239           0 :          ERROR STOP "shuffle failed: array was shuffled differently with same rng state"
     240           1 :       WRITE (UNIT=default_output_unit, FMT="(A)", ADVANCE="no") "."
     241             : 
     242             :       WRITE (UNIT=default_output_unit, FMT="(T4,A)") &
     243           1 :          " successful"
     244          27 :    END SUBROUTINE
     245             : END PROGRAM parallel_rng_types_TEST
     246             : ! vim: set ts=3 sw=3 tw=132 :

Generated by: LCOV version 1.15