LCOV - code coverage report
Current view: top level - src - pao_methods.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4dc10b3) Lines: 340 348 97.7 %
Date: 2024-11-21 06:45:46 Functions: 19 19 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Methods used by pao_main.F
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_methods
      13             :    USE ao_util,                         ONLY: exp_radius
      14             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      15             :                                               get_atomic_kind
      16             :    USE basis_set_types,                 ONLY: gto_basis_set_type
      17             :    USE bibliography,                    ONLY: Kolafa2004,&
      18             :                                               Kuhne2007,&
      19             :                                               cite_reference
      20             :    USE cp_control_types,                ONLY: dft_control_type
      21             :    USE cp_dbcsr_api,                    ONLY: &
      22             :         dbcsr_add, dbcsr_binary_read, dbcsr_checksum, dbcsr_complete_redistribute, dbcsr_copy, &
      23             :         dbcsr_create, dbcsr_desymmetrize, dbcsr_distribution_get, dbcsr_distribution_new, &
      24             :         dbcsr_distribution_type, dbcsr_dot, dbcsr_filter, dbcsr_get_block_p, dbcsr_get_info, &
      25             :         dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
      26             :         dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_release, &
      27             :         dbcsr_reserve_diag_blocks, dbcsr_scale, dbcsr_set, dbcsr_type
      28             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      29             :                                               cp_logger_type,&
      30             :                                               cp_to_string
      31             :    USE dm_ls_scf_methods,               ONLY: density_matrix_trs4,&
      32             :                                               ls_scf_init_matrix_S
      33             :    USE dm_ls_scf_qs,                    ONLY: ls_scf_dm_to_ks,&
      34             :                                               ls_scf_qs_atomic_guess,&
      35             :                                               matrix_ls_to_qs,&
      36             :                                               matrix_qs_to_ls
      37             :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      38             :                                               ls_scf_env_type
      39             :    USE iterate_matrix,                  ONLY: purify_mcweeny
      40             :    USE kinds,                           ONLY: default_path_length,&
      41             :                                               dp
      42             :    USE machine,                         ONLY: m_walltime
      43             :    USE mathlib,                         ONLY: binomial,&
      44             :                                               diamat_all
      45             :    USE message_passing,                 ONLY: mp_para_env_type
      46             :    USE pao_ml,                          ONLY: pao_ml_forces
      47             :    USE pao_model,                       ONLY: pao_model_load
      48             :    USE pao_param,                       ONLY: pao_calc_AB,&
      49             :                                               pao_param_count
      50             :    USE pao_types,                       ONLY: pao_env_type
      51             :    USE particle_types,                  ONLY: particle_type
      52             :    USE qs_energy_types,                 ONLY: qs_energy_type
      53             :    USE qs_environment_types,            ONLY: get_qs_env,&
      54             :                                               qs_environment_type
      55             :    USE qs_initial_guess,                ONLY: calculate_atomic_fock_matrix
      56             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      57             :                                               pao_descriptor_type,&
      58             :                                               pao_potential_type,&
      59             :                                               qs_kind_type,&
      60             :                                               set_qs_kind
      61             :    USE qs_ks_methods,                   ONLY: qs_ks_update_qs_env
      62             :    USE qs_ks_types,                     ONLY: qs_ks_did_change
      63             :    USE qs_rho_methods,                  ONLY: qs_rho_update_rho
      64             :    USE qs_rho_types,                    ONLY: qs_rho_get,&
      65             :                                               qs_rho_type
      66             : 
      67             : !$ USE OMP_LIB, ONLY: omp_get_level
      68             : #include "./base/base_uses.f90"
      69             : 
      70             :    IMPLICIT NONE
      71             : 
      72             :    PRIVATE
      73             : 
      74             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_methods'
      75             : 
      76             :    PUBLIC :: pao_print_atom_info, pao_init_kinds
      77             :    PUBLIC :: pao_build_orthogonalizer, pao_build_selector
      78             :    PUBLIC :: pao_build_diag_distribution
      79             :    PUBLIC :: pao_build_matrix_X, pao_build_core_hamiltonian
      80             :    PUBLIC :: pao_test_convergence
      81             :    PUBLIC :: pao_calc_energy, pao_check_trace_ps
      82             :    PUBLIC :: pao_store_P, pao_add_forces, pao_guess_initial_P
      83             :    PUBLIC :: pao_check_grad
      84             : 
      85             : CONTAINS
      86             : 
      87             : ! **************************************************************************************************
      88             : !> \brief Initialize qs kinds
      89             : !> \param pao ...
      90             : !> \param qs_env ...
      91             : ! **************************************************************************************************
      92          96 :    SUBROUTINE pao_init_kinds(pao, qs_env)
      93             :       TYPE(pao_env_type), POINTER                        :: pao
      94             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      95             : 
      96             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init_kinds'
      97             : 
      98             :       CHARACTER(LEN=default_path_length)                 :: pao_model_file
      99             :       INTEGER                                            :: handle, i, ikind, pao_basis_size
     100             :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
     101          96 :       TYPE(pao_descriptor_type), DIMENSION(:), POINTER   :: pao_descriptors
     102          96 :       TYPE(pao_potential_type), DIMENSION(:), POINTER    :: pao_potentials
     103          96 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     104             : 
     105          96 :       CALL timeset(routineN, handle)
     106          96 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set)
     107             : 
     108         230 :       DO ikind = 1, SIZE(qs_kind_set)
     109             :          CALL get_qs_kind(qs_kind_set(ikind), &
     110             :                           basis_set=basis_set, &
     111             :                           pao_basis_size=pao_basis_size, &
     112             :                           pao_model_file=pao_model_file, &
     113             :                           pao_potentials=pao_potentials, &
     114         134 :                           pao_descriptors=pao_descriptors)
     115             : 
     116         134 :          IF (pao_basis_size < 1) THEN
     117             :             ! pao disabled for ikind, set pao_basis_size to size of primary basis
     118          12 :             CALL set_qs_kind(qs_kind_set(ikind), pao_basis_size=basis_set%nsgf)
     119             :          END IF
     120             : 
     121             :          ! initialize radii of Gaussians to speedup screeing later on
     122         196 :          DO i = 1, SIZE(pao_potentials)
     123         196 :             pao_potentials(i)%beta_radius = exp_radius(0, pao_potentials(i)%beta, pao%eps_pgf, 1.0_dp)
     124             :          END DO
     125         152 :          DO i = 1, SIZE(pao_descriptors)
     126          18 :             pao_descriptors(i)%beta_radius = exp_radius(0, pao_descriptors(i)%beta, pao%eps_pgf, 1.0_dp)
     127         152 :             pao_descriptors(i)%screening_radius = exp_radius(0, pao_descriptors(i)%screening, pao%eps_pgf, 1.0_dp)
     128             :          END DO
     129             : 
     130             :          ! Load torch model.
     131         364 :          IF (LEN_TRIM(pao_model_file) > 0) THEN
     132           4 :             IF (.NOT. ALLOCATED(pao%models)) &
     133          10 :                ALLOCATE (pao%models(SIZE(qs_kind_set)))
     134           4 :             CALL pao_model_load(pao, qs_env, ikind, pao_model_file, pao%models(ikind))
     135             :          END IF
     136             : 
     137             :       END DO
     138          96 :       CALL timestop(handle)
     139          96 :    END SUBROUTINE pao_init_kinds
     140             : 
     141             : ! **************************************************************************************************
     142             : !> \brief Prints a one line summary for each atom.
     143             : !> \param pao ...
     144             : ! **************************************************************************************************
     145          96 :    SUBROUTINE pao_print_atom_info(pao)
     146             :       TYPE(pao_env_type), POINTER                        :: pao
     147             : 
     148             :       INTEGER                                            :: iatom, natoms
     149          96 :       INTEGER, DIMENSION(:), POINTER                     :: pao_basis, param_cols, param_rows, &
     150          96 :                                                             pri_basis
     151             : 
     152          96 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=pri_basis, col_blk_size=pao_basis)
     153          96 :       CPASSERT(SIZE(pao_basis) == SIZE(pri_basis))
     154          96 :       natoms = SIZE(pao_basis)
     155             : 
     156          96 :       CALL dbcsr_get_info(pao%matrix_X, row_blk_size=param_rows, col_blk_size=param_cols)
     157          96 :       CPASSERT(SIZE(param_rows) == natoms .AND. SIZE(param_cols) == natoms)
     158             : 
     159          96 :       IF (pao%iw_atoms > 0) THEN
     160          12 :          DO iatom = 1, natoms
     161             :             WRITE (pao%iw_atoms, "(A,I7,T20,A,I3,T45,A,I3,T65,A,I3)") &
     162           9 :                " PAO| atom: ", iatom, &
     163           9 :                " prim_basis: ", pri_basis(iatom), &
     164           9 :                " pao_basis: ", pao_basis(iatom), &
     165          21 :                " pao_params: ", (param_cols(iatom)*param_rows(iatom))
     166             :          END DO
     167             :       END IF
     168          96 :    END SUBROUTINE pao_print_atom_info
     169             : 
     170             : ! **************************************************************************************************
     171             : !> \brief Constructs matrix_N and its inverse.
     172             : !> \param pao ...
     173             : !> \param qs_env ...
     174             : ! **************************************************************************************************
     175          96 :    SUBROUTINE pao_build_orthogonalizer(pao, qs_env)
     176             :       TYPE(pao_env_type), POINTER                        :: pao
     177             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     178             : 
     179             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_orthogonalizer'
     180             : 
     181             :       INTEGER                                            :: acol, arow, handle, i, iatom, j, k, N
     182             :       LOGICAL                                            :: found
     183             :       REAL(dp)                                           :: v, w
     184          96 :       REAL(dp), DIMENSION(:), POINTER                    :: evals
     185          96 :       REAL(dp), DIMENSION(:, :), POINTER                 :: A, block_N, block_N_inv, block_S
     186             :       TYPE(dbcsr_iterator_type)                          :: iter
     187          96 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     188             : 
     189          96 :       CALL timeset(routineN, handle)
     190             : 
     191          96 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     192             : 
     193          96 :       CALL dbcsr_create(pao%matrix_N, template=matrix_s(1)%matrix, name="PAO matrix_N")
     194          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N)
     195             : 
     196          96 :       CALL dbcsr_create(pao%matrix_N_inv, template=matrix_s(1)%matrix, name="PAO matrix_N_inv")
     197          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_inv)
     198             : 
     199             : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,matrix_s) &
     200          96 : !$OMP PRIVATE(iter,arow,acol,iatom,block_N,block_N_inv,block_S,found,N,A,evals,k,i,j,w,v)
     201             :       CALL dbcsr_iterator_start(iter, pao%matrix_N)
     202             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     203             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_N)
     204             :          iatom = arow; CPASSERT(arow == acol)
     205             : 
     206             :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_inv, row=iatom, col=iatom, block=block_N_inv, found=found)
     207             :          CPASSERT(ASSOCIATED(block_N_inv))
     208             : 
     209             :          CALL dbcsr_get_block_p(matrix=matrix_s(1)%matrix, row=iatom, col=iatom, block=block_S, found=found)
     210             :          CPASSERT(ASSOCIATED(block_S))
     211             : 
     212             :          N = SIZE(block_S, 1); CPASSERT(SIZE(block_S, 1) == SIZE(block_S, 2)) ! primary basis size
     213             :          ALLOCATE (A(N, N), evals(N))
     214             : 
     215             :          ! take square root of atomic overlap matrix
     216             :          A = block_S
     217             :          CALL diamat_all(A, evals) !afterwards A contains the eigenvectors
     218             :          block_N = 0.0_dp
     219             :          block_N_inv = 0.0_dp
     220             :          DO k = 1, N
     221             :             ! NOTE: To maintain a consistent notation with the Berghold paper,
     222             :             ! the "_inv" is swapped: N^{-1}=sqrt(S); N=sqrt(S)^{-1}
     223             :             w = 1.0_dp/SQRT(evals(k))
     224             :             v = SQRT(evals(k))
     225             :             DO i = 1, N
     226             :                DO j = 1, N
     227             :                   block_N(i, j) = block_N(i, j) + w*A(i, k)*A(j, k)
     228             :                   block_N_inv(i, j) = block_N_inv(i, j) + v*A(i, k)*A(j, k)
     229             :                END DO
     230             :             END DO
     231             :          END DO
     232             :          DEALLOCATE (A, evals)
     233             :       END DO
     234             :       CALL dbcsr_iterator_stop(iter)
     235             : !$OMP END PARALLEL
     236             : 
     237             :       ! store a copies of N and N_inv that are distributed according to pao%diag_distribution
     238             :       CALL dbcsr_create(pao%matrix_N_diag, &
     239             :                         name="PAO matrix_N_diag", &
     240             :                         dist=pao%diag_distribution, &
     241          96 :                         template=matrix_s(1)%matrix)
     242          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_diag)
     243          96 :       CALL dbcsr_complete_redistribute(pao%matrix_N, pao%matrix_N_diag)
     244             :       CALL dbcsr_create(pao%matrix_N_inv_diag, &
     245             :                         name="PAO matrix_N_inv_diag", &
     246             :                         dist=pao%diag_distribution, &
     247          96 :                         template=matrix_s(1)%matrix)
     248          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_inv_diag)
     249          96 :       CALL dbcsr_complete_redistribute(pao%matrix_N_inv, pao%matrix_N_inv_diag)
     250             : 
     251          96 :       CALL timestop(handle)
     252          96 :    END SUBROUTINE pao_build_orthogonalizer
     253             : 
     254             : ! **************************************************************************************************
     255             : !> \brief Build rectangular matrix to converert between primary and PAO basis.
     256             : !> \param pao ...
     257             : !> \param qs_env ...
     258             : ! **************************************************************************************************
     259          96 :    SUBROUTINE pao_build_selector(pao, qs_env)
     260             :       TYPE(pao_env_type), POINTER                        :: pao
     261             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     262             : 
     263             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_selector'
     264             : 
     265             :       INTEGER                                            :: acol, arow, handle, i, iatom, ikind, M, &
     266             :                                                             natoms
     267          96 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_aux, blk_sizes_pri
     268          96 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_Y
     269             :       TYPE(dbcsr_iterator_type)                          :: iter
     270          96 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     271          96 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     272          96 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     273             : 
     274          96 :       CALL timeset(routineN, handle)
     275             : 
     276             :       CALL get_qs_env(qs_env, &
     277             :                       natom=natoms, &
     278             :                       matrix_s=matrix_s, &
     279             :                       qs_kind_set=qs_kind_set, &
     280          96 :                       particle_set=particle_set)
     281             : 
     282          96 :       CALL dbcsr_get_info(matrix_s(1)%matrix, col_blk_size=blk_sizes_pri)
     283             : 
     284         288 :       ALLOCATE (blk_sizes_aux(natoms))
     285         322 :       DO iatom = 1, natoms
     286         226 :          CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     287         226 :          CALL get_qs_kind(qs_kind_set(ikind), pao_basis_size=M)
     288         226 :          CPASSERT(M > 0)
     289         226 :          IF (blk_sizes_pri(iatom) < M) &
     290           0 :             CPABORT("PAO basis size exceeds primary basis size.")
     291         548 :          blk_sizes_aux(iatom) = M
     292             :       END DO
     293             : 
     294             :       CALL dbcsr_create(pao%matrix_Y, &
     295             :                         template=matrix_s(1)%matrix, &
     296             :                         matrix_type="N", &
     297             :                         row_blk_size=blk_sizes_pri, &
     298             :                         col_blk_size=blk_sizes_aux, &
     299          96 :                         name="PAO matrix_Y")
     300          96 :       DEALLOCATE (blk_sizes_aux)
     301             : 
     302          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_Y)
     303             : 
     304             : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao) &
     305          96 : !$OMP PRIVATE(iter,arow,acol,block_Y,i,M)
     306             :       CALL dbcsr_iterator_start(iter, pao%matrix_Y)
     307             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     308             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_Y)
     309             :          M = SIZE(block_Y, 2) ! size of pao basis
     310             :          block_Y = 0.0_dp
     311             :          DO i = 1, M
     312             :             block_Y(i, i) = 1.0_dp
     313             :          END DO
     314             :       END DO
     315             :       CALL dbcsr_iterator_stop(iter)
     316             : !$OMP END PARALLEL
     317             : 
     318          96 :       CALL timestop(handle)
     319          96 :    END SUBROUTINE pao_build_selector
     320             : 
     321             : ! **************************************************************************************************
     322             : !> \brief Creates new DBCSR distribution which spreads diagonal blocks evenly across ranks
     323             : !> \param pao ...
     324             : !> \param qs_env ...
     325             : ! **************************************************************************************************
     326          96 :    SUBROUTINE pao_build_diag_distribution(pao, qs_env)
     327             :       TYPE(pao_env_type), POINTER                        :: pao
     328             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     329             : 
     330             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_diag_distribution'
     331             : 
     332             :       INTEGER                                            :: handle, iatom, natoms, pgrid_cols, &
     333             :                                                             pgrid_rows
     334          96 :       INTEGER, DIMENSION(:), POINTER                     :: diag_col_dist, diag_row_dist
     335             :       TYPE(dbcsr_distribution_type)                      :: main_dist
     336          96 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     337             : 
     338          96 :       CALL timeset(routineN, handle)
     339             : 
     340          96 :       CALL get_qs_env(qs_env, natom=natoms, matrix_s=matrix_s)
     341             : 
     342             :       ! get processor grid from matrix_s
     343          96 :       CALL dbcsr_get_info(matrix=matrix_s(1)%matrix, distribution=main_dist)
     344          96 :       CALL dbcsr_distribution_get(main_dist, nprows=pgrid_rows, npcols=pgrid_cols)
     345             : 
     346             :       ! create new mapping of matrix-grid to processor-grid
     347         480 :       ALLOCATE (diag_row_dist(natoms), diag_col_dist(natoms))
     348         322 :       DO iatom = 1, natoms
     349         226 :          diag_row_dist(iatom) = MOD(iatom - 1, pgrid_rows)
     350         322 :          diag_col_dist(iatom) = MOD((iatom - 1)/pgrid_rows, pgrid_cols)
     351             :       END DO
     352             : 
     353             :       ! instanciate distribution object
     354             :       CALL dbcsr_distribution_new(pao%diag_distribution, template=main_dist, &
     355          96 :                                   row_dist=diag_row_dist, col_dist=diag_col_dist)
     356             : 
     357          96 :       DEALLOCATE (diag_row_dist, diag_col_dist)
     358             : 
     359          96 :       CALL timestop(handle)
     360         192 :    END SUBROUTINE pao_build_diag_distribution
     361             : 
     362             : ! **************************************************************************************************
     363             : !> \brief Creates the matrix_X
     364             : !> \param pao ...
     365             : !> \param qs_env ...
     366             : ! **************************************************************************************************
     367          96 :    SUBROUTINE pao_build_matrix_X(pao, qs_env)
     368             :       TYPE(pao_env_type), POINTER                        :: pao
     369             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     370             : 
     371             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_matrix_X'
     372             : 
     373             :       INTEGER                                            :: handle, iatom, ikind, natoms
     374          96 :       INTEGER, DIMENSION(:), POINTER                     :: col_blk_size, row_blk_size
     375          96 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     376             : 
     377          96 :       CALL timeset(routineN, handle)
     378             : 
     379             :       CALL get_qs_env(qs_env, &
     380             :                       natom=natoms, &
     381          96 :                       particle_set=particle_set)
     382             : 
     383             :       ! determine block-sizes of matrix_X
     384         480 :       ALLOCATE (row_blk_size(natoms), col_blk_size(natoms))
     385         322 :       col_blk_size = 1
     386         322 :       DO iatom = 1, natoms
     387         226 :          CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     388         322 :          CALL pao_param_count(pao, qs_env, ikind, nparams=row_blk_size(iatom))
     389             :       END DO
     390             : 
     391             :       ! build actual matrix_X
     392             :       CALL dbcsr_create(pao%matrix_X, &
     393             :                         name="PAO matrix_X", &
     394             :                         dist=pao%diag_distribution, &
     395             :                         matrix_type="N", &
     396             :                         row_blk_size=row_blk_size, &
     397          96 :                         col_blk_size=col_blk_size)
     398          96 :       DEALLOCATE (row_blk_size, col_blk_size)
     399             : 
     400          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_X)
     401          96 :       CALL dbcsr_set(pao%matrix_X, 0.0_dp)
     402             : 
     403          96 :       CALL timestop(handle)
     404          96 :    END SUBROUTINE pao_build_matrix_X
     405             : 
     406             : ! **************************************************************************************************
     407             : !> \brief Creates the matrix_H0 which contains the core hamiltonian
     408             : !> \param pao ...
     409             : !> \param qs_env ...
     410             : ! **************************************************************************************************
     411          96 :    SUBROUTINE pao_build_core_hamiltonian(pao, qs_env)
     412             :       TYPE(pao_env_type), POINTER                        :: pao
     413             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     414             : 
     415          96 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     416          96 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     417          96 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     418             : 
     419             :       CALL get_qs_env(qs_env, &
     420             :                       matrix_s=matrix_s, &
     421             :                       atomic_kind_set=atomic_kind_set, &
     422          96 :                       qs_kind_set=qs_kind_set)
     423             : 
     424             :       ! allocate matrix_H0
     425             :       CALL dbcsr_create(pao%matrix_H0, &
     426             :                         name="PAO matrix_H0", &
     427             :                         dist=pao%diag_distribution, &
     428          96 :                         template=matrix_s(1)%matrix)
     429          96 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_H0)
     430             : 
     431             :       ! calculate initial atomic fock matrix H0
     432             :       ! Can't use matrix_ks from ls_scf_qs_atomic_guess(), because it's not rotationally invariant.
     433             :       ! getting H0 directly from the atomic code
     434             :       CALL calculate_atomic_fock_matrix(pao%matrix_H0, &
     435             :                                         atomic_kind_set, &
     436             :                                         qs_kind_set, &
     437          96 :                                         ounit=pao%iw)
     438             : 
     439          96 :    END SUBROUTINE pao_build_core_hamiltonian
     440             : 
     441             : ! **************************************************************************************************
     442             : !> \brief Test whether the PAO optimization has reached convergence
     443             : !> \param pao ...
     444             : !> \param ls_scf_env ...
     445             : !> \param new_energy ...
     446             : !> \param is_converged ...
     447             : ! **************************************************************************************************
     448        2620 :    SUBROUTINE pao_test_convergence(pao, ls_scf_env, new_energy, is_converged)
     449             :       TYPE(pao_env_type), POINTER                        :: pao
     450             :       TYPE(ls_scf_env_type)                              :: ls_scf_env
     451             :       REAL(KIND=dp), INTENT(IN)                          :: new_energy
     452             :       LOGICAL, INTENT(OUT)                               :: is_converged
     453             : 
     454             :       REAL(KIND=dp)                                      :: energy_diff, loop_eps, now, time_diff
     455             : 
     456             :       ! calculate progress
     457        2620 :       energy_diff = new_energy - pao%energy_prev
     458        2620 :       pao%energy_prev = new_energy
     459        2620 :       now = m_walltime()
     460        2620 :       time_diff = now - pao%step_start_time
     461        2620 :       pao%step_start_time = now
     462             : 
     463             :       ! convergence criterion
     464        2620 :       loop_eps = pao%norm_G/ls_scf_env%nelectron_total
     465        2620 :       is_converged = loop_eps < pao%eps_pao
     466             : 
     467        2620 :       IF (pao%istep > 1) THEN
     468        2544 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| energy improvement:", energy_diff
     469             :          ! CPWARN_IF(energy_diff>0.0_dp, "PAO| energy increased")
     470             : 
     471             :          ! print one-liner
     472        2544 :          IF (pao%iw > 0) WRITE (pao%iw, '(A,I6,11X,F20.9,1X,E10.3,1X,E10.3,1X,F9.3)') &
     473        1272 :             " PAO| step ", &
     474        1272 :             pao%istep, &
     475        1272 :             new_energy, &
     476        1272 :             loop_eps, &
     477        1272 :             pao%linesearch%step_size, & !prev step, which let to the current energy
     478        2544 :             time_diff
     479             :       END IF
     480        2620 :    END SUBROUTINE pao_test_convergence
     481             : 
     482             : ! **************************************************************************************************
     483             : !> \brief Calculate the pao energy
     484             : !> \param pao ...
     485             : !> \param qs_env ...
     486             : !> \param ls_scf_env ...
     487             : !> \param energy ...
     488             : ! **************************************************************************************************
     489       11818 :    SUBROUTINE pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     490             :       TYPE(pao_env_type), POINTER                        :: pao
     491             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     492             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     493             :       REAL(KIND=dp), INTENT(OUT)                         :: energy
     494             : 
     495             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_calc_energy'
     496             : 
     497             :       INTEGER                                            :: handle, ispin
     498             :       REAL(KIND=dp)                                      :: penalty, trace_PH
     499             : 
     500       11818 :       CALL timeset(routineN, handle)
     501             : 
     502             :       ! calculate matrix U, which determines the pao basis
     503       11818 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE., penalty=penalty)
     504             : 
     505             :       ! calculat S, S_inv, S_sqrt, and S_sqrt_inv in the new pao basis
     506       11818 :       CALL pao_rebuild_S(qs_env, ls_scf_env)
     507             : 
     508             :       ! calculate the density matrix P in the pao basis
     509       11818 :       CALL pao_dm_trs4(qs_env, ls_scf_env)
     510             : 
     511             :       ! calculate the energy from the trace(PH) in the pao basis
     512       11818 :       energy = 0.0_dp
     513       23636 :       DO ispin = 1, ls_scf_env%nspins
     514       11818 :          CALL dbcsr_dot(ls_scf_env%matrix_p(ispin), ls_scf_env%matrix_ks(ispin), trace_PH)
     515       23636 :          energy = energy + trace_PH
     516             :       END DO
     517             : 
     518             :       ! add penalty term
     519       11818 :       energy = energy + penalty
     520             : 
     521       11818 :       IF (pao%iw > 0) THEN
     522        5909 :          WRITE (pao%iw, *) ""
     523        5909 :          WRITE (pao%iw, *) "PAO| energy:", energy, "penalty:", penalty
     524             :       END IF
     525       11818 :       CALL timestop(handle)
     526       11818 :    END SUBROUTINE pao_calc_energy
     527             : 
     528             : ! **************************************************************************************************
     529             : !> \brief Ensure that the number of electrons is correct.
     530             : !> \param ls_scf_env ...
     531             : ! **************************************************************************************************
     532       10326 :    SUBROUTINE pao_check_trace_PS(ls_scf_env)
     533             :       TYPE(ls_scf_env_type)                              :: ls_scf_env
     534             : 
     535             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_check_trace_PS'
     536             : 
     537             :       INTEGER                                            :: handle, ispin
     538             :       REAL(KIND=dp)                                      :: tmp, trace_PS
     539             :       TYPE(dbcsr_type)                                   :: matrix_S_desym
     540             : 
     541       10326 :       CALL timeset(routineN, handle)
     542       10326 :       CALL dbcsr_create(matrix_S_desym, template=ls_scf_env%matrix_s, matrix_type="N")
     543       10326 :       CALL dbcsr_desymmetrize(ls_scf_env%matrix_s, matrix_S_desym)
     544             : 
     545       10326 :       trace_PS = 0.0_dp
     546       20652 :       DO ispin = 1, ls_scf_env%nspins
     547       10326 :          CALL dbcsr_dot(ls_scf_env%matrix_p(ispin), matrix_S_desym, tmp)
     548       20652 :          trace_PS = trace_PS + tmp
     549             :       END DO
     550             : 
     551       10326 :       CALL dbcsr_release(matrix_S_desym)
     552             : 
     553       10326 :       IF (ABS(ls_scf_env%nelectron_total - trace_PS) > 0.5) &
     554           0 :          CPABORT("Number of electrons wrong. Trace(PS) ="//cp_to_string(trace_PS))
     555             : 
     556       10326 :       CALL timestop(handle)
     557       10326 :    END SUBROUTINE pao_check_trace_PS
     558             : 
     559             : ! **************************************************************************************************
     560             : !> \brief Read primary density matrix from file.
     561             : !> \param pao ...
     562             : !> \param qs_env ...
     563             : ! **************************************************************************************************
     564          56 :    SUBROUTINE pao_read_preopt_dm(pao, qs_env)
     565             :       TYPE(pao_env_type), POINTER                        :: pao
     566             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     567             : 
     568             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_read_preopt_dm'
     569             : 
     570             :       INTEGER                                            :: handle, ispin
     571             :       REAL(KIND=dp)                                      :: cs_pos
     572             :       TYPE(dbcsr_distribution_type)                      :: dist
     573          28 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s, rho_ao
     574             :       TYPE(dbcsr_type)                                   :: matrix_tmp
     575             :       TYPE(dft_control_type), POINTER                    :: dft_control
     576             :       TYPE(qs_energy_type), POINTER                      :: energy
     577             :       TYPE(qs_rho_type), POINTER                         :: rho
     578             : 
     579          28 :       CALL timeset(routineN, handle)
     580             : 
     581             :       CALL get_qs_env(qs_env, &
     582             :                       dft_control=dft_control, &
     583             :                       matrix_s=matrix_s, &
     584             :                       rho=rho, &
     585          28 :                       energy=energy)
     586             : 
     587          28 :       CALL qs_rho_get(rho, rho_ao=rho_ao)
     588             : 
     589          28 :       IF (dft_control%nspins /= 1) CPABORT("open shell not yet implemented")
     590             : 
     591          28 :       CALL dbcsr_get_info(matrix_s(1)%matrix, distribution=dist)
     592             : 
     593          56 :       DO ispin = 1, dft_control%nspins
     594          28 :          CALL dbcsr_binary_read(pao%preopt_dm_file, matrix_new=matrix_tmp, distribution=dist)
     595          28 :          cs_pos = dbcsr_checksum(matrix_tmp, pos=.TRUE.)
     596          28 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Read restart DM "// &
     597          14 :             TRIM(pao%preopt_dm_file)//" with checksum: ", cs_pos
     598          28 :          CALL dbcsr_copy(rho_ao(ispin)%matrix, matrix_tmp, keep_sparsity=.TRUE.)
     599          56 :          CALL dbcsr_release(matrix_tmp)
     600             :       END DO
     601             : 
     602             :       ! calculate corresponding ks matrix
     603          28 :       CALL qs_rho_update_rho(rho, qs_env=qs_env)
     604          28 :       CALL qs_ks_did_change(qs_env%ks_env, rho_changed=.TRUE.)
     605             :       CALL qs_ks_update_qs_env(qs_env, calculate_forces=.FALSE., &
     606          28 :                                just_energy=.FALSE., print_active=.TRUE.)
     607          28 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Quickstep energy from restart density:", energy%total
     608             : 
     609          28 :       CALL timestop(handle)
     610             : 
     611          28 :    END SUBROUTINE pao_read_preopt_dm
     612             : 
     613             : ! **************************************************************************************************
     614             : !> \brief Rebuilds S, S_inv, S_sqrt, and S_sqrt_inv in the pao basis
     615             : !> \param qs_env ...
     616             : !> \param ls_scf_env ...
     617             : ! **************************************************************************************************
     618       11818 :    SUBROUTINE pao_rebuild_S(qs_env, ls_scf_env)
     619             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     620             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     621             : 
     622             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_rebuild_S'
     623             : 
     624             :       INTEGER                                            :: handle
     625       11818 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     626             : 
     627       11818 :       CALL timeset(routineN, handle)
     628             : 
     629       11818 :       CALL dbcsr_release(ls_scf_env%matrix_s_inv)
     630       11818 :       CALL dbcsr_release(ls_scf_env%matrix_s_sqrt)
     631       11818 :       CALL dbcsr_release(ls_scf_env%matrix_s_sqrt_inv)
     632             : 
     633       11818 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     634       11818 :       CALL ls_scf_init_matrix_s(matrix_s(1)%matrix, ls_scf_env)
     635             : 
     636       11818 :       CALL timestop(handle)
     637       11818 :    END SUBROUTINE pao_rebuild_S
     638             : 
     639             : ! **************************************************************************************************
     640             : !> \brief Calculate density matrix using TRS4 purification
     641             : !> \param qs_env ...
     642             : !> \param ls_scf_env ...
     643             : ! **************************************************************************************************
     644       11818 :    SUBROUTINE pao_dm_trs4(qs_env, ls_scf_env)
     645             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     646             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     647             : 
     648             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_dm_trs4'
     649             : 
     650             :       CHARACTER(LEN=default_path_length)                 :: project_name
     651             :       INTEGER                                            :: handle, ispin, nelectron_spin_real, nspin
     652             :       LOGICAL                                            :: converged
     653             :       REAL(KIND=dp)                                      :: homo_spin, lumo_spin, mu_spin
     654             :       TYPE(cp_logger_type), POINTER                      :: logger
     655       11818 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks
     656             : 
     657       11818 :       CALL timeset(routineN, handle)
     658       11818 :       logger => cp_get_default_logger()
     659       11818 :       project_name = logger%iter_info%project_name
     660       11818 :       nspin = ls_scf_env%nspins
     661             : 
     662       11818 :       CALL get_qs_env(qs_env, matrix_ks=matrix_ks)
     663       23636 :       DO ispin = 1, nspin
     664             :          CALL matrix_qs_to_ls(ls_scf_env%matrix_ks(ispin), matrix_ks(ispin)%matrix, &
     665       11818 :                               ls_scf_env%ls_mstruct, covariant=.TRUE.)
     666             : 
     667       11818 :          nelectron_spin_real = ls_scf_env%nelectron_spin(ispin)
     668       11818 :          IF (ls_scf_env%nspins == 1) nelectron_spin_real = nelectron_spin_real/2
     669             :          CALL density_matrix_trs4(ls_scf_env%matrix_p(ispin), ls_scf_env%matrix_ks(ispin), &
     670             :                                   ls_scf_env%matrix_s_sqrt_inv, &
     671             :                                   nelectron_spin_real, ls_scf_env%eps_filter, homo_spin, lumo_spin, mu_spin, &
     672             :                                   dynamic_threshold=.FALSE., converged=converged, &
     673             :                                   max_iter_lanczos=ls_scf_env%max_iter_lanczos, &
     674       11818 :                                   eps_lanczos=ls_scf_env%eps_lanczos)
     675       23636 :          IF (.NOT. converged) CPABORT("TRS4 did not converge")
     676             :       END DO
     677             : 
     678       11818 :       IF (nspin == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(1), 2.0_dp)
     679             : 
     680       11818 :       CALL timestop(handle)
     681       11818 :    END SUBROUTINE pao_dm_trs4
     682             : 
     683             : ! **************************************************************************************************
     684             : !> \brief Debugging routine for checking the analytic gradient.
     685             : !> \param pao ...
     686             : !> \param qs_env ...
     687             : !> \param ls_scf_env ...
     688             : ! **************************************************************************************************
     689        2632 :    SUBROUTINE pao_check_grad(pao, qs_env, ls_scf_env)
     690             :       TYPE(pao_env_type), POINTER                        :: pao
     691             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     692             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     693             : 
     694             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_check_grad'
     695             : 
     696             :       INTEGER                                            :: handle, i, iatom, j, natoms
     697        2620 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_col, blk_sizes_row
     698             :       LOGICAL                                            :: found
     699             :       REAL(dp)                                           :: delta, delta_max, eps, Gij_num
     700        2620 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_G, block_X
     701             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     702             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     703             : 
     704        2608 :       IF (pao%check_grad_tol < 0.0_dp) RETURN ! no checking
     705             : 
     706          12 :       CALL timeset(routineN, handle)
     707             : 
     708          12 :       ls_mstruct => ls_scf_env%ls_mstruct
     709             : 
     710          12 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natoms)
     711             : 
     712          12 :       eps = pao%num_grad_eps
     713          12 :       delta_max = 0.0_dp
     714             : 
     715          12 :       CALL dbcsr_get_info(pao%matrix_X, col_blk_size=blk_sizes_col, row_blk_size=blk_sizes_row)
     716             : 
     717             :       ! can not use an iterator here, because other DBCSR routines are called within loop.
     718          38 :       DO iatom = 1, natoms
     719          26 :          IF (pao%iw > 0) WRITE (pao%iw, *) 'PAO| checking gradient of atom ', iatom
     720          26 :          CALL dbcsr_get_block_p(matrix=pao%matrix_X, row=iatom, col=iatom, block=block_X, found=found)
     721             : 
     722          26 :          IF (ASSOCIATED(block_X)) THEN !only one node actually has the block
     723          13 :             CALL dbcsr_get_block_p(matrix=pao%matrix_G, row=iatom, col=iatom, block=block_G, found=found)
     724          13 :             CPASSERT(ASSOCIATED(block_G))
     725             :          END IF
     726             : 
     727         586 :          DO i = 1, blk_sizes_row(iatom)
     728        1070 :             DO j = 1, blk_sizes_col(iatom)
     729         828 :                SELECT CASE (pao%num_grad_order)
     730             :                CASE (2) ! calculate derivative to 2th order
     731         306 :                   Gij_num = -eval_point(block_X, i, j, -eps, pao, ls_scf_env, qs_env)
     732         306 :                   Gij_num = Gij_num + eval_point(block_X, i, j, +eps, pao, ls_scf_env, qs_env)
     733         306 :                   Gij_num = Gij_num/(2.0_dp*eps)
     734             : 
     735             :                CASE (4) ! calculate derivative to 4th order
     736         180 :                   Gij_num = eval_point(block_X, i, j, -2_dp*eps, pao, ls_scf_env, qs_env)
     737         180 :                   Gij_num = Gij_num - 8_dp*eval_point(block_X, i, j, -1_dp*eps, pao, ls_scf_env, qs_env)
     738         180 :                   Gij_num = Gij_num + 8_dp*eval_point(block_X, i, j, +1_dp*eps, pao, ls_scf_env, qs_env)
     739         180 :                   Gij_num = Gij_num - eval_point(block_X, i, j, +2_dp*eps, pao, ls_scf_env, qs_env)
     740         180 :                   Gij_num = Gij_num/(12.0_dp*eps)
     741             : 
     742             :                CASE (6) ! calculate derivative to 6th order
     743          36 :                   Gij_num = -1_dp*eval_point(block_X, i, j, -3_dp*eps, pao, ls_scf_env, qs_env)
     744          36 :                   Gij_num = Gij_num + 9_dp*eval_point(block_X, i, j, -2_dp*eps, pao, ls_scf_env, qs_env)
     745          36 :                   Gij_num = Gij_num - 45_dp*eval_point(block_X, i, j, -1_dp*eps, pao, ls_scf_env, qs_env)
     746          36 :                   Gij_num = Gij_num + 45_dp*eval_point(block_X, i, j, +1_dp*eps, pao, ls_scf_env, qs_env)
     747          36 :                   Gij_num = Gij_num - 9_dp*eval_point(block_X, i, j, +2_dp*eps, pao, ls_scf_env, qs_env)
     748          36 :                   Gij_num = Gij_num + 1_dp*eval_point(block_X, i, j, +3_dp*eps, pao, ls_scf_env, qs_env)
     749          36 :                   Gij_num = Gij_num/(60.0_dp*eps)
     750             : 
     751             :                CASE DEFAULT
     752         522 :                   CPABORT("Unsupported numerical derivative order: "//cp_to_string(pao%num_grad_order))
     753             :                END SELECT
     754             : 
     755        1044 :                IF (ASSOCIATED(block_X)) THEN
     756         261 :                   delta = ABS(Gij_num - block_G(i, j))
     757         261 :                   delta_max = MAX(delta_max, delta)
     758             :                   !WRITE (*,*) "gradient check", iatom, i, j, Gij_num, block_G(i,j), delta
     759             :                END IF
     760             :             END DO
     761             :          END DO
     762             :       END DO
     763             : 
     764          12 :       CALL para_env%max(delta_max)
     765          12 :       IF (pao%iw > 0) WRITE (pao%iw, *) 'PAO| checked gradient, max delta:', delta_max
     766          12 :       IF (delta_max > pao%check_grad_tol) CALL cp_abort(__LOCATION__, &
     767           0 :                                                         "Analytic and numeric gradients differ too much:"//cp_to_string(delta_max))
     768             : 
     769          12 :       CALL timestop(handle)
     770        2620 :    END SUBROUTINE pao_check_grad
     771             : 
     772             : ! **************************************************************************************************
     773             : !> \brief Helper routine for pao_check_grad()
     774             : !> \param block_X ...
     775             : !> \param i ...
     776             : !> \param j ...
     777             : !> \param eps ...
     778             : !> \param pao ...
     779             : !> \param ls_scf_env ...
     780             : !> \param qs_env ...
     781             : !> \return ...
     782             : ! **************************************************************************************************
     783        3096 :    FUNCTION eval_point(block_X, i, j, eps, pao, ls_scf_env, qs_env) RESULT(energy)
     784             :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     785             :       INTEGER, INTENT(IN)                                :: i, j
     786             :       REAL(dp), INTENT(IN)                               :: eps
     787             :       TYPE(pao_env_type), POINTER                        :: pao
     788             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     789             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     790             :       REAL(dp)                                           :: energy
     791             : 
     792             :       REAL(dp)                                           :: old_Xij
     793             : 
     794        1548 :       IF (ASSOCIATED(block_X)) THEN
     795         774 :          old_Xij = block_X(i, j) ! backup old block_X
     796         774 :          block_X(i, j) = block_X(i, j) + eps ! add perturbation
     797             :       END IF
     798             : 
     799             :       ! calculate energy
     800        1548 :       CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     801             : 
     802             :       ! restore old block_X
     803        1548 :       IF (ASSOCIATED(block_X)) THEN
     804         774 :          block_X(i, j) = old_Xij
     805             :       END IF
     806             : 
     807        1548 :    END FUNCTION eval_point
     808             : 
     809             : ! **************************************************************************************************
     810             : !> \brief Stores density matrix as initial guess for next SCF optimization.
     811             : !> \param qs_env ...
     812             : !> \param ls_scf_env ...
     813             : ! **************************************************************************************************
     814         560 :    SUBROUTINE pao_store_P(qs_env, ls_scf_env)
     815             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     816             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     817             : 
     818             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_store_P'
     819             : 
     820             :       INTEGER                                            :: handle, ispin, istore
     821         280 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     822             :       TYPE(dft_control_type), POINTER                    :: dft_control
     823             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     824             :       TYPE(pao_env_type), POINTER                        :: pao
     825             : 
     826           0 :       IF (ls_scf_env%scf_history%nstore == 0) RETURN
     827         280 :       CALL timeset(routineN, handle)
     828         280 :       ls_mstruct => ls_scf_env%ls_mstruct
     829         280 :       pao => ls_scf_env%pao_env
     830         280 :       CALL get_qs_env(qs_env, dft_control=dft_control, matrix_s=matrix_s)
     831             : 
     832         280 :       ls_scf_env%scf_history%istore = ls_scf_env%scf_history%istore + 1
     833         280 :       istore = MOD(ls_scf_env%scf_history%istore - 1, ls_scf_env%scf_history%nstore) + 1
     834         280 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Storing density matrix for ASPC guess in slot:", istore
     835             : 
     836             :       ! initialize storage
     837         280 :       IF (ls_scf_env%scf_history%istore <= ls_scf_env%scf_history%nstore) THEN
     838         212 :          DO ispin = 1, dft_control%nspins
     839         212 :             CALL dbcsr_create(ls_scf_env%scf_history%matrix(ispin, istore), template=matrix_s(1)%matrix)
     840             :          END DO
     841             :       END IF
     842             : 
     843             :       ! We are storing the density matrix in the non-orthonormal primary basis.
     844             :       ! While the orthonormal basis would yield better extrapolations,
     845             :       ! we simply can not afford to calculat S_sqrt in the primary basis.
     846         560 :       DO ispin = 1, dft_control%nspins
     847             :          ! transform into primary basis
     848             :          CALL matrix_ls_to_qs(ls_scf_env%scf_history%matrix(ispin, istore), ls_scf_env%matrix_p(ispin), &
     849         560 :                               ls_scf_env%ls_mstruct, covariant=.FALSE., keep_sparsity=.FALSE.)
     850             :       END DO
     851             : 
     852         280 :       CALL timestop(handle)
     853         280 :    END SUBROUTINE pao_store_P
     854             : 
     855             : ! **************************************************************************************************
     856             : !> \brief Provide an initial guess for the density matrix
     857             : !> \param pao ...
     858             : !> \param qs_env ...
     859             : !> \param ls_scf_env ...
     860             : ! **************************************************************************************************
     861         280 :    SUBROUTINE pao_guess_initial_P(pao, qs_env, ls_scf_env)
     862             :       TYPE(pao_env_type), POINTER                        :: pao
     863             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     864             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     865             : 
     866             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_guess_initial_P'
     867             : 
     868             :       INTEGER                                            :: handle
     869             : 
     870         280 :       CALL timeset(routineN, handle)
     871             : 
     872         280 :       IF (ls_scf_env%scf_history%istore > 0) THEN
     873         184 :          CALL pao_aspc_guess_P(pao, qs_env, ls_scf_env)
     874         184 :          pao%need_initial_scf = .TRUE.
     875             :       ELSE
     876          96 :          IF (LEN_TRIM(pao%preopt_dm_file) > 0) THEN
     877          28 :             CALL pao_read_preopt_dm(pao, qs_env)
     878          28 :             pao%need_initial_scf = .FALSE.
     879          28 :             pao%preopt_dm_file = "" ! load only for first MD step
     880             :          ELSE
     881          68 :             CALL ls_scf_qs_atomic_guess(qs_env, ls_scf_env, ls_scf_env%energy_init)
     882          68 :             IF (pao%iw > 0) WRITE (pao%iw, '(A,F20.9)') &
     883          34 :                " PAO| Energy from initial atomic guess:", ls_scf_env%energy_init
     884          68 :             pao%need_initial_scf = .TRUE.
     885             :          END IF
     886             :       END IF
     887             : 
     888         280 :       CALL timestop(handle)
     889             : 
     890         280 :    END SUBROUTINE pao_guess_initial_P
     891             : 
     892             : ! **************************************************************************************************
     893             : !> \brief Run the Always Stable Predictor-Corrector to guess an initial density matrix
     894             : !> \param pao ...
     895             : !> \param qs_env ...
     896             : !> \param ls_scf_env ...
     897             : ! **************************************************************************************************
     898         184 :    SUBROUTINE pao_aspc_guess_P(pao, qs_env, ls_scf_env)
     899             :       TYPE(pao_env_type), POINTER                        :: pao
     900             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     901             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     902             : 
     903             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_aspc_guess_P'
     904             : 
     905             :       INTEGER                                            :: handle, iaspc, ispin, istore, naspc
     906             :       REAL(dp)                                           :: alpha
     907         184 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     908             :       TYPE(dbcsr_type)                                   :: matrix_P
     909             :       TYPE(dft_control_type), POINTER                    :: dft_control
     910             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     911             : 
     912         184 :       CALL timeset(routineN, handle)
     913         184 :       ls_mstruct => ls_scf_env%ls_mstruct
     914         184 :       CPASSERT(ls_scf_env%scf_history%istore > 0)
     915         184 :       CALL cite_reference(Kolafa2004)
     916         184 :       CALL cite_reference(Kuhne2007)
     917         184 :       CALL get_qs_env(qs_env, dft_control=dft_control, matrix_s=matrix_s)
     918             : 
     919         184 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Calculating initial guess with ASPC"
     920             : 
     921         184 :       CALL dbcsr_create(matrix_P, template=matrix_s(1)%matrix)
     922             : 
     923         184 :       naspc = MIN(ls_scf_env%scf_history%istore, ls_scf_env%scf_history%nstore)
     924         368 :       DO ispin = 1, dft_control%nspins
     925             :          ! actual extrapolation
     926         184 :          CALL dbcsr_set(matrix_P, 0.0_dp)
     927         392 :          DO iaspc = 1, naspc
     928             :             alpha = (-1.0_dp)**(iaspc + 1)*REAL(iaspc, KIND=dp)* &
     929         208 :                     binomial(2*naspc, naspc - iaspc)/binomial(2*naspc - 2, naspc - 1)
     930         208 :             istore = MOD(ls_scf_env%scf_history%istore - iaspc, ls_scf_env%scf_history%nstore) + 1
     931         392 :             CALL dbcsr_add(matrix_P, ls_scf_env%scf_history%matrix(ispin, istore), 1.0_dp, alpha)
     932             :          END DO
     933             : 
     934             :          ! transform back from primary basis into pao basis
     935         368 :          CALL matrix_qs_to_ls(ls_scf_env%matrix_p(ispin), matrix_P, ls_scf_env%ls_mstruct, covariant=.FALSE.)
     936             :       END DO
     937             : 
     938         184 :       CALL dbcsr_release(matrix_P)
     939             : 
     940             :       ! linear combination of P's is not idempotent. A bit of McWeeny is needed to ensure it is again
     941         368 :       DO ispin = 1, dft_control%nspins
     942         184 :          IF (dft_control%nspins == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(ispin), 0.5_dp)
     943             :          ! to ensure that noisy blocks do not build up during MD (in particular with curvy) filter that guess a bit more
     944         184 :          CALL dbcsr_filter(ls_scf_env%matrix_p(ispin), ls_scf_env%eps_filter**(2.0_dp/3.0_dp))
     945             :          ! we could go to the orthonomal basis, but it seems not worth the trouble
     946             :          ! TODO : 10 iterations is a conservative upper bound, figure out when it fails
     947         184 :          CALL purify_mcweeny(ls_scf_env%matrix_p(ispin:ispin), ls_scf_env%matrix_s, ls_scf_env%eps_filter, 10)
     948         368 :          IF (dft_control%nspins == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(ispin), 2.0_dp)
     949             :       END DO
     950             : 
     951         184 :       CALL pao_check_trace_PS(ls_scf_env) ! sanity check
     952             : 
     953             :       ! compute corresponding energy and ks matrix
     954         184 :       CALL ls_scf_dm_to_ks(qs_env, ls_scf_env, ls_scf_env%energy_init, iscf=0)
     955             : 
     956         184 :       CALL timestop(handle)
     957         184 :    END SUBROUTINE pao_aspc_guess_P
     958             : 
     959             : ! **************************************************************************************************
     960             : !> \brief Calculate the forces contributed by PAO
     961             : !> \param qs_env ...
     962             : !> \param ls_scf_env ...
     963             : ! **************************************************************************************************
     964          42 :    SUBROUTINE pao_add_forces(qs_env, ls_scf_env)
     965             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     966             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     967             : 
     968             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_add_forces'
     969             : 
     970             :       INTEGER                                            :: handle, iatom, natoms
     971          42 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: forces
     972             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     973             :       TYPE(pao_env_type), POINTER                        :: pao
     974          42 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     975             : 
     976          42 :       CALL timeset(routineN, handle)
     977          42 :       pao => ls_scf_env%pao_env
     978             : 
     979          42 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Adding forces."
     980             : 
     981          42 :       IF (pao%max_pao /= 0) THEN
     982          20 :          IF (pao%penalty_strength /= 0.0_dp) &
     983           0 :             CPABORT("PAO forces require PENALTY_STRENGTH or MAX_PAO set to zero")
     984          20 :          IF (pao%linpot_regu_strength /= 0.0_dp) &
     985           0 :             CPABORT("PAO forces require LINPOT_REGULARIZATION_STRENGTH or MAX_PAO set to zero")
     986          20 :          IF (pao%regularization /= 0.0_dp) &
     987           0 :             CPABORT("PAO forces require REGULARIZATION or MAX_PAO set to zero")
     988             :       END IF
     989             : 
     990             :       CALL get_qs_env(qs_env, &
     991             :                       para_env=para_env, &
     992             :                       particle_set=particle_set, &
     993          42 :                       natom=natoms)
     994             : 
     995         126 :       ALLOCATE (forces(natoms, 3))
     996          42 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., forces=forces) ! without penalty terms
     997             : 
     998          42 :       IF (SIZE(pao%ml_training_set) > 0) &
     999          18 :          CALL pao_ml_forces(pao, qs_env, pao%matrix_G, forces)
    1000             : 
    1001          42 :       IF (ALLOCATED(pao%models)) &
    1002           0 :          CPABORT("PAO forces for PyTorch models are not yet implemented.")
    1003             : 
    1004          42 :       CALL para_env%sum(forces)
    1005         136 :       DO iatom = 1, natoms
    1006         418 :          particle_set(iatom)%f = particle_set(iatom)%f + forces(iatom, :)
    1007             :       END DO
    1008             : 
    1009          42 :       DEALLOCATE (forces)
    1010             : 
    1011          42 :       CALL timestop(handle)
    1012             : 
    1013          42 :    END SUBROUTINE pao_add_forces
    1014             : 
    1015             : END MODULE pao_methods

Generated by: LCOV version 1.15