LCOV - code coverage report
Current view: top level - src - pao_main.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:2fce0f8) Lines: 138 139 99.3 %
Date: 2024-12-21 06:28:57 Functions: 5 5 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Main module for the PAO method
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_main
      13             :    USE bibliography,                    ONLY: Schuett2018,&
      14             :                                               cite_reference
      15             :    USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
      16             :                                               dbcsr_copy,&
      17             :                                               dbcsr_create,&
      18             :                                               dbcsr_p_type,&
      19             :                                               dbcsr_release,&
      20             :                                               dbcsr_reserve_diag_blocks,&
      21             :                                               dbcsr_set,&
      22             :                                               dbcsr_type
      23             :    USE cp_external_control,             ONLY: external_control
      24             :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      25             :                                               ls_scf_env_type
      26             :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      27             :                                               section_vals_type
      28             :    USE kinds,                           ONLY: dp
      29             :    USE linesearch,                      ONLY: linesearch_finalize,&
      30             :                                               linesearch_init,&
      31             :                                               linesearch_reset,&
      32             :                                               linesearch_step
      33             :    USE machine,                         ONLY: m_walltime
      34             :    USE pao_input,                       ONLY: parse_pao_section
      35             :    USE pao_io,                          ONLY: pao_read_restart,&
      36             :                                               pao_write_ks_matrix_csr,&
      37             :                                               pao_write_restart,&
      38             :                                               pao_write_s_matrix_csr
      39             :    USE pao_methods,                     ONLY: &
      40             :         pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
      41             :         pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
      42             :         pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, pao_init_kinds, &
      43             :         pao_print_atom_info, pao_store_P, pao_test_convergence
      44             :    USE pao_ml,                          ONLY: pao_ml_init,&
      45             :                                               pao_ml_predict
      46             :    USE pao_model,                       ONLY: pao_model_predict
      47             :    USE pao_optimizer,                   ONLY: pao_opt_finalize,&
      48             :                                               pao_opt_init,&
      49             :                                               pao_opt_new_dir
      50             :    USE pao_param,                       ONLY: pao_calc_AB,&
      51             :                                               pao_param_finalize,&
      52             :                                               pao_param_init,&
      53             :                                               pao_param_initial_guess
      54             :    USE pao_types,                       ONLY: pao_env_type
      55             :    USE qs_environment_types,            ONLY: get_qs_env,&
      56             :                                               qs_environment_type
      57             : #include "./base/base_uses.f90"
      58             : 
      59             :    IMPLICIT NONE
      60             : 
      61             :    PRIVATE
      62             : 
      63             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
      64             : 
      65             :    PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
      66             : 
      67             : CONTAINS
      68             : 
      69             : ! **************************************************************************************************
      70             : !> \brief Initialize the PAO environment
      71             : !> \param qs_env ...
      72             : !> \param ls_scf_env ...
      73             : ! **************************************************************************************************
      74         438 :    SUBROUTINE pao_init(qs_env, ls_scf_env)
      75             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      76             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
      77             : 
      78             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
      79             : 
      80             :       INTEGER                                            :: handle
      81         342 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      82             :       TYPE(pao_env_type), POINTER                        :: pao
      83             :       TYPE(section_vals_type), POINTER                   :: input
      84             : 
      85         246 :       IF (.NOT. ls_scf_env%do_pao) RETURN
      86             : 
      87          96 :       CALL timeset(routineN, handle)
      88          96 :       CALL cite_reference(Schuett2018)
      89          96 :       pao => ls_scf_env%pao_env
      90          96 :       CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
      91             : 
      92             :       ! parse input
      93          96 :       CALL parse_pao_section(pao, input)
      94             : 
      95          96 :       CALL pao_init_kinds(pao, qs_env)
      96             : 
      97             :       ! train machine learning
      98          96 :       CALL pao_ml_init(pao, qs_env)
      99             : 
     100          96 :       CALL timestop(handle)
     101         342 :    END SUBROUTINE pao_init
     102             : 
     103             : ! **************************************************************************************************
     104             : !> \brief Start a PAO optimization run.
     105             : !> \param qs_env ...
     106             : !> \param ls_scf_env ...
     107             : ! **************************************************************************************************
     108         878 :    SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
     109             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     110             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     111             : 
     112             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
     113             : 
     114             :       INTEGER                                            :: handle
     115             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     116             :       TYPE(pao_env_type), POINTER                        :: pao
     117             :       TYPE(section_vals_type), POINTER                   :: input, section
     118             : 
     119         598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     120             : 
     121         280 :       CALL timeset(routineN, handle)
     122         280 :       CALL get_qs_env(qs_env, input=input)
     123         280 :       pao => ls_scf_env%pao_env
     124         280 :       ls_mstruct => ls_scf_env%ls_mstruct
     125             : 
     126             :       ! reset state
     127         280 :       pao%step_start_time = m_walltime()
     128         280 :       pao%istep = 0
     129         280 :       pao%matrix_P_ready = .FALSE.
     130             : 
     131             :       ! ready stuff that does not depend on atom positions
     132         280 :       IF (.NOT. pao%constants_ready) THEN
     133          96 :          CALL pao_build_diag_distribution(pao, qs_env)
     134          96 :          CALL pao_build_orthogonalizer(pao, qs_env)
     135          96 :          CALL pao_build_selector(pao, qs_env)
     136          96 :          CALL pao_build_core_hamiltonian(pao, qs_env)
     137          96 :          pao%constants_ready = .TRUE.
     138             :       END IF
     139             : 
     140         280 :       CALL pao_param_init(pao, qs_env)
     141             : 
     142             :       ! ready PAO parameter matrix_X
     143         280 :       IF (.NOT. pao%matrix_X_ready) THEN
     144          96 :          CALL pao_build_matrix_X(pao, qs_env)
     145          96 :          CALL pao_print_atom_info(pao)
     146          96 :          IF (LEN_TRIM(pao%restart_file) > 0) THEN
     147           8 :             CALL pao_read_restart(pao, qs_env)
     148          88 :          ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     149          18 :             CALL pao_ml_predict(pao, qs_env)
     150          70 :          ELSE IF (ALLOCATED(pao%models)) THEN
     151           2 :             CALL pao_model_predict(pao, qs_env)
     152             :          ELSE
     153          68 :             CALL pao_param_initial_guess(pao, qs_env)
     154             :          END IF
     155          96 :          pao%matrix_X_ready = .TRUE.
     156         184 :       ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     157         120 :          CALL pao_ml_predict(pao, qs_env)
     158          64 :       ELSE IF (ALLOCATED(pao%models)) THEN
     159           0 :          CALL pao_model_predict(pao, qs_env)
     160             :       ELSE
     161          64 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
     162             :       END IF
     163             : 
     164             :       ! init line-search
     165         280 :       section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
     166         280 :       CALL linesearch_init(pao%linesearch, section, "PAO|")
     167             : 
     168             :       ! create some more matrices
     169         280 :       CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
     170         280 :       CALL dbcsr_set(pao%matrix_G, 0.0_dp)
     171             : 
     172         280 :       CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
     173         280 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
     174         280 :       CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
     175         280 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
     176             : 
     177             :       ! fill PAO transformation matrices
     178         280 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE.)
     179             : 
     180         280 :       CALL timestop(handle)
     181             :    END SUBROUTINE pao_optimization_start
     182             : 
     183             : ! **************************************************************************************************
     184             : !> \brief Called after the SCF optimization, updates the PAO basis.
     185             : !> \param qs_env ...
     186             : !> \param ls_scf_env ...
     187             : !> \param pao_is_done ...
     188             : ! **************************************************************************************************
     189        1062 :    SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
     190             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     191             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     192             :       LOGICAL, INTENT(OUT)                               :: pao_is_done
     193             : 
     194             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
     195             : 
     196             :       INTEGER                                            :: handle, icycle
     197             :       LOGICAL                                            :: cycle_converged, do_mixing, should_stop
     198             :       REAL(KIND=dp)                                      :: energy, penalty
     199             :       TYPE(dbcsr_type)                                   :: matrix_X_mixing
     200             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     201             :       TYPE(pao_env_type), POINTER                        :: pao
     202             : 
     203         816 :       IF (.NOT. ls_scf_env%do_pao) THEN
     204         318 :          pao_is_done = .TRUE.
     205         570 :          RETURN
     206             :       END IF
     207             : 
     208         498 :       ls_mstruct => ls_scf_env%ls_mstruct
     209         498 :       pao => ls_scf_env%pao_env
     210             : 
     211         498 :       IF (.NOT. pao%matrix_P_ready) THEN
     212         280 :          CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
     213         280 :          pao%matrix_P_ready = .TRUE.
     214             :       END IF
     215             : 
     216         498 :       IF (pao%max_pao == 0) THEN
     217         204 :          pao_is_done = .TRUE.
     218         204 :          RETURN
     219             :       END IF
     220             : 
     221         294 :       IF (pao%need_initial_scf) THEN
     222          48 :          pao_is_done = .FALSE.
     223          48 :          pao%need_initial_scf = .FALSE.
     224          48 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
     225          48 :          RETURN
     226             :       END IF
     227             : 
     228         246 :       CALL timeset(routineN, handle)
     229             : 
     230             :       ! perform mixing once we are well into the optimization
     231         246 :       do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
     232             :       IF (do_mixing) THEN
     233         128 :          CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
     234             :       END IF
     235             : 
     236         246 :       cycle_converged = .FALSE.
     237         246 :       icycle = 0
     238         246 :       CALL linesearch_reset(pao%linesearch)
     239         246 :       CALL pao_opt_init(pao)
     240             : 
     241       20056 :       DO WHILE (.TRUE.)
     242       10142 :          pao%istep = pao%istep + 1
     243             : 
     244       15213 :          IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
     245       10142 :             pao%istep, " ============================="
     246             : 
     247             :          ! calc energy and check trace_PS
     248       10142 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     249       10142 :          CALL pao_check_trace_PS(ls_scf_env)
     250             : 
     251       10142 :          IF (pao%linesearch%starts) THEN
     252        2620 :             icycle = icycle + 1
     253             :             ! calc new gradient including penalty terms
     254        2620 :             CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., penalty=penalty)
     255        2620 :             CALL pao_check_grad(pao, qs_env, ls_scf_env)
     256             : 
     257             :             ! calculate new direction for line-search
     258        2620 :             CALL pao_opt_new_dir(pao, icycle)
     259             : 
     260             :             !backup X
     261        2620 :             CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
     262             : 
     263             :             ! print info and convergence test
     264        2620 :             CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
     265        2620 :             IF (cycle_converged) THEN
     266         212 :                pao_is_done = icycle < 3
     267         212 :                IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
     268             :                EXIT
     269             :             END IF
     270             : 
     271             :             ! if we have reached the maximum number of cycles exit in order
     272             :             ! to restart with a fresh hamiltonian
     273        2408 :             IF (icycle >= pao%max_cycles) THEN
     274          16 :                IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
     275          16 :                pao_is_done = .FALSE.
     276          16 :                EXIT
     277             :             END IF
     278             : 
     279        2392 :             IF (MOD(icycle, pao%write_cycles) == 0) &
     280           8 :                CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
     281             :          END IF
     282             : 
     283             :          ! check for early abort without convergence?
     284        9914 :          CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
     285        9914 :          IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
     286          18 :             CPWARN("PAO not converged!")
     287          18 :             pao_is_done = .TRUE.
     288          18 :             EXIT
     289             :          END IF
     290             : 
     291             :          ! perform line-search step
     292        9896 :          CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
     293             : 
     294        9896 :          IF (pao%linesearch%step_size < 1e-9_dp) CPABORT("PAO gradient is wrong.")
     295             : 
     296        9896 :          CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
     297        9896 :          CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
     298             :       END DO
     299             : 
     300             :       ! perform mixing of matrix_X
     301         246 :       IF (do_mixing) THEN
     302         128 :          CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
     303         128 :          CALL dbcsr_release(matrix_X_mixing)
     304         128 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
     305         128 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     306             :       END IF
     307             : 
     308         246 :       CALL pao_write_restart(pao, qs_env, energy)
     309         246 :       CALL pao_opt_finalize(pao)
     310             : 
     311         246 :       CALL timestop(handle)
     312         816 :    END SUBROUTINE pao_update
     313             : 
     314             : ! **************************************************************************************************
     315             : !> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
     316             : !> \param qs_env ...
     317             : !> \param ls_scf_env ...
     318             : !> \param pao_is_done ...
     319             : ! **************************************************************************************************
     320        1096 :    SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
     321             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     322             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     323             :       LOGICAL, INTENT(IN)                                :: pao_is_done
     324             : 
     325             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
     326             : 
     327             :       INTEGER                                            :: handle
     328             : 
     329        1034 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     330         498 :       IF (.NOT. pao_is_done) RETURN
     331             : 
     332         280 :       CALL timeset(routineN, handle)
     333             : 
     334             :       ! print out the matrices here before pao_store_P converts them back into matrices in
     335             :       ! terms of the primary basis
     336         280 :       CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
     337         280 :       CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
     338             : 
     339         280 :       CALL pao_store_P(qs_env, ls_scf_env)
     340         280 :       IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
     341             : 
     342         280 :       CALL timestop(handle)
     343             :    END SUBROUTINE
     344             : 
     345             : ! **************************************************************************************************
     346             : !> \brief Finish a PAO optimization run.
     347             : !> \param ls_scf_env ...
     348             : ! **************************************************************************************************
     349         878 :    SUBROUTINE pao_optimization_end(ls_scf_env)
     350             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     351             : 
     352             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
     353             : 
     354             :       INTEGER                                            :: handle
     355             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     356             :       TYPE(pao_env_type), POINTER                        :: pao
     357             : 
     358         598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     359             : 
     360         280 :       pao => ls_scf_env%pao_env
     361         280 :       ls_mstruct => ls_scf_env%ls_mstruct
     362             : 
     363         280 :       CALL timeset(routineN, handle)
     364             : 
     365         280 :       CALL pao_param_finalize(pao)
     366             : 
     367             :       ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
     368         280 :       CALL dbcsr_release(pao%matrix_X_orig)
     369         280 :       CALL dbcsr_release(pao%matrix_G)
     370         280 :       CALL dbcsr_release(ls_mstruct%matrix_A)
     371         280 :       CALL dbcsr_release(ls_mstruct%matrix_B)
     372             : 
     373         280 :       CALL linesearch_finalize(pao%linesearch)
     374             : 
     375         280 :       CALL timestop(handle)
     376             :    END SUBROUTINE pao_optimization_end
     377             : 
     378             : END MODULE pao_main

Generated by: LCOV version 1.15