LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4c33f95) Lines: 122 123 99.2 %
Date: 2025-01-30 06:53:08 Functions: 26 40 65.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : MODULE torch_api
       8             :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9             :                             C_BOOL, &
      10             :                             C_CHAR, &
      11             :                             C_FLOAT, &
      12             :                             C_DOUBLE, &
      13             :                             C_F_POINTER, &
      14             :                             C_INT, &
      15             :                             C_NULL_CHAR, &
      16             :                             C_NULL_PTR, &
      17             :                             C_PTR, &
      18             :                             C_INT64_T
      19             : 
      20             :    USE kinds, ONLY: sp, int_8, dp, default_string_length
      21             : 
      22             : #include "./base/base_uses.f90"
      23             : 
      24             :    IMPLICIT NONE
      25             : 
      26             :    PRIVATE
      27             : 
      28             :    TYPE torch_tensor_type
      29             :       PRIVATE
      30             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      31             :    END TYPE torch_tensor_type
      32             : 
      33             :    TYPE torch_dict_type
      34             :       PRIVATE
      35             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      36             :    END TYPE torch_dict_type
      37             : 
      38             :    TYPE torch_model_type
      39             :       PRIVATE
      40             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      41             :    END TYPE torch_model_type
      42             : 
      43             :    #:set max_dim = 3
      44             :    INTERFACE torch_tensor_from_array
      45             :       #:for ndims  in range(1, max_dim+1)
      46             :          MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
      47             :          MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
      48             :          MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
      49             :       #:endfor
      50             :    END INTERFACE torch_tensor_from_array
      51             : 
      52             :    INTERFACE torch_tensor_data_ptr
      53             :       #:for ndims  in range(1, max_dim+1)
      54             :          MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
      55             :          MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
      56             :          MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
      57             :       #:endfor
      58             :    END INTERFACE torch_tensor_data_ptr
      59             : 
      60             :    INTERFACE torch_model_get_attr
      61             :       MODULE PROCEDURE torch_model_get_attr_string
      62             :       MODULE PROCEDURE torch_model_get_attr_double
      63             :       MODULE PROCEDURE torch_model_get_attr_int64
      64             :       MODULE PROCEDURE torch_model_get_attr_int32
      65             :       MODULE PROCEDURE torch_model_get_attr_strlist
      66             :    END INTERFACE torch_model_get_attr
      67             : 
      68             :    PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
      69             :    PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_grad
      70             :    PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_insert, torch_dict_get, torch_dict_release
      71             :    PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
      72             :    PUBLIC :: torch_model_get_attr, torch_model_read_metadata
      73             :    PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
      74             : 
      75             : CONTAINS
      76             : 
      77             :    #:set typenames = ['float', 'int64', 'double']
      78             :    #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
      79             :    #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      80             : 
      81             :    #:for ndims in range(1, max_dim+1)
      82             :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
      83             : 
      84             : ! **************************************************************************************************
      85             : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
      86             : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
      87             : !> \author Ole Schuett
      88             : ! **************************************************************************************************
      89         164 :          SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
      90             :             TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
      91             :             #:set arraydims = ", ".join(":" for i in range(ndims))
      92             :             ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
      93             :             LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
      94             : 
      95             : #if defined(__LIBTORCH)
      96             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
      97             :             LOGICAL                                            :: my_req_grad
      98             : 
      99             :             INTERFACE
     100             :                SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
     101             :                   BIND(C, name="torch_c_tensor_from_array_${typename}$")
     102             :                   IMPORT :: C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
     103             :                   TYPE(C_PTR)                                  :: tensor
     104             :                   LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     105             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     106             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     107             :                   ${type_c}$, DIMENSION(*)                     :: source
     108             :                END SUBROUTINE torch_c_tensor_from_array_${typename}$
     109             :             END INTERFACE
     110             : 
     111         164 :             my_req_grad = .FALSE.
     112         164 :             IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     113             : 
     114             :             #:for axis in range(ndims)
     115         164 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     116             :             #:endfor
     117             : 
     118         164 :             CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     119             :             CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
     120             :                                                          req_grad=LOGICAL(my_req_grad, C_BOOL), &
     121             :                                                          ndims=${ndims}$, &
     122             :                                                          sizes=sizes_c, &
     123         164 :                                                          source=source)
     124         164 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     125             : #else
     126             :             CPABORT("CP2K compiled without the Torch library.")
     127             :             MARK_USED(tensor)
     128             :             MARK_USED(source)
     129             :             MARK_USED(requires_grad)
     130             : #endif
     131         164 :          END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
     132             : 
     133             : ! **************************************************************************************************
     134             : !> \brief Copies data from a Torch tensor to an array.
     135             : !>        The returned pointer is only valide during the tensor's lifetime!
     136             : !> \author Ole Schuett
     137             : ! **************************************************************************************************
     138          88 :          SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
     139             :             TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     140             :             #:set arraydims = ", ".join(":" for i in range(ndims))
     141             :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: data_ptr
     142             : 
     143             : #if defined(__LIBTORCH)
     144             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     145             :             TYPE(C_PTR)                                        :: data_ptr_c
     146             : 
     147             :             INTERFACE
     148             :                SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
     149             :                   BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
     150             :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
     151             :                   TYPE(C_PTR), VALUE                           :: tensor
     152             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     153             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     154             :                   TYPE(C_PTR)                                  :: data_ptr
     155             :                END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
     156             :             END INTERFACE
     157             : 
     158         266 :             sizes_c(:) = -1
     159          88 :             data_ptr_c = C_NULL_PTR
     160          88 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     161          88 :             CPASSERT(.NOT. ASSOCIATED(data_ptr))
     162             :             CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
     163             :                                                        ndims=${ndims}$, &
     164             :                                                        sizes=sizes_c, &
     165          88 :                                                        data_ptr=data_ptr_c)
     166             : 
     167         266 :             CPASSERT(ALL(sizes_c >= 0))
     168          88 :             CPASSERT(C_ASSOCIATED(data_ptr_c))
     169             : 
     170             :             #:for axis in range(ndims)
     171          88 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     172             :             #:endfor
     173         266 :             CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
     174             : #else
     175             :             CPABORT("CP2K compiled without the Torch library.")
     176             :             MARK_USED(tensor)
     177             :             MARK_USED(data_ptr)
     178             : #endif
     179          88 :          END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
     180             : 
     181             :       #:endfor
     182             :    #:endfor
     183             : 
     184             : ! **************************************************************************************************
     185             : !> \brief Runs autograd on a Torch tensor.
     186             : !> \author Ole Schuett
     187             : ! **************************************************************************************************
     188           6 :    SUBROUTINE torch_tensor_backward(tensor, outer_grad)
     189             :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     190             :       TYPE(torch_tensor_type), INTENT(IN)                :: outer_grad
     191             : 
     192             : #if defined(__LIBTORCH)
     193             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_tensor_backward'
     194             :       INTEGER                                            :: handle
     195             : 
     196             :       INTERFACE
     197             :          SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
     198             :             BIND(C, name="torch_c_tensor_backward")
     199             :             IMPORT :: C_CHAR, C_PTR
     200             :             TYPE(C_PTR), VALUE                           :: tensor
     201             :             TYPE(C_PTR), VALUE                           :: outer_grad
     202             :          END SUBROUTINE torch_c_tensor_backward
     203             :       END INTERFACE
     204             : 
     205           6 :       CALL timeset(routineN, handle)
     206           6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     207           6 :       CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
     208           6 :       CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
     209           6 :       CALL timestop(handle)
     210             : #else
     211             :       CPABORT("CP2K compiled without the Torch library.")
     212             :       MARK_USED(tensor)
     213             :       MARK_USED(outer_grad)
     214             : #endif
     215           6 :    END SUBROUTINE torch_tensor_backward
     216             : 
     217             : ! **************************************************************************************************
     218             : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
     219             : !> \author Ole Schuett
     220             : ! **************************************************************************************************
     221           6 :    SUBROUTINE torch_tensor_grad(tensor, grad)
     222             :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     223             :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grad
     224             : 
     225             : #if defined(__LIBTORCH)
     226             :       INTERFACE
     227             :          SUBROUTINE torch_c_tensor_grad(tensor, grad) &
     228             :             BIND(C, name="torch_c_tensor_grad")
     229             :             IMPORT :: C_PTR
     230             :             TYPE(C_PTR), VALUE                           :: tensor
     231             :             TYPE(C_PTR)                                  :: grad
     232             :          END SUBROUTINE torch_c_tensor_grad
     233             :       END INTERFACE
     234             : 
     235           6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     236           6 :       CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
     237           6 :       CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
     238           6 :       CPASSERT(C_ASSOCIATED(grad%c_ptr))
     239             : #else
     240             :       CPABORT("CP2K compiled without the Torch library.")
     241             :       MARK_USED(tensor)
     242             :       MARK_USED(grad)
     243             : #endif
     244           6 :    END SUBROUTINE torch_tensor_grad
     245             : 
     246             : ! **************************************************************************************************
     247             : !> \brief Releases a Torch tensor and all its ressources.
     248             : !> \author Ole Schuett
     249             : ! **************************************************************************************************
     250         252 :    SUBROUTINE torch_tensor_release(tensor)
     251             :       TYPE(torch_tensor_type), INTENT(INOUT)               :: tensor
     252             : 
     253             : #if defined(__LIBTORCH)
     254             :       INTERFACE
     255             :          SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
     256             :             IMPORT :: C_PTR
     257             :             TYPE(C_PTR), VALUE                        :: tensor
     258             :          END SUBROUTINE torch_c_tensor_release
     259             :       END INTERFACE
     260             : 
     261         252 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     262         252 :       CALL torch_c_tensor_release(tensor=tensor%c_ptr)
     263         252 :       tensor%c_ptr = C_NULL_PTR
     264             : #else
     265             :       CPABORT("CP2K was compiled without Torch library.")
     266             :       MARK_USED(tensor)
     267             : #endif
     268         252 :    END SUBROUTINE torch_tensor_release
     269             : 
     270             : ! **************************************************************************************************
     271             : !> \brief Creates an empty Torch dictionary.
     272             : !> \author Ole Schuett
     273             : ! **************************************************************************************************
     274         128 :    SUBROUTINE torch_dict_create(dict)
     275             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     276             : 
     277             : #if defined(__LIBTORCH)
     278             :       INTERFACE
     279             :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     280             :             IMPORT :: C_PTR
     281             :             TYPE(C_PTR)                               :: dict
     282             :          END SUBROUTINE torch_c_dict_create
     283             :       END INTERFACE
     284             : 
     285         128 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     286         128 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     287         128 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     288             : #else
     289             :       CPABORT("CP2K was compiled without Torch library.")
     290             :       MARK_USED(dict)
     291             : #endif
     292         128 :    END SUBROUTINE torch_dict_create
     293             : 
     294             : ! **************************************************************************************************
     295             : !> \brief Inserts a Torch tensor into a Torch dictionary.
     296             : !> \author Ole Schuett
     297             : ! **************************************************************************************************
     298         158 :    SUBROUTINE torch_dict_insert(dict, key, tensor)
     299             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     300             :       CHARACTER(len=*), INTENT(IN)                       :: key
     301             :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     302             : 
     303             : #if defined(__LIBTORCH)
     304             : 
     305             :       INTERFACE
     306             :          SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
     307             :             BIND(C, name="torch_c_dict_insert")
     308             :             IMPORT :: C_CHAR, C_PTR
     309             :             TYPE(C_PTR), VALUE                           :: dict
     310             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     311             :             TYPE(C_PTR), VALUE                           :: tensor
     312             :          END SUBROUTINE torch_c_dict_insert
     313             :       END INTERFACE
     314             : 
     315         158 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     316         158 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     317         158 :       CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     318             : #else
     319             :       CPABORT("CP2K compiled without the Torch library.")
     320             :       MARK_USED(dict)
     321             :       MARK_USED(key)
     322             :       MARK_USED(tensor)
     323             : #endif
     324         158 :    END SUBROUTINE torch_dict_insert
     325             : 
     326             : ! **************************************************************************************************
     327             : !> \brief Retrieves a Torch tensor from a Torch dictionary.
     328             : !> \author Ole Schuett
     329             : ! **************************************************************************************************
     330          82 :    SUBROUTINE torch_dict_get(dict, key, tensor)
     331             :       TYPE(torch_dict_type), INTENT(IN)                  :: dict
     332             :       CHARACTER(len=*), INTENT(IN)                       :: key
     333             :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     334             : 
     335             : #if defined(__LIBTORCH)
     336             : 
     337             :       INTERFACE
     338             :          SUBROUTINE torch_c_dict_get(dict, key, tensor) &
     339             :             BIND(C, name="torch_c_dict_get")
     340             :             IMPORT :: C_CHAR, C_PTR
     341             :             TYPE(C_PTR), VALUE                           :: dict
     342             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     343             :             TYPE(C_PTR)                                  :: tensor
     344             :          END SUBROUTINE torch_c_dict_get
     345             :       END INTERFACE
     346             : 
     347          82 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     348          82 :       CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     349          82 :       CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     350          82 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     351             : 
     352             : #else
     353             :       CPABORT("CP2K compiled without the Torch library.")
     354             :       MARK_USED(dict)
     355             :       MARK_USED(key)
     356             :       MARK_USED(tensor)
     357             : #endif
     358          82 :    END SUBROUTINE torch_dict_get
     359             : 
     360             : ! **************************************************************************************************
     361             : !> \brief Releases a Torch dictionary and all its ressources.
     362             : !> \author Ole Schuett
     363             : ! **************************************************************************************************
     364         128 :    SUBROUTINE torch_dict_release(dict)
     365             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     366             : 
     367             : #if defined(__LIBTORCH)
     368             :       INTERFACE
     369             :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     370             :             IMPORT :: C_PTR
     371             :             TYPE(C_PTR), VALUE                        :: dict
     372             :          END SUBROUTINE torch_c_dict_release
     373             :       END INTERFACE
     374             : 
     375         128 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     376         128 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     377         128 :       dict%c_ptr = C_NULL_PTR
     378             : #else
     379             :       CPABORT("CP2K was compiled without Torch library.")
     380             :       MARK_USED(dict)
     381             : #endif
     382         128 :    END SUBROUTINE torch_dict_release
     383             : 
     384             : ! **************************************************************************************************
     385             : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     386             : !> \author Ole Schuett
     387             : ! **************************************************************************************************
     388          18 :    SUBROUTINE torch_model_load(model, filename)
     389             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     390             :       CHARACTER(len=*), INTENT(IN)                       :: filename
     391             : 
     392             : #if defined(__LIBTORCH)
     393             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_load'
     394             :       INTEGER                                            :: handle
     395             : 
     396             :       INTERFACE
     397             :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     398             :             IMPORT :: C_PTR, C_CHAR
     399             :             TYPE(C_PTR)                               :: model
     400             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     401             :          END SUBROUTINE torch_c_model_load
     402             :       END INTERFACE
     403             : 
     404          18 :       CALL timeset(routineN, handle)
     405          18 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     406          18 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     407          18 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     408          18 :       CALL timestop(handle)
     409             : #else
     410             :       CPABORT("CP2K was compiled without Torch library.")
     411             :       MARK_USED(model)
     412             :       MARK_USED(filename)
     413             : #endif
     414          18 :    END SUBROUTINE torch_model_load
     415             : 
     416             : ! **************************************************************************************************
     417             : !> \brief Evaluates the given Torch model.
     418             : !> \author Ole Schuett
     419             : ! **************************************************************************************************
     420          64 :    SUBROUTINE torch_model_forward(model, inputs, outputs)
     421             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     422             :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     423             :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     424             : 
     425             : #if defined(__LIBTORCH)
     426             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward'
     427             :       INTEGER                                            :: handle
     428             : 
     429             :       INTERFACE
     430             :          SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
     431             :             IMPORT :: C_PTR
     432             :             TYPE(C_PTR), VALUE                        :: model
     433             :             TYPE(C_PTR), VALUE                        :: inputs
     434             :             TYPE(C_PTR), VALUE                        :: outputs
     435             :          END SUBROUTINE torch_c_model_forward
     436             :       END INTERFACE
     437             : 
     438          64 :       CALL timeset(routineN, handle)
     439          64 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     440          64 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     441          64 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     442          64 :       CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
     443          64 :       CALL timestop(handle)
     444             : #else
     445             :       CPABORT("CP2K was compiled without Torch library.")
     446             :       MARK_USED(model)
     447             :       MARK_USED(inputs)
     448             :       MARK_USED(outputs)
     449             : #endif
     450          64 :    END SUBROUTINE torch_model_forward
     451             : 
     452             : ! **************************************************************************************************
     453             : !> \brief Releases a Torch model and all its ressources.
     454             : !> \author Ole Schuett
     455             : ! **************************************************************************************************
     456          18 :    SUBROUTINE torch_model_release(model)
     457             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     458             : 
     459             : #if defined(__LIBTORCH)
     460             :       INTERFACE
     461             :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     462             :             IMPORT :: C_PTR
     463             :             TYPE(C_PTR), VALUE                        :: model
     464             :          END SUBROUTINE torch_c_model_release
     465             :       END INTERFACE
     466             : 
     467          18 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     468          18 :       CALL torch_c_model_release(model=model%c_ptr)
     469          18 :       model%c_ptr = C_NULL_PTR
     470             : #else
     471             :       CPABORT("CP2K was compiled without Torch library.")
     472             :       MARK_USED(model)
     473             : #endif
     474          18 :    END SUBROUTINE torch_model_release
     475             : 
     476             : ! **************************************************************************************************
     477             : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     478             : !> \author Ole Schuett
     479             : ! **************************************************************************************************
     480          52 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     481             :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     482             :       CHARACTER(:), ALLOCATABLE                           :: res
     483             : 
     484             : #if defined(__LIBTORCH)
     485             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_read_metadata'
     486             :       INTEGER                                            :: handle
     487             : 
     488             :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     489          52 :          POINTER                                         :: content_f
     490             :       INTEGER                                            :: i
     491             :       INTEGER                                            :: length
     492             :       TYPE(C_PTR)                                        :: content_c
     493             : 
     494             :       INTERFACE
     495             :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     496             :             BIND(C, name="torch_c_model_read_metadata")
     497             :             IMPORT :: C_CHAR, C_PTR, C_INT
     498             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     499             :             TYPE(C_PTR)                               :: content
     500             :             INTEGER(kind=C_INT)                       :: length
     501             :          END SUBROUTINE torch_c_model_read_metadata
     502             :       END INTERFACE
     503             : 
     504          52 :       CALL timeset(routineN, handle)
     505          52 :       content_c = C_NULL_PTR
     506          52 :       length = -1
     507             :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     508             :                                        key=TRIM(key)//C_NULL_CHAR, &
     509             :                                        content=content_c, &
     510          52 :                                        length=length)
     511          52 :       CPASSERT(C_ASSOCIATED(content_c))
     512          52 :       CPASSERT(length >= 0)
     513             : 
     514         104 :       CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
     515          52 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     516             : 
     517          52 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     518         278 :       DO i = 1, length
     519         226 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     520         278 :          res(i:i) = content_f(i)
     521             :       END DO
     522             : 
     523          52 :       DEALLOCATE (content_f) ! Was allocated on the C side.
     524          52 :       CALL timestop(handle)
     525             : #else
     526             :       CPABORT("CP2K was compiled without Torch library.")
     527             :       MARK_USED(filename)
     528             :       MARK_USED(key)
     529             :       MARK_USED(res)
     530             : #endif
     531          52 :    END FUNCTION torch_model_read_metadata
     532             : 
     533             : ! **************************************************************************************************
     534             : !> \brief Returns true iff the Torch CUDA backend is available.
     535             : !> \author Ole Schuett
     536             : ! **************************************************************************************************
     537           2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     538             :       LOGICAL                                            :: res
     539             : 
     540             : #if defined(__LIBTORCH)
     541             :       INTERFACE
     542             :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     543             :             IMPORT :: C_BOOL
     544             :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     545             :          END FUNCTION torch_c_cuda_is_available
     546             :       END INTERFACE
     547             : 
     548           2 :       res = torch_c_cuda_is_available()
     549             : #else
     550             :       CPABORT("CP2K was compiled without Torch library.")
     551             :       res = .FALSE.
     552             : #endif
     553           2 :    END FUNCTION torch_cuda_is_available
     554             : 
     555             : ! **************************************************************************************************
     556             : !> \brief Set whether to allow the use of TF32.
     557             : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     558             : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     559             : !> \author Gabriele Tocci
     560             : ! **************************************************************************************************
     561           8 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     562             :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     563             : 
     564             : #if defined(__LIBTORCH)
     565             :       INTERFACE
     566             :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     567             :             IMPORT :: C_BOOL
     568             :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     569             :          END SUBROUTINE torch_c_allow_tf32
     570             :       END INTERFACE
     571             : 
     572           8 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     573             : #else
     574             :       CPABORT("CP2K was compiled without Torch library.")
     575             :       MARK_USED(allow_tf32)
     576             : #endif
     577           8 :    END SUBROUTINE torch_allow_tf32
     578             : 
     579             : ! **************************************************************************************************
     580             : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     581             : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     582             : !> \author Gabriele Tocci
     583             : ! **************************************************************************************************
     584           8 :    SUBROUTINE torch_model_freeze(model)
     585             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     586             : 
     587             : #if defined(__LIBTORCH)
     588             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_freeze'
     589             :       INTEGER                                            :: handle
     590             : 
     591             :       INTERFACE
     592             :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     593             :             IMPORT :: C_PTR
     594             :             TYPE(C_PTR), VALUE                        :: model
     595             :          END SUBROUTINE torch_c_model_freeze
     596             :       END INTERFACE
     597             : 
     598           8 :       CALL timeset(routineN, handle)
     599           8 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     600           8 :       CALL torch_c_model_freeze(model=model%c_ptr)
     601           8 :       CALL timestop(handle)
     602             : #else
     603             :       CPABORT("CP2K was compiled without Torch library.")
     604             :       MARK_USED(model)
     605             : #endif
     606           8 :    END SUBROUTINE torch_model_freeze
     607             : 
     608             :    #:set typenames = ['int64', 'double', 'string']
     609             :    #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
     610             :    #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
     611             : 
     612             :    #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
     613             : ! **************************************************************************************************
     614             : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     615             : !> \author Ole Schuett
     616             : ! **************************************************************************************************
     617          64 :       SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
     618             :          TYPE(torch_model_type), INTENT(IN)                 :: model
     619             :          CHARACTER(len=*), INTENT(IN)                       :: key
     620             :          ${type_f}$, INTENT(OUT)                            :: dest
     621             : 
     622             : #if defined(__LIBTORCH)
     623             : 
     624             :          INTERFACE
     625             :             SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
     626             :                BIND(C, name="torch_c_model_get_attr_${typename}$")
     627             :                IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
     628             :                TYPE(C_PTR), VALUE                           :: model
     629             :                CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     630             :                ${type_c}$                                   :: dest
     631             :             END SUBROUTINE torch_c_model_get_attr_${typename}$
     632             :          END INTERFACE
     633             : 
     634             :          CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
     635             :                                                    key=TRIM(key)//C_NULL_CHAR, &
     636          64 :                                                    dest=dest)
     637             : #else
     638             :          CPABORT("CP2K compiled without the Torch library.")
     639             :          MARK_USED(model)
     640             :          MARK_USED(key)
     641             :          MARK_USED(dest)
     642             : #endif
     643          64 :       END SUBROUTINE torch_model_get_attr_${typename}$
     644             :    #:endfor
     645             : 
     646             : ! **************************************************************************************************
     647             : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     648             : !> \author Ole Schuett
     649             : ! **************************************************************************************************
     650          40 :    SUBROUTINE torch_model_get_attr_int32(model, key, dest)
     651             :       TYPE(torch_model_type), INTENT(IN)                 :: model
     652             :       CHARACTER(len=*), INTENT(IN)                       :: key
     653             :       INTEGER, INTENT(OUT)                               :: dest
     654             : 
     655             :       INTEGER(kind=int_8)                                :: temp
     656          40 :       CALL torch_model_get_attr_int64(model, key, temp)
     657          40 :       CPASSERT(ABS(temp) < HUGE(dest))
     658          40 :       dest = INT(temp)
     659          40 :    END SUBROUTINE torch_model_get_attr_int32
     660             : 
     661             : ! **************************************************************************************************
     662             : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
     663             : !> \author Ole Schuett
     664             : ! **************************************************************************************************
     665           8 :    SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
     666             :       TYPE(torch_model_type), INTENT(IN)                 :: model
     667             :       CHARACTER(len=*), INTENT(IN)                       :: key
     668             :       CHARACTER(LEN=default_string_length), &
     669             :          ALLOCATABLE, DIMENSION(:)                       :: dest
     670             : 
     671             : #if defined(__LIBTORCH)
     672             : 
     673             :       INTEGER :: num_items, i
     674             : 
     675             :       INTERFACE
     676             :          SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
     677             :             BIND(C, name="torch_c_model_get_attr_list_size")
     678             :             IMPORT :: C_PTR, C_CHAR, C_INT
     679             :             TYPE(C_PTR), VALUE                           :: model
     680             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     681             :             INTEGER(kind=C_INT)                          :: size
     682             :          END SUBROUTINE torch_c_model_get_attr_list_size
     683             :       END INTERFACE
     684             : 
     685             :       INTERFACE
     686             :          SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
     687             :             BIND(C, name="torch_c_model_get_attr_strlist")
     688             :             IMPORT :: C_PTR, C_CHAR, C_INT
     689             :             TYPE(C_PTR), VALUE                           :: model
     690             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     691             :             INTEGER(kind=C_INT), VALUE                   :: index
     692             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: dest
     693             :          END SUBROUTINE torch_c_model_get_attr_strlist
     694             :       END INTERFACE
     695             : 
     696             :       CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
     697             :                                             key=TRIM(key)//C_NULL_CHAR, &
     698           8 :                                             size=num_items)
     699          24 :       ALLOCATE (dest(num_items))
     700          24 :       dest(:) = ""
     701             : 
     702          24 :       DO i = 1, num_items
     703             :          CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
     704             :                                              key=TRIM(key)//C_NULL_CHAR, &
     705             :                                              index=i - 1, &
     706          24 :                                              dest=dest(i))
     707             : 
     708             :       END DO
     709             : #else
     710             :       CPABORT("CP2K compiled without the Torch library.")
     711             :       MARK_USED(model)
     712             :       MARK_USED(key)
     713             :       MARK_USED(dest)
     714             : #endif
     715             : 
     716           8 :    END SUBROUTINE torch_model_get_attr_strlist
     717             : 
     718           0 : END MODULE torch_api

Generated by: LCOV version 1.15