LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 82 83 98.8 %
Date: 2024-11-21 06:45:46 Functions: 21 34 61.8 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : MODULE torch_api
       8             :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9             :                             C_BOOL, &
      10             :                             C_CHAR, &
      11             :                             C_FLOAT, &
      12             :                             C_DOUBLE, &
      13             :                             C_F_POINTER, &
      14             :                             C_INT, &
      15             :                             C_NULL_CHAR, &
      16             :                             C_NULL_PTR, &
      17             :                             C_PTR, &
      18             :                             C_INT64_T
      19             : 
      20             :    USE kinds, ONLY: sp, int_8, dp, default_string_length
      21             : 
      22             : #include "./base/base_uses.f90"
      23             : 
      24             :    IMPLICIT NONE
      25             : 
      26             :    PRIVATE
      27             : 
      28             :    TYPE torch_dict_type
      29             :       PRIVATE
      30             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      31             :    END TYPE torch_dict_type
      32             : 
      33             :    TYPE torch_model_type
      34             :       PRIVATE
      35             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      36             :    END TYPE torch_model_type
      37             : 
      38             :    #:set max_dim = 3
      39             :    INTERFACE torch_dict_insert
      40             :       #:for ndims  in range(1, max_dim+1)
      41             :          MODULE PROCEDURE torch_dict_insert_float_${ndims}$d
      42             :          MODULE PROCEDURE torch_dict_insert_int64_${ndims}$d
      43             :          MODULE PROCEDURE torch_dict_insert_double_${ndims}$d
      44             :       #:endfor
      45             :    END INTERFACE torch_dict_insert
      46             : 
      47             :    INTERFACE torch_dict_get
      48             :       #:for ndims  in range(1, max_dim+1)
      49             :          MODULE PROCEDURE torch_dict_get_float_${ndims}$d
      50             :          MODULE PROCEDURE torch_dict_get_int64_${ndims}$d
      51             :          MODULE PROCEDURE torch_dict_get_double_${ndims}$d
      52             :       #:endfor
      53             :    END INTERFACE torch_dict_get
      54             : 
      55             :    INTERFACE torch_model_get_attr
      56             :       MODULE PROCEDURE torch_model_get_attr_string
      57             :       MODULE PROCEDURE torch_model_get_attr_double
      58             :       MODULE PROCEDURE torch_model_get_attr_int64
      59             :       MODULE PROCEDURE torch_model_get_attr_int32
      60             :       MODULE PROCEDURE torch_model_get_attr_strlist
      61             :    END INTERFACE torch_model_get_attr
      62             : 
      63             :    PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_release
      64             :    PUBLIC :: torch_dict_insert, torch_dict_get
      65             :    PUBLIC :: torch_model_type, torch_model_load, torch_model_eval, torch_model_release
      66             :    PUBLIC :: torch_model_get_attr, torch_model_read_metadata
      67             :    PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
      68             : 
      69             : CONTAINS
      70             : 
      71             :    #:set typenames = ['float', 'int64', 'double']
      72             :    #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
      73             :    #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      74             : 
      75             :    #:for ndims in range(1, max_dim+1)
      76             :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
      77             : 
      78             : ! **************************************************************************************************
      79             : !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
      80             : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
      81             : !> \author Ole Schuett
      82             : ! **************************************************************************************************
      83          62 :          SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d(dict, key, source)
      84             :             TYPE(torch_dict_type), INTENT(INOUT)               :: dict
      85             :             CHARACTER(len=*), INTENT(IN)                       :: key
      86             :             #:set arraydims = ", ".join(":" for i in range(ndims))
      87             :             ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
      88             : 
      89             : #if defined(__LIBTORCH)
      90             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
      91             : 
      92             :             INTERFACE
      93             :                SUBROUTINE torch_c_dict_insert_${typename}$ (dict, key, ndims, sizes, source) &
      94             :                   BIND(C, name="torch_c_dict_insert_${typename}$")
      95             :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE
      96             :                   TYPE(C_PTR), VALUE                           :: dict
      97             :                   CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
      98             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
      99             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     100             :                   ${type_c}$, DIMENSION(*)                     :: source
     101             :                END SUBROUTINE torch_c_dict_insert_${typename}$
     102             :             END INTERFACE
     103             : 
     104             :             #:for axis in range(ndims)
     105          62 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     106             :             #:endfor
     107             : 
     108          62 :             CPASSERT(C_ASSOCIATED(dict%c_ptr))
     109             :             CALL torch_c_dict_insert_${typename}$ (dict=dict%c_ptr, &
     110             :                                                    key=TRIM(key)//C_NULL_CHAR, &
     111             :                                                    ndims=${ndims}$, &
     112             :                                                    sizes=sizes_c, &
     113          62 :                                                    source=source)
     114             : #else
     115             :             CPABORT("CP2K compiled without the Torch library.")
     116             :             MARK_USED(dict)
     117             :             MARK_USED(key)
     118             :             MARK_USED(source)
     119             : #endif
     120          62 :          END SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d
     121             : 
     122             : ! **************************************************************************************************
     123             : !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
     124             : !> \author Ole Schuett
     125             : ! **************************************************************************************************
     126          34 :          SUBROUTINE torch_dict_get_${typename}$_${ndims}$d(dict, key, dest)
     127             :             TYPE(torch_dict_type), INTENT(IN)                  :: dict
     128             :             CHARACTER(len=*), INTENT(IN)                       :: key
     129             :             #:set arraydims = ", ".join(":" for i in range(ndims))
     130             :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: dest
     131             : 
     132             : #if defined(__LIBTORCH)
     133             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     134             :             TYPE(C_PTR)                                        :: dest_c
     135             : 
     136             :             INTERFACE
     137             :                SUBROUTINE torch_c_dict_get_${typename}$ (dict, key, ndims, sizes, dest) &
     138             :                   BIND(C, name="torch_c_dict_get_${typename}$")
     139             :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
     140             :                   TYPE(C_PTR), VALUE                           :: dict
     141             :                   CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     142             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     143             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     144             :                   TYPE(C_PTR)                                  :: dest
     145             :                END SUBROUTINE torch_c_dict_get_${typename}$
     146             :             END INTERFACE
     147             : 
     148         104 :             sizes_c(:) = -1
     149          34 :             dest_c = C_NULL_PTR
     150          34 :             CPASSERT(C_ASSOCIATED(dict%c_ptr))
     151          34 :             CPASSERT(.NOT. ASSOCIATED(dest))
     152             :             CALL torch_c_dict_get_${typename}$ (dict=dict%c_ptr, &
     153             :                                                 key=TRIM(key)//C_NULL_CHAR, &
     154             :                                                 ndims=${ndims}$, &
     155             :                                                 sizes=sizes_c, &
     156          34 :                                                 dest=dest_c)
     157             : 
     158         104 :             CPASSERT(ALL(sizes_c >= 0))
     159          34 :             CPASSERT(C_ASSOCIATED(dest_c))
     160             : 
     161             :             #:for axis in range(ndims)
     162          34 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     163             :             #:endfor
     164         104 :             CALL C_F_POINTER(dest_c, dest, shape=sizes_f)
     165             : #else
     166             :             CPABORT("CP2K compiled without the Torch library.")
     167             :             MARK_USED(dict)
     168             :             MARK_USED(key)
     169             :             MARK_USED(dest)
     170             : #endif
     171          34 :          END SUBROUTINE torch_dict_get_${typename}$_${ndims}$d
     172             : 
     173             :       #:endfor
     174             :    #:endfor
     175             : 
     176             : ! **************************************************************************************************
     177             : !> \brief Creates an empty Torch dictionary.
     178             : !> \author Ole Schuett
     179             : ! **************************************************************************************************
     180          32 :    SUBROUTINE torch_dict_create(dict)
     181             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     182             : 
     183             : #if defined(__LIBTORCH)
     184             :       INTERFACE
     185             :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     186             :             IMPORT :: C_PTR
     187             :             TYPE(C_PTR)                               :: dict
     188             :          END SUBROUTINE torch_c_dict_create
     189             :       END INTERFACE
     190             : 
     191          32 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     192          32 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     193          32 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     194             : #else
     195             :       CPABORT("CP2K was compiled without Torch library.")
     196             :       MARK_USED(dict)
     197             : #endif
     198          32 :    END SUBROUTINE torch_dict_create
     199             : 
     200             : ! **************************************************************************************************
     201             : !> \brief Releases a Torch dictionary and all its ressources.
     202             : !> \author Ole Schuett
     203             : ! **************************************************************************************************
     204          32 :    SUBROUTINE torch_dict_release(dict)
     205             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     206             : 
     207             : #if defined(__LIBTORCH)
     208             :       INTERFACE
     209             :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     210             :             IMPORT :: C_PTR
     211             :             TYPE(C_PTR), VALUE                        :: dict
     212             :          END SUBROUTINE torch_c_dict_release
     213             :       END INTERFACE
     214             : 
     215          32 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     216          32 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     217          32 :       dict%c_ptr = C_NULL_PTR
     218             : #else
     219             :       CPABORT("CP2K was compiled without Torch library.")
     220             :       MARK_USED(dict)
     221             : #endif
     222          32 :    END SUBROUTINE torch_dict_release
     223             : 
     224             : ! **************************************************************************************************
     225             : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     226             : !> \author Ole Schuett
     227             : ! **************************************************************************************************
     228          14 :    SUBROUTINE torch_model_load(model, filename)
     229             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     230             :       CHARACTER(len=*), INTENT(IN)                       :: filename
     231             : 
     232             : #if defined(__LIBTORCH)
     233             :       INTERFACE
     234             :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     235             :             IMPORT :: C_PTR, C_CHAR
     236             :             TYPE(C_PTR)                               :: model
     237             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     238             :          END SUBROUTINE torch_c_model_load
     239             :       END INTERFACE
     240             : 
     241          14 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     242          14 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     243          14 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     244             : #else
     245             :       CPABORT("CP2K was compiled without Torch library.")
     246             :       MARK_USED(model)
     247             :       MARK_USED(filename)
     248             : #endif
     249          14 :    END SUBROUTINE torch_model_load
     250             : 
     251             : ! **************************************************************************************************
     252             : !> \brief Evaluates the given Torch model. (In Torch lingo this operation is called forward())
     253             : !> \author Ole Schuett
     254             : ! **************************************************************************************************
     255          16 :    SUBROUTINE torch_model_eval(model, inputs, outputs)
     256             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     257             :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     258             :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     259             : 
     260             : #if defined(__LIBTORCH)
     261             :       INTERFACE
     262             :          SUBROUTINE torch_c_model_eval(model, inputs, outputs) BIND(C, name="torch_c_model_eval")
     263             :             IMPORT :: C_PTR
     264             :             TYPE(C_PTR), VALUE                        :: model
     265             :             TYPE(C_PTR), VALUE                        :: inputs
     266             :             TYPE(C_PTR), VALUE                        :: outputs
     267             :          END SUBROUTINE torch_c_model_eval
     268             :       END INTERFACE
     269             : 
     270          16 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     271          16 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     272          16 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     273             :       CALL torch_c_model_eval(model=model%c_ptr, &
     274             :                               inputs=inputs%c_ptr, &
     275          16 :                               outputs=outputs%c_ptr)
     276             : #else
     277             :       CPABORT("CP2K was compiled without Torch library.")
     278             :       MARK_USED(model)
     279             :       MARK_USED(inputs)
     280             :       MARK_USED(outputs)
     281             : #endif
     282          16 :    END SUBROUTINE torch_model_eval
     283             : 
     284             : ! **************************************************************************************************
     285             : !> \brief Releases a Torch model and all its ressources.
     286             : !> \author Ole Schuett
     287             : ! **************************************************************************************************
     288          14 :    SUBROUTINE torch_model_release(model)
     289             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     290             : 
     291             : #if defined(__LIBTORCH)
     292             :       INTERFACE
     293             :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     294             :             IMPORT :: C_PTR
     295             :             TYPE(C_PTR), VALUE                        :: model
     296             :          END SUBROUTINE torch_c_model_release
     297             :       END INTERFACE
     298             : 
     299          14 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     300          14 :       CALL torch_c_model_release(model=model%c_ptr)
     301          14 :       model%c_ptr = C_NULL_PTR
     302             : #else
     303             :       CPABORT("CP2K was compiled without Torch library.")
     304             :       MARK_USED(model)
     305             : #endif
     306          14 :    END SUBROUTINE torch_model_release
     307             : 
     308             : ! **************************************************************************************************
     309             : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     310             : !> \author Ole Schuett
     311             : ! **************************************************************************************************
     312         140 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     313             :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     314             :       CHARACTER(:), ALLOCATABLE                           :: res
     315             : 
     316             : #if defined(__LIBTORCH)
     317             :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     318         140 :          POINTER                                         :: content_f
     319             :       INTEGER                                            :: i
     320             :       INTEGER                                            :: length
     321             :       TYPE(C_PTR)                                        :: content_c
     322             : 
     323             :       INTERFACE
     324             :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     325             :             BIND(C, name="torch_c_model_read_metadata")
     326             :             IMPORT :: C_CHAR, C_PTR, C_INT
     327             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     328             :             TYPE(C_PTR)                               :: content
     329             :             INTEGER(kind=C_INT)                       :: length
     330             :          END SUBROUTINE torch_c_model_read_metadata
     331             :       END INTERFACE
     332             : 
     333         140 :       content_c = C_NULL_PTR
     334         140 :       length = -1
     335             :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     336             :                                        key=TRIM(key)//C_NULL_CHAR, &
     337             :                                        content=content_c, &
     338         140 :                                        length=length)
     339         140 :       CPASSERT(C_ASSOCIATED(content_c))
     340         140 :       CPASSERT(length >= 0)
     341             : 
     342         280 :       CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
     343         140 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     344             : 
     345         140 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     346         798 :       DO i = 1, length
     347         658 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     348         798 :          res(i:i) = content_f(i)
     349             :       END DO
     350             : 
     351         140 :       DEALLOCATE (content_f) ! Was allocated on the C side.
     352             : #else
     353             :       CPABORT("CP2K was compiled without Torch library.")
     354             :       MARK_USED(filename)
     355             :       MARK_USED(key)
     356             :       MARK_USED(res)
     357             : #endif
     358         140 :    END FUNCTION torch_model_read_metadata
     359             : 
     360             : ! **************************************************************************************************
     361             : !> \brief Returns true iff the Torch CUDA backend is available.
     362             : !> \author Ole Schuett
     363             : ! **************************************************************************************************
     364           2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     365             :       LOGICAL                                            :: res
     366             : 
     367             : #if defined(__LIBTORCH)
     368             :       INTERFACE
     369             :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     370             :             IMPORT :: C_BOOL
     371             :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     372             :          END FUNCTION torch_c_cuda_is_available
     373             :       END INTERFACE
     374             : 
     375           2 :       res = torch_c_cuda_is_available()
     376             : #else
     377             :       CPABORT("CP2K was compiled without Torch library.")
     378             :       res = .FALSE.
     379             : #endif
     380           2 :    END FUNCTION torch_cuda_is_available
     381             : 
     382             : ! **************************************************************************************************
     383             : !> \brief Set whether to allow the use of TF32.
     384             : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     385             : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     386             : !> \author Gabriele Tocci
     387             : ! **************************************************************************************************
     388          20 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     389             :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     390             : 
     391             : #if defined(__LIBTORCH)
     392             :       INTERFACE
     393             :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     394             :             IMPORT :: C_BOOL
     395             :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     396             :          END SUBROUTINE torch_c_allow_tf32
     397             :       END INTERFACE
     398             : 
     399          20 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     400             : #else
     401             :       CPABORT("CP2K was compiled without Torch library.")
     402             :       MARK_USED(allow_tf32)
     403             : #endif
     404          20 :    END SUBROUTINE torch_allow_tf32
     405             : 
     406             : ! **************************************************************************************************
     407             : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     408             : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     409             : !> \author Gabriele Tocci
     410             : ! **************************************************************************************************
     411           8 :    SUBROUTINE torch_model_freeze(model)
     412             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     413             : 
     414             : #if defined(__LIBTORCH)
     415             :       INTERFACE
     416             :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     417             :             IMPORT :: C_PTR
     418             :             TYPE(C_PTR), VALUE                        :: model
     419             :          END SUBROUTINE torch_c_model_freeze
     420             :       END INTERFACE
     421             : 
     422           8 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     423           8 :       CALL torch_c_model_freeze(model=model%c_ptr)
     424             : #else
     425             :       CPABORT("CP2K was compiled without Torch library.")
     426             :       MARK_USED(model)
     427             : #endif
     428           8 :    END SUBROUTINE torch_model_freeze
     429             : 
     430             :    #:set typenames = ['int64', 'double', 'string']
     431             :    #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
     432             :    #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
     433             : 
     434             :    #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
     435             : ! **************************************************************************************************
     436             : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     437             : !> \author Ole Schuett
     438             : ! **************************************************************************************************
     439          32 :       SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
     440             :          TYPE(torch_model_type), INTENT(IN)                 :: model
     441             :          CHARACTER(len=*), INTENT(IN)                       :: key
     442             :          ${type_f}$, INTENT(OUT)                            :: dest
     443             : 
     444             : #if defined(__LIBTORCH)
     445             : 
     446             :          INTERFACE
     447             :             SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
     448             :                BIND(C, name="torch_c_model_get_attr_${typename}$")
     449             :                IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
     450             :                TYPE(C_PTR), VALUE                           :: model
     451             :                CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     452             :                ${type_c}$                                   :: dest
     453             :             END SUBROUTINE torch_c_model_get_attr_${typename}$
     454             :          END INTERFACE
     455             : 
     456             :          CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
     457             :                                                    key=TRIM(key)//C_NULL_CHAR, &
     458          32 :                                                    dest=dest)
     459             : #else
     460             :          CPABORT("CP2K compiled without the Torch library.")
     461             :          MARK_USED(model)
     462             :          MARK_USED(key)
     463             :          MARK_USED(dest)
     464             : #endif
     465          32 :       END SUBROUTINE torch_model_get_attr_${typename}$
     466             :    #:endfor
     467             : 
     468             : ! **************************************************************************************************
     469             : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     470             : !> \author Ole Schuett
     471             : ! **************************************************************************************************
     472          20 :    SUBROUTINE torch_model_get_attr_int32(model, key, dest)
     473             :       TYPE(torch_model_type), INTENT(IN)                 :: model
     474             :       CHARACTER(len=*), INTENT(IN)                       :: key
     475             :       INTEGER, INTENT(OUT)                               :: dest
     476             : 
     477             :       INTEGER(kind=int_8)                                :: temp
     478          20 :       CALL torch_model_get_attr_int64(model, key, temp)
     479          20 :       CPASSERT(ABS(temp) < HUGE(dest))
     480          20 :       dest = INT(temp)
     481          20 :    END SUBROUTINE torch_model_get_attr_int32
     482             : 
     483             : ! **************************************************************************************************
     484             : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
     485             : !> \author Ole Schuett
     486             : ! **************************************************************************************************
     487           4 :    SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
     488             :       TYPE(torch_model_type), INTENT(IN)                 :: model
     489             :       CHARACTER(len=*), INTENT(IN)                       :: key
     490             :       CHARACTER(LEN=default_string_length), &
     491             :          ALLOCATABLE, DIMENSION(:)                       :: dest
     492             : 
     493             : #if defined(__LIBTORCH)
     494             : 
     495             :       INTEGER :: num_items, i
     496             : 
     497             :       INTERFACE
     498             :          SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
     499             :             BIND(C, name="torch_c_model_get_attr_list_size")
     500             :             IMPORT :: C_PTR, C_CHAR, C_INT
     501             :             TYPE(C_PTR), VALUE                           :: model
     502             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     503             :             INTEGER(kind=C_INT)                          :: size
     504             :          END SUBROUTINE torch_c_model_get_attr_list_size
     505             :       END INTERFACE
     506             : 
     507             :       INTERFACE
     508             :          SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
     509             :             BIND(C, name="torch_c_model_get_attr_strlist")
     510             :             IMPORT :: C_PTR, C_CHAR, C_INT
     511             :             TYPE(C_PTR), VALUE                           :: model
     512             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     513             :             INTEGER(kind=C_INT), VALUE                   :: index
     514             :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: dest
     515             :          END SUBROUTINE torch_c_model_get_attr_strlist
     516             :       END INTERFACE
     517             : 
     518             :       CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
     519             :                                             key=TRIM(key)//C_NULL_CHAR, &
     520           4 :                                             size=num_items)
     521          12 :       ALLOCATE (dest(num_items))
     522          12 :       dest(:) = ""
     523             : 
     524          12 :       DO i = 1, num_items
     525             :          CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
     526             :                                              key=TRIM(key)//C_NULL_CHAR, &
     527             :                                              index=i - 1, &
     528          12 :                                              dest=dest(i))
     529             : 
     530             :       END DO
     531             : #else
     532             :       CPABORT("CP2K compiled without the Torch library.")
     533             :       MARK_USED(model)
     534             :       MARK_USED(key)
     535             :       MARK_USED(dest)
     536             : #endif
     537             : 
     538           4 :    END SUBROUTINE torch_model_get_attr_strlist
     539             : 
     540           0 : END MODULE torch_api

Generated by: LCOV version 1.15