Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT64_T
19 :
20 : USE kinds, ONLY: sp, int_8, dp, default_string_length
21 :
22 : #include "./base/base_uses.f90"
23 :
24 : IMPLICIT NONE
25 :
26 : PRIVATE
27 :
28 : TYPE torch_tensor_type
29 : PRIVATE
30 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
31 : END TYPE torch_tensor_type
32 :
33 : TYPE torch_dict_type
34 : PRIVATE
35 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
36 : END TYPE torch_dict_type
37 :
38 : TYPE torch_model_type
39 : PRIVATE
40 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
41 : END TYPE torch_model_type
42 :
43 : #:set max_dim = 3
44 : INTERFACE torch_tensor_from_array
45 : #:for ndims in range(1, max_dim+1)
46 : MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
47 : MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
48 : MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
49 : #:endfor
50 : END INTERFACE torch_tensor_from_array
51 :
52 : INTERFACE torch_tensor_data_ptr
53 : #:for ndims in range(1, max_dim+1)
54 : MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
55 : MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
56 : MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
57 : #:endfor
58 : END INTERFACE torch_tensor_data_ptr
59 :
60 : INTERFACE torch_model_get_attr
61 : MODULE PROCEDURE torch_model_get_attr_string
62 : MODULE PROCEDURE torch_model_get_attr_double
63 : MODULE PROCEDURE torch_model_get_attr_int64
64 : MODULE PROCEDURE torch_model_get_attr_int32
65 : MODULE PROCEDURE torch_model_get_attr_strlist
66 : END INTERFACE torch_model_get_attr
67 :
68 : PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
69 : PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_grad
70 : PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_insert, torch_dict_get, torch_dict_release
71 : PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
72 : PUBLIC :: torch_model_get_attr, torch_model_read_metadata
73 : PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
74 :
75 : CONTAINS
76 :
77 : #:set typenames = ['float', 'int64', 'double']
78 : #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
79 : #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
80 :
81 : #:for ndims in range(1, max_dim+1)
82 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
83 :
84 : ! **************************************************************************************************
85 : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
86 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
87 : !> \author Ole Schuett
88 : ! **************************************************************************************************
89 164 : SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
90 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
91 : #:set arraydims = ", ".join(":" for i in range(ndims))
92 : ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
93 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
94 :
95 : #if defined(__LIBTORCH)
96 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
97 : LOGICAL :: my_req_grad
98 :
99 : INTERFACE
100 : SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
101 : BIND(C, name="torch_c_tensor_from_array_${typename}$")
102 : IMPORT :: C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
103 : TYPE(C_PTR) :: tensor
104 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
105 : INTEGER(kind=C_INT), VALUE :: ndims
106 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
107 : ${type_c}$, DIMENSION(*) :: source
108 : END SUBROUTINE torch_c_tensor_from_array_${typename}$
109 : END INTERFACE
110 :
111 164 : my_req_grad = .FALSE.
112 164 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
113 :
114 : #:for axis in range(ndims)
115 164 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
116 : #:endfor
117 :
118 164 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
119 : CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
120 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
121 : ndims=${ndims}$, &
122 : sizes=sizes_c, &
123 164 : source=source)
124 164 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
125 : #else
126 : CPABORT("CP2K compiled without the Torch library.")
127 : MARK_USED(tensor)
128 : MARK_USED(source)
129 : MARK_USED(requires_grad)
130 : #endif
131 164 : END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
132 :
133 : ! **************************************************************************************************
134 : !> \brief Copies data from a Torch tensor to an array.
135 : !> The returned pointer is only valide during the tensor's lifetime!
136 : !> \author Ole Schuett
137 : ! **************************************************************************************************
138 88 : SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
139 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
140 : #:set arraydims = ", ".join(":" for i in range(ndims))
141 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: data_ptr
142 :
143 : #if defined(__LIBTORCH)
144 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
145 : TYPE(C_PTR) :: data_ptr_c
146 :
147 : INTERFACE
148 : SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
149 : BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
150 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
151 : TYPE(C_PTR), VALUE :: tensor
152 : INTEGER(kind=C_INT), VALUE :: ndims
153 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
154 : TYPE(C_PTR) :: data_ptr
155 : END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
156 : END INTERFACE
157 :
158 266 : sizes_c(:) = -1
159 88 : data_ptr_c = C_NULL_PTR
160 88 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
161 88 : CPASSERT(.NOT. ASSOCIATED(data_ptr))
162 : CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
163 : ndims=${ndims}$, &
164 : sizes=sizes_c, &
165 88 : data_ptr=data_ptr_c)
166 :
167 266 : CPASSERT(ALL(sizes_c >= 0))
168 88 : CPASSERT(C_ASSOCIATED(data_ptr_c))
169 :
170 : #:for axis in range(ndims)
171 88 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
172 : #:endfor
173 266 : CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
174 : #else
175 : CPABORT("CP2K compiled without the Torch library.")
176 : MARK_USED(tensor)
177 : MARK_USED(data_ptr)
178 : #endif
179 88 : END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
180 :
181 : #:endfor
182 : #:endfor
183 :
184 : ! **************************************************************************************************
185 : !> \brief Runs autograd on a Torch tensor.
186 : !> \author Ole Schuett
187 : ! **************************************************************************************************
188 6 : SUBROUTINE torch_tensor_backward(tensor, outer_grad)
189 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
190 : TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
191 :
192 : #if defined(__LIBTORCH)
193 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_tensor_backward'
194 : INTEGER :: handle
195 :
196 : INTERFACE
197 : SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
198 : BIND(C, name="torch_c_tensor_backward")
199 : IMPORT :: C_CHAR, C_PTR
200 : TYPE(C_PTR), VALUE :: tensor
201 : TYPE(C_PTR), VALUE :: outer_grad
202 : END SUBROUTINE torch_c_tensor_backward
203 : END INTERFACE
204 :
205 6 : CALL timeset(routineN, handle)
206 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
207 6 : CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
208 6 : CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
209 6 : CALL timestop(handle)
210 : #else
211 : CPABORT("CP2K compiled without the Torch library.")
212 : MARK_USED(tensor)
213 : MARK_USED(outer_grad)
214 : #endif
215 6 : END SUBROUTINE torch_tensor_backward
216 :
217 : ! **************************************************************************************************
218 : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
219 : !> \author Ole Schuett
220 : ! **************************************************************************************************
221 6 : SUBROUTINE torch_tensor_grad(tensor, grad)
222 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
223 : TYPE(torch_tensor_type), INTENT(INOUT) :: grad
224 :
225 : #if defined(__LIBTORCH)
226 : INTERFACE
227 : SUBROUTINE torch_c_tensor_grad(tensor, grad) &
228 : BIND(C, name="torch_c_tensor_grad")
229 : IMPORT :: C_PTR
230 : TYPE(C_PTR), VALUE :: tensor
231 : TYPE(C_PTR) :: grad
232 : END SUBROUTINE torch_c_tensor_grad
233 : END INTERFACE
234 :
235 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
236 6 : CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
237 6 : CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
238 6 : CPASSERT(C_ASSOCIATED(grad%c_ptr))
239 : #else
240 : CPABORT("CP2K compiled without the Torch library.")
241 : MARK_USED(tensor)
242 : MARK_USED(grad)
243 : #endif
244 6 : END SUBROUTINE torch_tensor_grad
245 :
246 : ! **************************************************************************************************
247 : !> \brief Releases a Torch tensor and all its ressources.
248 : !> \author Ole Schuett
249 : ! **************************************************************************************************
250 252 : SUBROUTINE torch_tensor_release(tensor)
251 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
252 :
253 : #if defined(__LIBTORCH)
254 : INTERFACE
255 : SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
256 : IMPORT :: C_PTR
257 : TYPE(C_PTR), VALUE :: tensor
258 : END SUBROUTINE torch_c_tensor_release
259 : END INTERFACE
260 :
261 252 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
262 252 : CALL torch_c_tensor_release(tensor=tensor%c_ptr)
263 252 : tensor%c_ptr = C_NULL_PTR
264 : #else
265 : CPABORT("CP2K was compiled without Torch library.")
266 : MARK_USED(tensor)
267 : #endif
268 252 : END SUBROUTINE torch_tensor_release
269 :
270 : ! **************************************************************************************************
271 : !> \brief Creates an empty Torch dictionary.
272 : !> \author Ole Schuett
273 : ! **************************************************************************************************
274 128 : SUBROUTINE torch_dict_create(dict)
275 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
276 :
277 : #if defined(__LIBTORCH)
278 : INTERFACE
279 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
280 : IMPORT :: C_PTR
281 : TYPE(C_PTR) :: dict
282 : END SUBROUTINE torch_c_dict_create
283 : END INTERFACE
284 :
285 128 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
286 128 : CALL torch_c_dict_create(dict=dict%c_ptr)
287 128 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
288 : #else
289 : CPABORT("CP2K was compiled without Torch library.")
290 : MARK_USED(dict)
291 : #endif
292 128 : END SUBROUTINE torch_dict_create
293 :
294 : ! **************************************************************************************************
295 : !> \brief Inserts a Torch tensor into a Torch dictionary.
296 : !> \author Ole Schuett
297 : ! **************************************************************************************************
298 158 : SUBROUTINE torch_dict_insert(dict, key, tensor)
299 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
300 : CHARACTER(len=*), INTENT(IN) :: key
301 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
302 :
303 : #if defined(__LIBTORCH)
304 :
305 : INTERFACE
306 : SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
307 : BIND(C, name="torch_c_dict_insert")
308 : IMPORT :: C_CHAR, C_PTR
309 : TYPE(C_PTR), VALUE :: dict
310 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
311 : TYPE(C_PTR), VALUE :: tensor
312 : END SUBROUTINE torch_c_dict_insert
313 : END INTERFACE
314 :
315 158 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
316 158 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
317 158 : CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
318 : #else
319 : CPABORT("CP2K compiled without the Torch library.")
320 : MARK_USED(dict)
321 : MARK_USED(key)
322 : MARK_USED(tensor)
323 : #endif
324 158 : END SUBROUTINE torch_dict_insert
325 :
326 : ! **************************************************************************************************
327 : !> \brief Retrieves a Torch tensor from a Torch dictionary.
328 : !> \author Ole Schuett
329 : ! **************************************************************************************************
330 82 : SUBROUTINE torch_dict_get(dict, key, tensor)
331 : TYPE(torch_dict_type), INTENT(IN) :: dict
332 : CHARACTER(len=*), INTENT(IN) :: key
333 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
334 :
335 : #if defined(__LIBTORCH)
336 :
337 : INTERFACE
338 : SUBROUTINE torch_c_dict_get(dict, key, tensor) &
339 : BIND(C, name="torch_c_dict_get")
340 : IMPORT :: C_CHAR, C_PTR
341 : TYPE(C_PTR), VALUE :: dict
342 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
343 : TYPE(C_PTR) :: tensor
344 : END SUBROUTINE torch_c_dict_get
345 : END INTERFACE
346 :
347 82 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
348 82 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
349 82 : CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
350 82 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
351 :
352 : #else
353 : CPABORT("CP2K compiled without the Torch library.")
354 : MARK_USED(dict)
355 : MARK_USED(key)
356 : MARK_USED(tensor)
357 : #endif
358 82 : END SUBROUTINE torch_dict_get
359 :
360 : ! **************************************************************************************************
361 : !> \brief Releases a Torch dictionary and all its ressources.
362 : !> \author Ole Schuett
363 : ! **************************************************************************************************
364 128 : SUBROUTINE torch_dict_release(dict)
365 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
366 :
367 : #if defined(__LIBTORCH)
368 : INTERFACE
369 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
370 : IMPORT :: C_PTR
371 : TYPE(C_PTR), VALUE :: dict
372 : END SUBROUTINE torch_c_dict_release
373 : END INTERFACE
374 :
375 128 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
376 128 : CALL torch_c_dict_release(dict=dict%c_ptr)
377 128 : dict%c_ptr = C_NULL_PTR
378 : #else
379 : CPABORT("CP2K was compiled without Torch library.")
380 : MARK_USED(dict)
381 : #endif
382 128 : END SUBROUTINE torch_dict_release
383 :
384 : ! **************************************************************************************************
385 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
386 : !> \author Ole Schuett
387 : ! **************************************************************************************************
388 18 : SUBROUTINE torch_model_load(model, filename)
389 : TYPE(torch_model_type), INTENT(INOUT) :: model
390 : CHARACTER(len=*), INTENT(IN) :: filename
391 :
392 : #if defined(__LIBTORCH)
393 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_load'
394 : INTEGER :: handle
395 :
396 : INTERFACE
397 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
398 : IMPORT :: C_PTR, C_CHAR
399 : TYPE(C_PTR) :: model
400 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
401 : END SUBROUTINE torch_c_model_load
402 : END INTERFACE
403 :
404 18 : CALL timeset(routineN, handle)
405 18 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
406 18 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
407 18 : CPASSERT(C_ASSOCIATED(model%c_ptr))
408 18 : CALL timestop(handle)
409 : #else
410 : CPABORT("CP2K was compiled without Torch library.")
411 : MARK_USED(model)
412 : MARK_USED(filename)
413 : #endif
414 18 : END SUBROUTINE torch_model_load
415 :
416 : ! **************************************************************************************************
417 : !> \brief Evaluates the given Torch model.
418 : !> \author Ole Schuett
419 : ! **************************************************************************************************
420 64 : SUBROUTINE torch_model_forward(model, inputs, outputs)
421 : TYPE(torch_model_type), INTENT(INOUT) :: model
422 : TYPE(torch_dict_type), INTENT(IN) :: inputs
423 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
424 :
425 : #if defined(__LIBTORCH)
426 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward'
427 : INTEGER :: handle
428 :
429 : INTERFACE
430 : SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
431 : IMPORT :: C_PTR
432 : TYPE(C_PTR), VALUE :: model
433 : TYPE(C_PTR), VALUE :: inputs
434 : TYPE(C_PTR), VALUE :: outputs
435 : END SUBROUTINE torch_c_model_forward
436 : END INTERFACE
437 :
438 64 : CALL timeset(routineN, handle)
439 64 : CPASSERT(C_ASSOCIATED(model%c_ptr))
440 64 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
441 64 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
442 64 : CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
443 64 : CALL timestop(handle)
444 : #else
445 : CPABORT("CP2K was compiled without Torch library.")
446 : MARK_USED(model)
447 : MARK_USED(inputs)
448 : MARK_USED(outputs)
449 : #endif
450 64 : END SUBROUTINE torch_model_forward
451 :
452 : ! **************************************************************************************************
453 : !> \brief Releases a Torch model and all its ressources.
454 : !> \author Ole Schuett
455 : ! **************************************************************************************************
456 18 : SUBROUTINE torch_model_release(model)
457 : TYPE(torch_model_type), INTENT(INOUT) :: model
458 :
459 : #if defined(__LIBTORCH)
460 : INTERFACE
461 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
462 : IMPORT :: C_PTR
463 : TYPE(C_PTR), VALUE :: model
464 : END SUBROUTINE torch_c_model_release
465 : END INTERFACE
466 :
467 18 : CPASSERT(C_ASSOCIATED(model%c_ptr))
468 18 : CALL torch_c_model_release(model=model%c_ptr)
469 18 : model%c_ptr = C_NULL_PTR
470 : #else
471 : CPABORT("CP2K was compiled without Torch library.")
472 : MARK_USED(model)
473 : #endif
474 18 : END SUBROUTINE torch_model_release
475 :
476 : ! **************************************************************************************************
477 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
478 : !> \author Ole Schuett
479 : ! **************************************************************************************************
480 52 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
481 : CHARACTER(len=*), INTENT(IN) :: filename, key
482 : CHARACTER(:), ALLOCATABLE :: res
483 :
484 : #if defined(__LIBTORCH)
485 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_read_metadata'
486 : INTEGER :: handle
487 :
488 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
489 52 : POINTER :: content_f
490 : INTEGER :: i
491 : INTEGER :: length
492 : TYPE(C_PTR) :: content_c
493 :
494 : INTERFACE
495 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
496 : BIND(C, name="torch_c_model_read_metadata")
497 : IMPORT :: C_CHAR, C_PTR, C_INT
498 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
499 : TYPE(C_PTR) :: content
500 : INTEGER(kind=C_INT) :: length
501 : END SUBROUTINE torch_c_model_read_metadata
502 : END INTERFACE
503 :
504 52 : CALL timeset(routineN, handle)
505 52 : content_c = C_NULL_PTR
506 52 : length = -1
507 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
508 : key=TRIM(key)//C_NULL_CHAR, &
509 : content=content_c, &
510 52 : length=length)
511 52 : CPASSERT(C_ASSOCIATED(content_c))
512 52 : CPASSERT(length >= 0)
513 :
514 104 : CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
515 52 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
516 :
517 52 : ALLOCATE (CHARACTER(LEN=length) :: res)
518 278 : DO i = 1, length
519 226 : CPASSERT(content_f(i) /= C_NULL_CHAR)
520 278 : res(i:i) = content_f(i)
521 : END DO
522 :
523 52 : DEALLOCATE (content_f) ! Was allocated on the C side.
524 52 : CALL timestop(handle)
525 : #else
526 : CPABORT("CP2K was compiled without Torch library.")
527 : MARK_USED(filename)
528 : MARK_USED(key)
529 : MARK_USED(res)
530 : #endif
531 52 : END FUNCTION torch_model_read_metadata
532 :
533 : ! **************************************************************************************************
534 : !> \brief Returns true iff the Torch CUDA backend is available.
535 : !> \author Ole Schuett
536 : ! **************************************************************************************************
537 2 : FUNCTION torch_cuda_is_available() RESULT(res)
538 : LOGICAL :: res
539 :
540 : #if defined(__LIBTORCH)
541 : INTERFACE
542 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
543 : IMPORT :: C_BOOL
544 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
545 : END FUNCTION torch_c_cuda_is_available
546 : END INTERFACE
547 :
548 2 : res = torch_c_cuda_is_available()
549 : #else
550 : CPABORT("CP2K was compiled without Torch library.")
551 : res = .FALSE.
552 : #endif
553 2 : END FUNCTION torch_cuda_is_available
554 :
555 : ! **************************************************************************************************
556 : !> \brief Set whether to allow the use of TF32.
557 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
558 : !> See https://pytorch.org/docs/stable/notes/cuda.html
559 : !> \author Gabriele Tocci
560 : ! **************************************************************************************************
561 8 : SUBROUTINE torch_allow_tf32(allow_tf32)
562 : LOGICAL, INTENT(IN) :: allow_tf32
563 :
564 : #if defined(__LIBTORCH)
565 : INTERFACE
566 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
567 : IMPORT :: C_BOOL
568 : LOGICAL(C_BOOL), VALUE :: allow_tf32
569 : END SUBROUTINE torch_c_allow_tf32
570 : END INTERFACE
571 :
572 8 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
573 : #else
574 : CPABORT("CP2K was compiled without Torch library.")
575 : MARK_USED(allow_tf32)
576 : #endif
577 8 : END SUBROUTINE torch_allow_tf32
578 :
579 : ! **************************************************************************************************
580 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
581 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
582 : !> \author Gabriele Tocci
583 : ! **************************************************************************************************
584 8 : SUBROUTINE torch_model_freeze(model)
585 : TYPE(torch_model_type), INTENT(INOUT) :: model
586 :
587 : #if defined(__LIBTORCH)
588 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_freeze'
589 : INTEGER :: handle
590 :
591 : INTERFACE
592 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
593 : IMPORT :: C_PTR
594 : TYPE(C_PTR), VALUE :: model
595 : END SUBROUTINE torch_c_model_freeze
596 : END INTERFACE
597 :
598 8 : CALL timeset(routineN, handle)
599 8 : CPASSERT(C_ASSOCIATED(model%c_ptr))
600 8 : CALL torch_c_model_freeze(model=model%c_ptr)
601 8 : CALL timestop(handle)
602 : #else
603 : CPABORT("CP2K was compiled without Torch library.")
604 : MARK_USED(model)
605 : #endif
606 8 : END SUBROUTINE torch_model_freeze
607 :
608 : #:set typenames = ['int64', 'double', 'string']
609 : #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
610 : #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
611 :
612 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
613 : ! **************************************************************************************************
614 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
615 : !> \author Ole Schuett
616 : ! **************************************************************************************************
617 64 : SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
618 : TYPE(torch_model_type), INTENT(IN) :: model
619 : CHARACTER(len=*), INTENT(IN) :: key
620 : ${type_f}$, INTENT(OUT) :: dest
621 :
622 : #if defined(__LIBTORCH)
623 :
624 : INTERFACE
625 : SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
626 : BIND(C, name="torch_c_model_get_attr_${typename}$")
627 : IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
628 : TYPE(C_PTR), VALUE :: model
629 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
630 : ${type_c}$ :: dest
631 : END SUBROUTINE torch_c_model_get_attr_${typename}$
632 : END INTERFACE
633 :
634 : CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
635 : key=TRIM(key)//C_NULL_CHAR, &
636 64 : dest=dest)
637 : #else
638 : CPABORT("CP2K compiled without the Torch library.")
639 : MARK_USED(model)
640 : MARK_USED(key)
641 : MARK_USED(dest)
642 : #endif
643 64 : END SUBROUTINE torch_model_get_attr_${typename}$
644 : #:endfor
645 :
646 : ! **************************************************************************************************
647 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
648 : !> \author Ole Schuett
649 : ! **************************************************************************************************
650 40 : SUBROUTINE torch_model_get_attr_int32(model, key, dest)
651 : TYPE(torch_model_type), INTENT(IN) :: model
652 : CHARACTER(len=*), INTENT(IN) :: key
653 : INTEGER, INTENT(OUT) :: dest
654 :
655 : INTEGER(kind=int_8) :: temp
656 40 : CALL torch_model_get_attr_int64(model, key, temp)
657 40 : CPASSERT(ABS(temp) < HUGE(dest))
658 40 : dest = INT(temp)
659 40 : END SUBROUTINE torch_model_get_attr_int32
660 :
661 : ! **************************************************************************************************
662 : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
663 : !> \author Ole Schuett
664 : ! **************************************************************************************************
665 8 : SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
666 : TYPE(torch_model_type), INTENT(IN) :: model
667 : CHARACTER(len=*), INTENT(IN) :: key
668 : CHARACTER(LEN=default_string_length), &
669 : ALLOCATABLE, DIMENSION(:) :: dest
670 :
671 : #if defined(__LIBTORCH)
672 :
673 : INTEGER :: num_items, i
674 :
675 : INTERFACE
676 : SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
677 : BIND(C, name="torch_c_model_get_attr_list_size")
678 : IMPORT :: C_PTR, C_CHAR, C_INT
679 : TYPE(C_PTR), VALUE :: model
680 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
681 : INTEGER(kind=C_INT) :: size
682 : END SUBROUTINE torch_c_model_get_attr_list_size
683 : END INTERFACE
684 :
685 : INTERFACE
686 : SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
687 : BIND(C, name="torch_c_model_get_attr_strlist")
688 : IMPORT :: C_PTR, C_CHAR, C_INT
689 : TYPE(C_PTR), VALUE :: model
690 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
691 : INTEGER(kind=C_INT), VALUE :: index
692 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
693 : END SUBROUTINE torch_c_model_get_attr_strlist
694 : END INTERFACE
695 :
696 : CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
697 : key=TRIM(key)//C_NULL_CHAR, &
698 8 : size=num_items)
699 24 : ALLOCATE (dest(num_items))
700 24 : dest(:) = ""
701 :
702 24 : DO i = 1, num_items
703 : CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
704 : key=TRIM(key)//C_NULL_CHAR, &
705 : index=i - 1, &
706 24 : dest=dest(i))
707 :
708 : END DO
709 : #else
710 : CPABORT("CP2K compiled without the Torch library.")
711 : MARK_USED(model)
712 : MARK_USED(key)
713 : MARK_USED(dest)
714 : #endif
715 :
716 8 : END SUBROUTINE torch_model_get_attr_strlist
717 :
718 0 : END MODULE torch_api
|