Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT64_T
19 :
20 : USE kinds, ONLY: sp, int_8, dp, default_string_length
21 :
22 : #include "./base/base_uses.f90"
23 :
24 : IMPLICIT NONE
25 :
26 : PRIVATE
27 :
28 : TYPE torch_dict_type
29 : PRIVATE
30 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
31 : END TYPE torch_dict_type
32 :
33 : TYPE torch_model_type
34 : PRIVATE
35 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
36 : END TYPE torch_model_type
37 :
38 : #:set max_dim = 3
39 : INTERFACE torch_dict_insert
40 : #:for ndims in range(1, max_dim+1)
41 : MODULE PROCEDURE torch_dict_insert_float_${ndims}$d
42 : MODULE PROCEDURE torch_dict_insert_int64_${ndims}$d
43 : MODULE PROCEDURE torch_dict_insert_double_${ndims}$d
44 : #:endfor
45 : END INTERFACE torch_dict_insert
46 :
47 : INTERFACE torch_dict_get
48 : #:for ndims in range(1, max_dim+1)
49 : MODULE PROCEDURE torch_dict_get_float_${ndims}$d
50 : MODULE PROCEDURE torch_dict_get_int64_${ndims}$d
51 : MODULE PROCEDURE torch_dict_get_double_${ndims}$d
52 : #:endfor
53 : END INTERFACE torch_dict_get
54 :
55 : INTERFACE torch_model_get_attr
56 : MODULE PROCEDURE torch_model_get_attr_string
57 : MODULE PROCEDURE torch_model_get_attr_double
58 : MODULE PROCEDURE torch_model_get_attr_int64
59 : MODULE PROCEDURE torch_model_get_attr_int32
60 : MODULE PROCEDURE torch_model_get_attr_strlist
61 : END INTERFACE torch_model_get_attr
62 :
63 : PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_release
64 : PUBLIC :: torch_dict_insert, torch_dict_get
65 : PUBLIC :: torch_model_type, torch_model_load, torch_model_eval, torch_model_release
66 : PUBLIC :: torch_model_get_attr, torch_model_read_metadata
67 : PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
68 :
69 : CONTAINS
70 :
71 : #:set typenames = ['float', 'int64', 'double']
72 : #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
73 : #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
74 :
75 : #:for ndims in range(1, max_dim+1)
76 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
77 :
78 : ! **************************************************************************************************
79 : !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
80 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
81 : !> \author Ole Schuett
82 : ! **************************************************************************************************
83 62 : SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d(dict, key, source)
84 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
85 : CHARACTER(len=*), INTENT(IN) :: key
86 : #:set arraydims = ", ".join(":" for i in range(ndims))
87 : ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
88 :
89 : #if defined(__LIBTORCH)
90 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
91 :
92 : INTERFACE
93 : SUBROUTINE torch_c_dict_insert_${typename}$ (dict, key, ndims, sizes, source) &
94 : BIND(C, name="torch_c_dict_insert_${typename}$")
95 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE
96 : TYPE(C_PTR), VALUE :: dict
97 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
98 : INTEGER(kind=C_INT), VALUE :: ndims
99 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
100 : ${type_c}$, DIMENSION(*) :: source
101 : END SUBROUTINE torch_c_dict_insert_${typename}$
102 : END INTERFACE
103 :
104 : #:for axis in range(ndims)
105 62 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
106 : #:endfor
107 :
108 62 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
109 : CALL torch_c_dict_insert_${typename}$ (dict=dict%c_ptr, &
110 : key=TRIM(key)//C_NULL_CHAR, &
111 : ndims=${ndims}$, &
112 : sizes=sizes_c, &
113 62 : source=source)
114 : #else
115 : CPABORT("CP2K compiled without the Torch library.")
116 : MARK_USED(dict)
117 : MARK_USED(key)
118 : MARK_USED(source)
119 : #endif
120 62 : END SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d
121 :
122 : ! **************************************************************************************************
123 : !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
124 : !> \author Ole Schuett
125 : ! **************************************************************************************************
126 34 : SUBROUTINE torch_dict_get_${typename}$_${ndims}$d(dict, key, dest)
127 : TYPE(torch_dict_type), INTENT(IN) :: dict
128 : CHARACTER(len=*), INTENT(IN) :: key
129 : #:set arraydims = ", ".join(":" for i in range(ndims))
130 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: dest
131 :
132 : #if defined(__LIBTORCH)
133 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
134 : TYPE(C_PTR) :: dest_c
135 :
136 : INTERFACE
137 : SUBROUTINE torch_c_dict_get_${typename}$ (dict, key, ndims, sizes, dest) &
138 : BIND(C, name="torch_c_dict_get_${typename}$")
139 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
140 : TYPE(C_PTR), VALUE :: dict
141 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
142 : INTEGER(kind=C_INT), VALUE :: ndims
143 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
144 : TYPE(C_PTR) :: dest
145 : END SUBROUTINE torch_c_dict_get_${typename}$
146 : END INTERFACE
147 :
148 104 : sizes_c(:) = -1
149 34 : dest_c = C_NULL_PTR
150 34 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
151 34 : CPASSERT(.NOT. ASSOCIATED(dest))
152 : CALL torch_c_dict_get_${typename}$ (dict=dict%c_ptr, &
153 : key=TRIM(key)//C_NULL_CHAR, &
154 : ndims=${ndims}$, &
155 : sizes=sizes_c, &
156 34 : dest=dest_c)
157 :
158 104 : CPASSERT(ALL(sizes_c >= 0))
159 34 : CPASSERT(C_ASSOCIATED(dest_c))
160 :
161 : #:for axis in range(ndims)
162 34 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
163 : #:endfor
164 104 : CALL C_F_POINTER(dest_c, dest, shape=sizes_f)
165 : #else
166 : CPABORT("CP2K compiled without the Torch library.")
167 : MARK_USED(dict)
168 : MARK_USED(key)
169 : MARK_USED(dest)
170 : #endif
171 34 : END SUBROUTINE torch_dict_get_${typename}$_${ndims}$d
172 :
173 : #:endfor
174 : #:endfor
175 :
176 : ! **************************************************************************************************
177 : !> \brief Creates an empty Torch dictionary.
178 : !> \author Ole Schuett
179 : ! **************************************************************************************************
180 32 : SUBROUTINE torch_dict_create(dict)
181 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
182 :
183 : #if defined(__LIBTORCH)
184 : INTERFACE
185 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
186 : IMPORT :: C_PTR
187 : TYPE(C_PTR) :: dict
188 : END SUBROUTINE torch_c_dict_create
189 : END INTERFACE
190 :
191 32 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
192 32 : CALL torch_c_dict_create(dict=dict%c_ptr)
193 32 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
194 : #else
195 : CPABORT("CP2K was compiled without Torch library.")
196 : MARK_USED(dict)
197 : #endif
198 32 : END SUBROUTINE torch_dict_create
199 :
200 : ! **************************************************************************************************
201 : !> \brief Releases a Torch dictionary and all its ressources.
202 : !> \author Ole Schuett
203 : ! **************************************************************************************************
204 32 : SUBROUTINE torch_dict_release(dict)
205 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
206 :
207 : #if defined(__LIBTORCH)
208 : INTERFACE
209 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
210 : IMPORT :: C_PTR
211 : TYPE(C_PTR), VALUE :: dict
212 : END SUBROUTINE torch_c_dict_release
213 : END INTERFACE
214 :
215 32 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
216 32 : CALL torch_c_dict_release(dict=dict%c_ptr)
217 32 : dict%c_ptr = C_NULL_PTR
218 : #else
219 : CPABORT("CP2K was compiled without Torch library.")
220 : MARK_USED(dict)
221 : #endif
222 32 : END SUBROUTINE torch_dict_release
223 :
224 : ! **************************************************************************************************
225 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
226 : !> \author Ole Schuett
227 : ! **************************************************************************************************
228 14 : SUBROUTINE torch_model_load(model, filename)
229 : TYPE(torch_model_type), INTENT(INOUT) :: model
230 : CHARACTER(len=*), INTENT(IN) :: filename
231 :
232 : #if defined(__LIBTORCH)
233 : INTERFACE
234 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
235 : IMPORT :: C_PTR, C_CHAR
236 : TYPE(C_PTR) :: model
237 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
238 : END SUBROUTINE torch_c_model_load
239 : END INTERFACE
240 :
241 14 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
242 14 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
243 14 : CPASSERT(C_ASSOCIATED(model%c_ptr))
244 : #else
245 : CPABORT("CP2K was compiled without Torch library.")
246 : MARK_USED(model)
247 : MARK_USED(filename)
248 : #endif
249 14 : END SUBROUTINE torch_model_load
250 :
251 : ! **************************************************************************************************
252 : !> \brief Evaluates the given Torch model. (In Torch lingo this operation is called forward())
253 : !> \author Ole Schuett
254 : ! **************************************************************************************************
255 16 : SUBROUTINE torch_model_eval(model, inputs, outputs)
256 : TYPE(torch_model_type), INTENT(INOUT) :: model
257 : TYPE(torch_dict_type), INTENT(IN) :: inputs
258 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
259 :
260 : #if defined(__LIBTORCH)
261 : INTERFACE
262 : SUBROUTINE torch_c_model_eval(model, inputs, outputs) BIND(C, name="torch_c_model_eval")
263 : IMPORT :: C_PTR
264 : TYPE(C_PTR), VALUE :: model
265 : TYPE(C_PTR), VALUE :: inputs
266 : TYPE(C_PTR), VALUE :: outputs
267 : END SUBROUTINE torch_c_model_eval
268 : END INTERFACE
269 :
270 16 : CPASSERT(C_ASSOCIATED(model%c_ptr))
271 16 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
272 16 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
273 : CALL torch_c_model_eval(model=model%c_ptr, &
274 : inputs=inputs%c_ptr, &
275 16 : outputs=outputs%c_ptr)
276 : #else
277 : CPABORT("CP2K was compiled without Torch library.")
278 : MARK_USED(model)
279 : MARK_USED(inputs)
280 : MARK_USED(outputs)
281 : #endif
282 16 : END SUBROUTINE torch_model_eval
283 :
284 : ! **************************************************************************************************
285 : !> \brief Releases a Torch model and all its ressources.
286 : !> \author Ole Schuett
287 : ! **************************************************************************************************
288 14 : SUBROUTINE torch_model_release(model)
289 : TYPE(torch_model_type), INTENT(INOUT) :: model
290 :
291 : #if defined(__LIBTORCH)
292 : INTERFACE
293 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
294 : IMPORT :: C_PTR
295 : TYPE(C_PTR), VALUE :: model
296 : END SUBROUTINE torch_c_model_release
297 : END INTERFACE
298 :
299 14 : CPASSERT(C_ASSOCIATED(model%c_ptr))
300 14 : CALL torch_c_model_release(model=model%c_ptr)
301 14 : model%c_ptr = C_NULL_PTR
302 : #else
303 : CPABORT("CP2K was compiled without Torch library.")
304 : MARK_USED(model)
305 : #endif
306 14 : END SUBROUTINE torch_model_release
307 :
308 : ! **************************************************************************************************
309 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
310 : !> \author Ole Schuett
311 : ! **************************************************************************************************
312 140 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
313 : CHARACTER(len=*), INTENT(IN) :: filename, key
314 : CHARACTER(:), ALLOCATABLE :: res
315 :
316 : #if defined(__LIBTORCH)
317 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
318 140 : POINTER :: content_f
319 : INTEGER :: i
320 : INTEGER :: length
321 : TYPE(C_PTR) :: content_c
322 :
323 : INTERFACE
324 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
325 : BIND(C, name="torch_c_model_read_metadata")
326 : IMPORT :: C_CHAR, C_PTR, C_INT
327 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
328 : TYPE(C_PTR) :: content
329 : INTEGER(kind=C_INT) :: length
330 : END SUBROUTINE torch_c_model_read_metadata
331 : END INTERFACE
332 :
333 140 : content_c = C_NULL_PTR
334 140 : length = -1
335 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
336 : key=TRIM(key)//C_NULL_CHAR, &
337 : content=content_c, &
338 140 : length=length)
339 140 : CPASSERT(C_ASSOCIATED(content_c))
340 140 : CPASSERT(length >= 0)
341 :
342 280 : CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
343 140 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
344 :
345 140 : ALLOCATE (CHARACTER(LEN=length) :: res)
346 798 : DO i = 1, length
347 658 : CPASSERT(content_f(i) /= C_NULL_CHAR)
348 798 : res(i:i) = content_f(i)
349 : END DO
350 :
351 140 : DEALLOCATE (content_f) ! Was allocated on the C side.
352 : #else
353 : CPABORT("CP2K was compiled without Torch library.")
354 : MARK_USED(filename)
355 : MARK_USED(key)
356 : MARK_USED(res)
357 : #endif
358 140 : END FUNCTION torch_model_read_metadata
359 :
360 : ! **************************************************************************************************
361 : !> \brief Returns true iff the Torch CUDA backend is available.
362 : !> \author Ole Schuett
363 : ! **************************************************************************************************
364 2 : FUNCTION torch_cuda_is_available() RESULT(res)
365 : LOGICAL :: res
366 :
367 : #if defined(__LIBTORCH)
368 : INTERFACE
369 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
370 : IMPORT :: C_BOOL
371 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
372 : END FUNCTION torch_c_cuda_is_available
373 : END INTERFACE
374 :
375 2 : res = torch_c_cuda_is_available()
376 : #else
377 : CPABORT("CP2K was compiled without Torch library.")
378 : res = .FALSE.
379 : #endif
380 2 : END FUNCTION torch_cuda_is_available
381 :
382 : ! **************************************************************************************************
383 : !> \brief Set whether to allow the use of TF32.
384 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
385 : !> See https://pytorch.org/docs/stable/notes/cuda.html
386 : !> \author Gabriele Tocci
387 : ! **************************************************************************************************
388 20 : SUBROUTINE torch_allow_tf32(allow_tf32)
389 : LOGICAL, INTENT(IN) :: allow_tf32
390 :
391 : #if defined(__LIBTORCH)
392 : INTERFACE
393 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
394 : IMPORT :: C_BOOL
395 : LOGICAL(C_BOOL), VALUE :: allow_tf32
396 : END SUBROUTINE torch_c_allow_tf32
397 : END INTERFACE
398 :
399 20 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
400 : #else
401 : CPABORT("CP2K was compiled without Torch library.")
402 : MARK_USED(allow_tf32)
403 : #endif
404 20 : END SUBROUTINE torch_allow_tf32
405 :
406 : ! **************************************************************************************************
407 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
408 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
409 : !> \author Gabriele Tocci
410 : ! **************************************************************************************************
411 8 : SUBROUTINE torch_model_freeze(model)
412 : TYPE(torch_model_type), INTENT(INOUT) :: model
413 :
414 : #if defined(__LIBTORCH)
415 : INTERFACE
416 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
417 : IMPORT :: C_PTR
418 : TYPE(C_PTR), VALUE :: model
419 : END SUBROUTINE torch_c_model_freeze
420 : END INTERFACE
421 :
422 8 : CPASSERT(C_ASSOCIATED(model%c_ptr))
423 8 : CALL torch_c_model_freeze(model=model%c_ptr)
424 : #else
425 : CPABORT("CP2K was compiled without Torch library.")
426 : MARK_USED(model)
427 : #endif
428 8 : END SUBROUTINE torch_model_freeze
429 :
430 : #:set typenames = ['int64', 'double', 'string']
431 : #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
432 : #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
433 :
434 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
435 : ! **************************************************************************************************
436 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
437 : !> \author Ole Schuett
438 : ! **************************************************************************************************
439 32 : SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
440 : TYPE(torch_model_type), INTENT(IN) :: model
441 : CHARACTER(len=*), INTENT(IN) :: key
442 : ${type_f}$, INTENT(OUT) :: dest
443 :
444 : #if defined(__LIBTORCH)
445 :
446 : INTERFACE
447 : SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
448 : BIND(C, name="torch_c_model_get_attr_${typename}$")
449 : IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
450 : TYPE(C_PTR), VALUE :: model
451 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
452 : ${type_c}$ :: dest
453 : END SUBROUTINE torch_c_model_get_attr_${typename}$
454 : END INTERFACE
455 :
456 : CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
457 : key=TRIM(key)//C_NULL_CHAR, &
458 32 : dest=dest)
459 : #else
460 : CPABORT("CP2K compiled without the Torch library.")
461 : MARK_USED(model)
462 : MARK_USED(key)
463 : MARK_USED(dest)
464 : #endif
465 32 : END SUBROUTINE torch_model_get_attr_${typename}$
466 : #:endfor
467 :
468 : ! **************************************************************************************************
469 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
470 : !> \author Ole Schuett
471 : ! **************************************************************************************************
472 20 : SUBROUTINE torch_model_get_attr_int32(model, key, dest)
473 : TYPE(torch_model_type), INTENT(IN) :: model
474 : CHARACTER(len=*), INTENT(IN) :: key
475 : INTEGER, INTENT(OUT) :: dest
476 :
477 : INTEGER(kind=int_8) :: temp
478 20 : CALL torch_model_get_attr_int64(model, key, temp)
479 20 : CPASSERT(ABS(temp) < HUGE(dest))
480 20 : dest = INT(temp)
481 20 : END SUBROUTINE torch_model_get_attr_int32
482 :
483 : ! **************************************************************************************************
484 : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
485 : !> \author Ole Schuett
486 : ! **************************************************************************************************
487 4 : SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
488 : TYPE(torch_model_type), INTENT(IN) :: model
489 : CHARACTER(len=*), INTENT(IN) :: key
490 : CHARACTER(LEN=default_string_length), &
491 : ALLOCATABLE, DIMENSION(:) :: dest
492 :
493 : #if defined(__LIBTORCH)
494 :
495 : INTEGER :: num_items, i
496 :
497 : INTERFACE
498 : SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
499 : BIND(C, name="torch_c_model_get_attr_list_size")
500 : IMPORT :: C_PTR, C_CHAR, C_INT
501 : TYPE(C_PTR), VALUE :: model
502 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
503 : INTEGER(kind=C_INT) :: size
504 : END SUBROUTINE torch_c_model_get_attr_list_size
505 : END INTERFACE
506 :
507 : INTERFACE
508 : SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
509 : BIND(C, name="torch_c_model_get_attr_strlist")
510 : IMPORT :: C_PTR, C_CHAR, C_INT
511 : TYPE(C_PTR), VALUE :: model
512 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
513 : INTEGER(kind=C_INT), VALUE :: index
514 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
515 : END SUBROUTINE torch_c_model_get_attr_strlist
516 : END INTERFACE
517 :
518 : CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
519 : key=TRIM(key)//C_NULL_CHAR, &
520 4 : size=num_items)
521 12 : ALLOCATE (dest(num_items))
522 12 : dest(:) = ""
523 :
524 12 : DO i = 1, num_items
525 : CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
526 : key=TRIM(key)//C_NULL_CHAR, &
527 : index=i - 1, &
528 12 : dest=dest(i))
529 :
530 : END DO
531 : #else
532 : CPABORT("CP2K compiled without the Torch library.")
533 : MARK_USED(model)
534 : MARK_USED(key)
535 : MARK_USED(dest)
536 : #endif
537 :
538 4 : END SUBROUTINE torch_model_get_attr_strlist
539 :
540 0 : END MODULE torch_api
|