LCOV - code coverage report
Current view: top level - src/dbt - dbt_reshape_ops.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:262480d) Lines: 56 57 98.2 %
Date: 2024-11-22 07:00:40 Functions: 2 4 50.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Routines to reshape / redistribute tensors
      10             : !> \author Patrick Seewald
      11             : ! **************************************************************************************************
      12             : MODULE dbt_reshape_ops
      13             :    #:include "dbt_macros.fypp"
      14             :    #:set maxdim = maxrank
      15             :    #:set ndims = range(2,maxdim+1)
      16             : 
      17             :    USE dbt_allocate_wrap, ONLY: allocate_any
      18             :    USE dbt_tas_base, ONLY: dbt_tas_copy, dbt_tas_get_info, dbt_tas_info
      19             :    USE dbt_block, ONLY: &
      20             :       block_nd, create_block, destroy_block, dbt_iterator_type, dbt_iterator_next_block, &
      21             :       dbt_iterator_blocks_left, dbt_iterator_start, dbt_iterator_stop, dbt_get_block, &
      22             :       dbt_reserve_blocks, dbt_put_block
      23             :    USE dbt_types, ONLY: dbt_blk_sizes, &
      24             :                         dbt_create, &
      25             :                         dbt_type, &
      26             :                         ndims_tensor, &
      27             :                         dbt_get_stored_coordinates, &
      28             :                         dbt_clear
      29             :    USE kinds, ONLY: default_string_length
      30             :    USE kinds, ONLY: dp, dp
      31             :    USE message_passing, ONLY: &
      32             :       mp_waitall, mp_comm_type, mp_request_type
      33             : 
      34             : #include "../base/base_uses.f90"
      35             : 
      36             :    IMPLICIT NONE
      37             :    PRIVATE
      38             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_reshape_ops'
      39             : 
      40             :    PUBLIC :: dbt_reshape
      41             : 
      42             :    TYPE block_buffer_type
      43             :       INTEGER, DIMENSION(:, :), ALLOCATABLE      :: blocks
      44             :       REAL(dp), DIMENSION(:), ALLOCATABLE        :: data
      45             :    END TYPE
      46             : 
      47             : CONTAINS
      48             : 
      49             : ! **************************************************************************************************
      50             : !> \brief copy data (involves reshape)
      51             : !>        tensor_out = tensor_out + tensor_in move_data memory optimization:
      52             : !>        transfer data from tensor_in to tensor_out s.t. tensor_in is empty on return
      53             : !> \author Ole Schuett
      54             : ! **************************************************************************************************
      55      191581 :    SUBROUTINE dbt_reshape(tensor_in, tensor_out, summation, move_data)
      56             : 
      57             :       TYPE(dbt_type), INTENT(INOUT)               :: tensor_in, tensor_out
      58             :       LOGICAL, INTENT(IN), OPTIONAL                    :: summation
      59             :       LOGICAL, INTENT(IN), OPTIONAL                    :: move_data
      60             : 
      61             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_reshape'
      62             : 
      63             :       INTEGER                                            :: iproc, numnodes, &
      64             :                                                             handle, iblk, jblk, offset, ndata, &
      65             :                                                             nblks_recv_mythread
      66      191581 :       INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: blks_to_allocate
      67             :       TYPE(dbt_iterator_type)                            :: iter
      68      191581 :       TYPE(block_nd)                                     :: blk_data
      69      191581 :       TYPE(block_buffer_type), ALLOCATABLE, DIMENSION(:) :: buffer_recv, buffer_send
      70      191581 :       INTEGER, DIMENSION(ndims_tensor(tensor_in))        :: blk_size, ind_nd
      71             :       LOGICAL :: found, summation_prv, move_prv
      72             : 
      73      191581 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: nblks_send_total, ndata_send_total, &
      74      191581 :                                                             nblks_recv_total, ndata_recv_total, &
      75      191581 :                                                             nblks_send_mythread, ndata_send_mythread
      76             :       TYPE(mp_comm_type) :: mp_comm
      77             : 
      78      191581 :       CALL timeset(routineN, handle)
      79             : 
      80      191581 :       IF (PRESENT(summation)) THEN
      81       65453 :          summation_prv = summation
      82             :       ELSE
      83             :          summation_prv = .FALSE.
      84             :       END IF
      85             : 
      86      191581 :       IF (PRESENT(move_data)) THEN
      87      191581 :          move_prv = move_data
      88             :       ELSE
      89             :          move_prv = .FALSE.
      90             :       END IF
      91             : 
      92      191581 :       CPASSERT(tensor_out%valid)
      93             : 
      94      191581 :       IF (.NOT. summation_prv) CALL dbt_clear(tensor_out)
      95             : 
      96      191581 :       mp_comm = tensor_in%pgrid%mp_comm_2d
      97      191581 :       numnodes = mp_comm%num_pe
      98     1473318 :       ALLOCATE (buffer_send(0:numnodes - 1), buffer_recv(0:numnodes - 1))
      99     1473318 :       ALLOCATE (nblks_send_total(0:numnodes - 1), ndata_send_total(0:numnodes - 1), source=0)
     100     1281737 :       ALLOCATE (nblks_recv_total(0:numnodes - 1), ndata_recv_total(0:numnodes - 1), source=0)
     101             : 
     102             : !$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
     103             : !$OMP SHARED(tensor_in,tensor_out,summation) &
     104             : !$OMP SHARED(buffer_send,buffer_recv,mp_comm,numnodes) &
     105             : !$OMP SHARED(nblks_send_total,ndata_send_total,nblks_recv_total,ndata_recv_total) &
     106             : !$OMP PRIVATE(nblks_send_mythread,ndata_send_mythread,nblks_recv_mythread) &
     107             : !$OMP PRIVATE(iter,ind_nd,blk_size,blk_data,found,iproc) &
     108      191581 : !$OMP PRIVATE(blks_to_allocate,offset,ndata,iblk,jblk)
     109             :       ALLOCATE (nblks_send_mythread(0:numnodes - 1), ndata_send_mythread(0:numnodes - 1), source=0)
     110             : 
     111             :       CALL dbt_iterator_start(iter, tensor_in)
     112             :       DO WHILE (dbt_iterator_blocks_left(iter))
     113             :          CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
     114             :          CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
     115             :          nblks_send_mythread(iproc) = nblks_send_mythread(iproc) + 1
     116             :          ndata_send_mythread(iproc) = ndata_send_mythread(iproc) + PRODUCT(blk_size)
     117             :       END DO
     118             :       CALL dbt_iterator_stop(iter)
     119             : !$OMP CRITICAL
     120             :       nblks_send_total(:) = nblks_send_total(:) + nblks_send_mythread(:)
     121             :       ndata_send_total(:) = ndata_send_total(:) + ndata_send_mythread(:)
     122             :       nblks_send_mythread(:) = nblks_send_total(:) ! current totals indicate slot for this thread
     123             :       ndata_send_mythread(:) = ndata_send_total(:)
     124             : !$OMP END CRITICAL
     125             : !$OMP BARRIER
     126             : 
     127             : !$OMP MASTER
     128             :       CALL mp_comm%alltoall(nblks_send_total, nblks_recv_total, 1)
     129             :       CALL mp_comm%alltoall(ndata_send_total, ndata_recv_total, 1)
     130             : !$OMP END MASTER
     131             : !$OMP BARRIER
     132             : 
     133             : !$OMP DO
     134             :       DO iproc = 0, numnodes - 1
     135             :          ALLOCATE (buffer_send(iproc)%data(ndata_send_total(iproc)))
     136             :          ALLOCATE (buffer_recv(iproc)%data(ndata_recv_total(iproc)))
     137             :          ! going to use buffer%blocks(:,0) to store data offsets
     138             :          ALLOCATE (buffer_send(iproc)%blocks(nblks_send_total(iproc), 0:ndims_tensor(tensor_in)))
     139             :          ALLOCATE (buffer_recv(iproc)%blocks(nblks_recv_total(iproc), 0:ndims_tensor(tensor_in)))
     140             :       END DO
     141             : !$OMP END DO
     142             : !$OMP BARRIER
     143             : 
     144             :       CALL dbt_iterator_start(iter, tensor_in)
     145             :       DO WHILE (dbt_iterator_blocks_left(iter))
     146             :          CALL dbt_iterator_next_block(iter, ind_nd, blk_size=blk_size)
     147             :          CALL dbt_get_stored_coordinates(tensor_out, ind_nd, iproc)
     148             :          CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
     149             :          CPASSERT(found)
     150             :          ! insert block data
     151             :          ndata = PRODUCT(blk_size)
     152             :          ndata_send_mythread(iproc) = ndata_send_mythread(iproc) - ndata
     153             :          offset = ndata_send_mythread(iproc)
     154             :          buffer_send(iproc)%data(offset + 1:offset + ndata) = blk_data%blk(:)
     155             :          ! insert block index
     156             :          nblks_send_mythread(iproc) = nblks_send_mythread(iproc) - 1
     157             :          iblk = nblks_send_mythread(iproc) + 1
     158             :          buffer_send(iproc)%blocks(iblk, 1:) = ind_nd(:)
     159             :          buffer_send(iproc)%blocks(iblk, 0) = offset
     160             :          CALL destroy_block(blk_data)
     161             :       END DO
     162             :       CALL dbt_iterator_stop(iter)
     163             : !$OMP BARRIER
     164             : 
     165             :       CALL dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
     166             : !$OMP BARRIER
     167             : 
     168             : !$OMP DO
     169             :       DO iproc = 0, numnodes - 1
     170             :          DEALLOCATE (buffer_send(iproc)%blocks, buffer_send(iproc)%data)
     171             :       END DO
     172             : !$OMP END DO
     173             : 
     174             :       nblks_recv_mythread = 0
     175             :       DO iproc = 0, numnodes - 1
     176             : !$OMP DO
     177             :          DO iblk = 1, nblks_recv_total(iproc)
     178             :             nblks_recv_mythread = nblks_recv_mythread + 1
     179             :          END DO
     180             : !$OMP END DO
     181             :       END DO
     182             :       ALLOCATE (blks_to_allocate(nblks_recv_mythread, ndims_tensor(tensor_in)))
     183             : 
     184             :       jblk = 0
     185             :       DO iproc = 0, numnodes - 1
     186             : !$OMP DO
     187             :          DO iblk = 1, nblks_recv_total(iproc)
     188             :             jblk = jblk + 1
     189             :             blks_to_allocate(jblk, :) = buffer_recv(iproc)%blocks(iblk, 1:)
     190             :          END DO
     191             : !$OMP END DO
     192             :       END DO
     193             :       CPASSERT(jblk == nblks_recv_mythread)
     194             :       CALL dbt_reserve_blocks(tensor_out, blks_to_allocate)
     195             :       DEALLOCATE (blks_to_allocate)
     196             : 
     197             :       DO iproc = 0, numnodes - 1
     198             : !$OMP DO
     199             :          DO iblk = 1, nblks_recv_total(iproc)
     200             :             ind_nd(:) = buffer_recv(iproc)%blocks(iblk, 1:)
     201             :             CALL dbt_blk_sizes(tensor_out, ind_nd, blk_size)
     202             :             offset = buffer_recv(iproc)%blocks(iblk, 0)
     203             :             ndata = PRODUCT(blk_size)
     204             :             CALL create_block(blk_data, blk_size, &
     205             :                               array=buffer_recv(iproc)%data(offset + 1:offset + ndata))
     206             :             CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
     207             :             CALL destroy_block(blk_data)
     208             :          END DO
     209             : !$OMP END DO
     210             :       END DO
     211             : 
     212             : !$OMP DO
     213             :       DO iproc = 0, numnodes - 1
     214             :          DEALLOCATE (buffer_recv(iproc)%blocks, buffer_recv(iproc)%data)
     215             :       END DO
     216             : !$OMP END DO
     217             : !$OMP END PARALLEL
     218             : 
     219      191581 :       DEALLOCATE (nblks_recv_total, ndata_recv_total)
     220      191581 :       DEALLOCATE (nblks_send_total, ndata_send_total)
     221      898575 :       DEALLOCATE (buffer_send, buffer_recv)
     222             : 
     223      191581 :       IF (move_prv) CALL dbt_clear(tensor_in)
     224             : 
     225      191581 :       CALL timestop(handle)
     226      383162 :    END SUBROUTINE dbt_reshape
     227             : 
     228             : ! **************************************************************************************************
     229             : !> \brief communicate buffer
     230             : !> \author Patrick Seewald
     231             : ! **************************************************************************************************
     232      191581 :    SUBROUTINE dbt_communicate_buffer(mp_comm, buffer_recv, buffer_send)
     233             :       TYPE(mp_comm_type), INTENT(IN)                    :: mp_comm
     234             :       TYPE(block_buffer_type), DIMENSION(0:), INTENT(INOUT) :: buffer_recv, buffer_send
     235             : 
     236             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_communicate_buffer'
     237             : 
     238             :       INTEGER                                               :: iproc, numnodes, &
     239             :                                                                rec_counter, send_counter, i
     240      191581 :       TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:, :)                 :: req_array
     241             :       INTEGER                                               :: handle
     242             : 
     243      191581 :       CALL timeset(routineN, handle)
     244      191581 :       numnodes = mp_comm%num_pe
     245             : 
     246      191581 :       IF (numnodes > 1) THEN
     247      161916 : !$OMP MASTER
     248      161916 :          send_counter = 0
     249      161916 :          rec_counter = 0
     250             : 
     251     2428740 :          ALLOCATE (req_array(1:numnodes, 4))
     252             : 
     253      485748 :          DO iproc = 0, numnodes - 1
     254     1133412 :             IF (SIZE(buffer_recv(iproc)%blocks) > 0) THEN
     255      204151 :                rec_counter = rec_counter + 1
     256      204151 :                CALL mp_comm%irecv(buffer_recv(iproc)%blocks, iproc, req_array(rec_counter, 3), tag=4)
     257      204151 :                CALL mp_comm%irecv(buffer_recv(iproc)%data, iproc, req_array(rec_counter, 4), tag=7)
     258             :             END IF
     259             :          END DO
     260             : 
     261      485748 :          DO iproc = 0, numnodes - 1
     262     1133412 :             IF (SIZE(buffer_send(iproc)%blocks) > 0) THEN
     263      204151 :                send_counter = send_counter + 1
     264      204151 :                CALL mp_comm%isend(buffer_send(iproc)%blocks, iproc, req_array(send_counter, 1), tag=4)
     265      204151 :                CALL mp_comm%isend(buffer_send(iproc)%data, iproc, req_array(send_counter, 2), tag=7)
     266             :             END IF
     267             :          END DO
     268             : 
     269      161916 :          IF (send_counter > 0) THEN
     270      143058 :             CALL mp_waitall(req_array(1:send_counter, 1:2))
     271             :          END IF
     272      161916 :          IF (rec_counter > 0) THEN
     273      135722 :             CALL mp_waitall(req_array(1:rec_counter, 3:4))
     274             :          END IF
     275             : !$OMP END MASTER
     276             : 
     277             :       ELSE
     278       29665 : !$OMP DO SCHEDULE(static, 512)
     279             :          DO i = 1, SIZE(buffer_send(0)%blocks, 1)
     280     3878925 :             buffer_recv(0)%blocks(i, :) = buffer_send(0)%blocks(i, :)
     281             :          END DO
     282             : !$OMP END DO
     283       29665 : !$OMP DO SCHEDULE(static, 512)
     284             :          DO i = 1, SIZE(buffer_send(0)%data)
     285   412866138 :             buffer_recv(0)%data(i) = buffer_send(0)%data(i)
     286             :          END DO
     287             : !$OMP END DO
     288             :       END IF
     289      191581 :       CALL timestop(handle)
     290             : 
     291      191581 :    END SUBROUTINE dbt_communicate_buffer
     292             : 
     293           0 : END MODULE dbt_reshape_ops

Generated by: LCOV version 1.15