LCOV - code coverage report
Current view: top level - src - pao_methods.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:4c33f95) Lines: 341 348 98.0 %
Date: 2025-01-30 06:53:08 Functions: 19 19 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \brief Methods used by pao_main.F
      10             : !> \author Ole Schuett
      11             : ! **************************************************************************************************
      12             : MODULE pao_methods
      13             :    USE ao_util,                         ONLY: exp_radius
      14             :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      15             :                                               get_atomic_kind
      16             :    USE basis_set_types,                 ONLY: gto_basis_set_type
      17             :    USE bibliography,                    ONLY: Kolafa2004,&
      18             :                                               Kuhne2007,&
      19             :                                               cite_reference
      20             :    USE cp_control_types,                ONLY: dft_control_type
      21             :    USE cp_dbcsr_api,                    ONLY: &
      22             :         dbcsr_add, dbcsr_binary_read, dbcsr_checksum, dbcsr_complete_redistribute, dbcsr_copy, &
      23             :         dbcsr_create, dbcsr_desymmetrize, dbcsr_distribution_get, dbcsr_distribution_new, &
      24             :         dbcsr_distribution_type, dbcsr_dot, dbcsr_filter, dbcsr_get_block_p, dbcsr_get_info, &
      25             :         dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, dbcsr_iterator_start, &
      26             :         dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_release, &
      27             :         dbcsr_reserve_diag_blocks, dbcsr_scale, dbcsr_set, dbcsr_type
      28             :    USE cp_log_handling,                 ONLY: cp_get_default_logger,&
      29             :                                               cp_logger_type,&
      30             :                                               cp_to_string
      31             :    USE dm_ls_scf_methods,               ONLY: density_matrix_trs4,&
      32             :                                               ls_scf_init_matrix_S
      33             :    USE dm_ls_scf_qs,                    ONLY: ls_scf_dm_to_ks,&
      34             :                                               ls_scf_qs_atomic_guess,&
      35             :                                               matrix_ls_to_qs,&
      36             :                                               matrix_qs_to_ls
      37             :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      38             :                                               ls_scf_env_type
      39             :    USE iterate_matrix,                  ONLY: purify_mcweeny
      40             :    USE kinds,                           ONLY: default_path_length,&
      41             :                                               dp
      42             :    USE machine,                         ONLY: m_walltime
      43             :    USE mathlib,                         ONLY: binomial,&
      44             :                                               diamat_all
      45             :    USE message_passing,                 ONLY: mp_para_env_type
      46             :    USE pao_ml,                          ONLY: pao_ml_forces
      47             :    USE pao_model,                       ONLY: pao_model_forces,&
      48             :                                               pao_model_load
      49             :    USE pao_param,                       ONLY: pao_calc_AB,&
      50             :                                               pao_param_count
      51             :    USE pao_types,                       ONLY: pao_env_type
      52             :    USE particle_types,                  ONLY: particle_type
      53             :    USE qs_energy_types,                 ONLY: qs_energy_type
      54             :    USE qs_environment_types,            ONLY: get_qs_env,&
      55             :                                               qs_environment_type
      56             :    USE qs_initial_guess,                ONLY: calculate_atomic_fock_matrix
      57             :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      58             :                                               pao_descriptor_type,&
      59             :                                               pao_potential_type,&
      60             :                                               qs_kind_type,&
      61             :                                               set_qs_kind
      62             :    USE qs_ks_methods,                   ONLY: qs_ks_update_qs_env
      63             :    USE qs_ks_types,                     ONLY: qs_ks_did_change
      64             :    USE qs_rho_methods,                  ONLY: qs_rho_update_rho
      65             :    USE qs_rho_types,                    ONLY: qs_rho_get,&
      66             :                                               qs_rho_type
      67             : 
      68             : !$ USE OMP_LIB, ONLY: omp_get_level
      69             : #include "./base/base_uses.f90"
      70             : 
      71             :    IMPLICIT NONE
      72             : 
      73             :    PRIVATE
      74             : 
      75             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_methods'
      76             : 
      77             :    PUBLIC :: pao_print_atom_info, pao_init_kinds
      78             :    PUBLIC :: pao_build_orthogonalizer, pao_build_selector
      79             :    PUBLIC :: pao_build_diag_distribution
      80             :    PUBLIC :: pao_build_matrix_X, pao_build_core_hamiltonian
      81             :    PUBLIC :: pao_test_convergence
      82             :    PUBLIC :: pao_calc_energy, pao_check_trace_ps
      83             :    PUBLIC :: pao_store_P, pao_add_forces, pao_guess_initial_P
      84             :    PUBLIC :: pao_check_grad
      85             : 
      86             : CONTAINS
      87             : 
      88             : ! **************************************************************************************************
      89             : !> \brief Initialize qs kinds
      90             : !> \param pao ...
      91             : !> \param qs_env ...
      92             : ! **************************************************************************************************
      93          98 :    SUBROUTINE pao_init_kinds(pao, qs_env)
      94             :       TYPE(pao_env_type), POINTER                        :: pao
      95             :       TYPE(qs_environment_type), POINTER                 :: qs_env
      96             : 
      97             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init_kinds'
      98             : 
      99             :       CHARACTER(LEN=default_path_length)                 :: pao_model_file
     100             :       INTEGER                                            :: handle, i, ikind, pao_basis_size
     101             :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
     102          98 :       TYPE(pao_descriptor_type), DIMENSION(:), POINTER   :: pao_descriptors
     103          98 :       TYPE(pao_potential_type), DIMENSION(:), POINTER    :: pao_potentials
     104          98 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     105             : 
     106          98 :       CALL timeset(routineN, handle)
     107          98 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set)
     108             : 
     109         236 :       DO ikind = 1, SIZE(qs_kind_set)
     110             :          CALL get_qs_kind(qs_kind_set(ikind), &
     111             :                           basis_set=basis_set, &
     112             :                           pao_basis_size=pao_basis_size, &
     113             :                           pao_model_file=pao_model_file, &
     114             :                           pao_potentials=pao_potentials, &
     115         138 :                           pao_descriptors=pao_descriptors)
     116             : 
     117         138 :          IF (pao_basis_size < 1) THEN
     118             :             ! pao disabled for ikind, set pao_basis_size to size of primary basis
     119          12 :             CALL set_qs_kind(qs_kind_set(ikind), pao_basis_size=basis_set%nsgf)
     120             :          END IF
     121             : 
     122             :          ! initialize radii of Gaussians to speedup screeing later on
     123         200 :          DO i = 1, SIZE(pao_potentials)
     124         200 :             pao_potentials(i)%beta_radius = exp_radius(0, pao_potentials(i)%beta, pao%eps_pgf, 1.0_dp)
     125             :          END DO
     126         156 :          DO i = 1, SIZE(pao_descriptors)
     127          18 :             pao_descriptors(i)%beta_radius = exp_radius(0, pao_descriptors(i)%beta, pao%eps_pgf, 1.0_dp)
     128         156 :             pao_descriptors(i)%screening_radius = exp_radius(0, pao_descriptors(i)%screening, pao%eps_pgf, 1.0_dp)
     129             :          END DO
     130             : 
     131             :          ! Load torch model.
     132         374 :          IF (LEN_TRIM(pao_model_file) > 0) THEN
     133           8 :             IF (.NOT. ALLOCATED(pao%models)) &
     134          20 :                ALLOCATE (pao%models(SIZE(qs_kind_set)))
     135           8 :             CALL pao_model_load(pao, qs_env, ikind, pao_model_file, pao%models(ikind))
     136             :          END IF
     137             : 
     138             :       END DO
     139          98 :       CALL timestop(handle)
     140          98 :    END SUBROUTINE pao_init_kinds
     141             : 
     142             : ! **************************************************************************************************
     143             : !> \brief Prints a one line summary for each atom.
     144             : !> \param pao ...
     145             : ! **************************************************************************************************
     146          98 :    SUBROUTINE pao_print_atom_info(pao)
     147             :       TYPE(pao_env_type), POINTER                        :: pao
     148             : 
     149             :       INTEGER                                            :: iatom, natoms
     150          98 :       INTEGER, DIMENSION(:), POINTER                     :: pao_basis, param_cols, param_rows, &
     151          98 :                                                             pri_basis
     152             : 
     153          98 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=pri_basis, col_blk_size=pao_basis)
     154          98 :       CPASSERT(SIZE(pao_basis) == SIZE(pri_basis))
     155          98 :       natoms = SIZE(pao_basis)
     156             : 
     157          98 :       CALL dbcsr_get_info(pao%matrix_X, row_blk_size=param_rows, col_blk_size=param_cols)
     158          98 :       CPASSERT(SIZE(param_rows) == natoms .AND. SIZE(param_cols) == natoms)
     159             : 
     160          98 :       IF (pao%iw_atoms > 0) THEN
     161          12 :          DO iatom = 1, natoms
     162             :             WRITE (pao%iw_atoms, "(A,I7,T20,A,I3,T45,A,I3,T65,A,I3)") &
     163           9 :                " PAO| atom: ", iatom, &
     164           9 :                " prim_basis: ", pri_basis(iatom), &
     165           9 :                " pao_basis: ", pao_basis(iatom), &
     166          21 :                " pao_params: ", (param_cols(iatom)*param_rows(iatom))
     167             :          END DO
     168             :       END IF
     169          98 :    END SUBROUTINE pao_print_atom_info
     170             : 
     171             : ! **************************************************************************************************
     172             : !> \brief Constructs matrix_N and its inverse.
     173             : !> \param pao ...
     174             : !> \param qs_env ...
     175             : ! **************************************************************************************************
     176          98 :    SUBROUTINE pao_build_orthogonalizer(pao, qs_env)
     177             :       TYPE(pao_env_type), POINTER                        :: pao
     178             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     179             : 
     180             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_orthogonalizer'
     181             : 
     182             :       INTEGER                                            :: acol, arow, handle, i, iatom, j, k, N
     183             :       LOGICAL                                            :: found
     184             :       REAL(dp)                                           :: v, w
     185          98 :       REAL(dp), DIMENSION(:), POINTER                    :: evals
     186          98 :       REAL(dp), DIMENSION(:, :), POINTER                 :: A, block_N, block_N_inv, block_S
     187             :       TYPE(dbcsr_iterator_type)                          :: iter
     188          98 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     189             : 
     190          98 :       CALL timeset(routineN, handle)
     191             : 
     192          98 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     193             : 
     194          98 :       CALL dbcsr_create(pao%matrix_N, template=matrix_s(1)%matrix, name="PAO matrix_N")
     195          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N)
     196             : 
     197          98 :       CALL dbcsr_create(pao%matrix_N_inv, template=matrix_s(1)%matrix, name="PAO matrix_N_inv")
     198          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_inv)
     199             : 
     200             : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,matrix_s) &
     201          98 : !$OMP PRIVATE(iter,arow,acol,iatom,block_N,block_N_inv,block_S,found,N,A,evals,k,i,j,w,v)
     202             :       CALL dbcsr_iterator_start(iter, pao%matrix_N)
     203             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     204             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_N)
     205             :          iatom = arow; CPASSERT(arow == acol)
     206             : 
     207             :          CALL dbcsr_get_block_p(matrix=pao%matrix_N_inv, row=iatom, col=iatom, block=block_N_inv, found=found)
     208             :          CPASSERT(ASSOCIATED(block_N_inv))
     209             : 
     210             :          CALL dbcsr_get_block_p(matrix=matrix_s(1)%matrix, row=iatom, col=iatom, block=block_S, found=found)
     211             :          CPASSERT(ASSOCIATED(block_S))
     212             : 
     213             :          N = SIZE(block_S, 1); CPASSERT(SIZE(block_S, 1) == SIZE(block_S, 2)) ! primary basis size
     214             :          ALLOCATE (A(N, N), evals(N))
     215             : 
     216             :          ! take square root of atomic overlap matrix
     217             :          A = block_S
     218             :          CALL diamat_all(A, evals) !afterwards A contains the eigenvectors
     219             :          block_N = 0.0_dp
     220             :          block_N_inv = 0.0_dp
     221             :          DO k = 1, N
     222             :             ! NOTE: To maintain a consistent notation with the Berghold paper,
     223             :             ! the "_inv" is swapped: N^{-1}=sqrt(S); N=sqrt(S)^{-1}
     224             :             w = 1.0_dp/SQRT(evals(k))
     225             :             v = SQRT(evals(k))
     226             :             DO i = 1, N
     227             :                DO j = 1, N
     228             :                   block_N(i, j) = block_N(i, j) + w*A(i, k)*A(j, k)
     229             :                   block_N_inv(i, j) = block_N_inv(i, j) + v*A(i, k)*A(j, k)
     230             :                END DO
     231             :             END DO
     232             :          END DO
     233             :          DEALLOCATE (A, evals)
     234             :       END DO
     235             :       CALL dbcsr_iterator_stop(iter)
     236             : !$OMP END PARALLEL
     237             : 
     238             :       ! store a copies of N and N_inv that are distributed according to pao%diag_distribution
     239             :       CALL dbcsr_create(pao%matrix_N_diag, &
     240             :                         name="PAO matrix_N_diag", &
     241             :                         dist=pao%diag_distribution, &
     242          98 :                         template=matrix_s(1)%matrix)
     243          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_diag)
     244          98 :       CALL dbcsr_complete_redistribute(pao%matrix_N, pao%matrix_N_diag)
     245             :       CALL dbcsr_create(pao%matrix_N_inv_diag, &
     246             :                         name="PAO matrix_N_inv_diag", &
     247             :                         dist=pao%diag_distribution, &
     248          98 :                         template=matrix_s(1)%matrix)
     249          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_N_inv_diag)
     250          98 :       CALL dbcsr_complete_redistribute(pao%matrix_N_inv, pao%matrix_N_inv_diag)
     251             : 
     252          98 :       CALL timestop(handle)
     253          98 :    END SUBROUTINE pao_build_orthogonalizer
     254             : 
     255             : ! **************************************************************************************************
     256             : !> \brief Build rectangular matrix to converert between primary and PAO basis.
     257             : !> \param pao ...
     258             : !> \param qs_env ...
     259             : ! **************************************************************************************************
     260          98 :    SUBROUTINE pao_build_selector(pao, qs_env)
     261             :       TYPE(pao_env_type), POINTER                        :: pao
     262             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     263             : 
     264             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_selector'
     265             : 
     266             :       INTEGER                                            :: acol, arow, handle, i, iatom, ikind, M, &
     267             :                                                             natoms
     268          98 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_aux, blk_sizes_pri
     269          98 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_Y
     270             :       TYPE(dbcsr_iterator_type)                          :: iter
     271          98 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     272          98 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     273          98 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     274             : 
     275          98 :       CALL timeset(routineN, handle)
     276             : 
     277             :       CALL get_qs_env(qs_env, &
     278             :                       natom=natoms, &
     279             :                       matrix_s=matrix_s, &
     280             :                       qs_kind_set=qs_kind_set, &
     281          98 :                       particle_set=particle_set)
     282             : 
     283          98 :       CALL dbcsr_get_info(matrix_s(1)%matrix, col_blk_size=blk_sizes_pri)
     284             : 
     285         294 :       ALLOCATE (blk_sizes_aux(natoms))
     286         336 :       DO iatom = 1, natoms
     287         238 :          CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     288         238 :          CALL get_qs_kind(qs_kind_set(ikind), pao_basis_size=M)
     289         238 :          CPASSERT(M > 0)
     290         238 :          IF (blk_sizes_pri(iatom) < M) &
     291           0 :             CPABORT("PAO basis size exceeds primary basis size.")
     292         574 :          blk_sizes_aux(iatom) = M
     293             :       END DO
     294             : 
     295             :       CALL dbcsr_create(pao%matrix_Y, &
     296             :                         template=matrix_s(1)%matrix, &
     297             :                         matrix_type="N", &
     298             :                         row_blk_size=blk_sizes_pri, &
     299             :                         col_blk_size=blk_sizes_aux, &
     300          98 :                         name="PAO matrix_Y")
     301          98 :       DEALLOCATE (blk_sizes_aux)
     302             : 
     303          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_Y)
     304             : 
     305             : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao) &
     306          98 : !$OMP PRIVATE(iter,arow,acol,block_Y,i,M)
     307             :       CALL dbcsr_iterator_start(iter, pao%matrix_Y)
     308             :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     309             :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_Y)
     310             :          M = SIZE(block_Y, 2) ! size of pao basis
     311             :          block_Y = 0.0_dp
     312             :          DO i = 1, M
     313             :             block_Y(i, i) = 1.0_dp
     314             :          END DO
     315             :       END DO
     316             :       CALL dbcsr_iterator_stop(iter)
     317             : !$OMP END PARALLEL
     318             : 
     319          98 :       CALL timestop(handle)
     320          98 :    END SUBROUTINE pao_build_selector
     321             : 
     322             : ! **************************************************************************************************
     323             : !> \brief Creates new DBCSR distribution which spreads diagonal blocks evenly across ranks
     324             : !> \param pao ...
     325             : !> \param qs_env ...
     326             : ! **************************************************************************************************
     327          98 :    SUBROUTINE pao_build_diag_distribution(pao, qs_env)
     328             :       TYPE(pao_env_type), POINTER                        :: pao
     329             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     330             : 
     331             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_diag_distribution'
     332             : 
     333             :       INTEGER                                            :: handle, iatom, natoms, pgrid_cols, &
     334             :                                                             pgrid_rows
     335          98 :       INTEGER, DIMENSION(:), POINTER                     :: diag_col_dist, diag_row_dist
     336             :       TYPE(dbcsr_distribution_type)                      :: main_dist
     337          98 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     338             : 
     339          98 :       CALL timeset(routineN, handle)
     340             : 
     341          98 :       CALL get_qs_env(qs_env, natom=natoms, matrix_s=matrix_s)
     342             : 
     343             :       ! get processor grid from matrix_s
     344          98 :       CALL dbcsr_get_info(matrix=matrix_s(1)%matrix, distribution=main_dist)
     345          98 :       CALL dbcsr_distribution_get(main_dist, nprows=pgrid_rows, npcols=pgrid_cols)
     346             : 
     347             :       ! create new mapping of matrix-grid to processor-grid
     348         490 :       ALLOCATE (diag_row_dist(natoms), diag_col_dist(natoms))
     349         336 :       DO iatom = 1, natoms
     350         238 :          diag_row_dist(iatom) = MOD(iatom - 1, pgrid_rows)
     351         336 :          diag_col_dist(iatom) = MOD((iatom - 1)/pgrid_rows, pgrid_cols)
     352             :       END DO
     353             : 
     354             :       ! instanciate distribution object
     355             :       CALL dbcsr_distribution_new(pao%diag_distribution, template=main_dist, &
     356          98 :                                   row_dist=diag_row_dist, col_dist=diag_col_dist)
     357             : 
     358          98 :       DEALLOCATE (diag_row_dist, diag_col_dist)
     359             : 
     360          98 :       CALL timestop(handle)
     361         196 :    END SUBROUTINE pao_build_diag_distribution
     362             : 
     363             : ! **************************************************************************************************
     364             : !> \brief Creates the matrix_X
     365             : !> \param pao ...
     366             : !> \param qs_env ...
     367             : ! **************************************************************************************************
     368          98 :    SUBROUTINE pao_build_matrix_X(pao, qs_env)
     369             :       TYPE(pao_env_type), POINTER                        :: pao
     370             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     371             : 
     372             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_build_matrix_X'
     373             : 
     374             :       INTEGER                                            :: handle, iatom, ikind, natoms
     375          98 :       INTEGER, DIMENSION(:), POINTER                     :: col_blk_size, row_blk_size
     376          98 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     377             : 
     378          98 :       CALL timeset(routineN, handle)
     379             : 
     380             :       CALL get_qs_env(qs_env, &
     381             :                       natom=natoms, &
     382          98 :                       particle_set=particle_set)
     383             : 
     384             :       ! determine block-sizes of matrix_X
     385         490 :       ALLOCATE (row_blk_size(natoms), col_blk_size(natoms))
     386         336 :       col_blk_size = 1
     387         336 :       DO iatom = 1, natoms
     388         238 :          CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     389         336 :          CALL pao_param_count(pao, qs_env, ikind, nparams=row_blk_size(iatom))
     390             :       END DO
     391             : 
     392             :       ! build actual matrix_X
     393             :       CALL dbcsr_create(pao%matrix_X, &
     394             :                         name="PAO matrix_X", &
     395             :                         dist=pao%diag_distribution, &
     396             :                         matrix_type="N", &
     397             :                         row_blk_size=row_blk_size, &
     398          98 :                         col_blk_size=col_blk_size)
     399          98 :       DEALLOCATE (row_blk_size, col_blk_size)
     400             : 
     401          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_X)
     402          98 :       CALL dbcsr_set(pao%matrix_X, 0.0_dp)
     403             : 
     404          98 :       CALL timestop(handle)
     405          98 :    END SUBROUTINE pao_build_matrix_X
     406             : 
     407             : ! **************************************************************************************************
     408             : !> \brief Creates the matrix_H0 which contains the core hamiltonian
     409             : !> \param pao ...
     410             : !> \param qs_env ...
     411             : ! **************************************************************************************************
     412          98 :    SUBROUTINE pao_build_core_hamiltonian(pao, qs_env)
     413             :       TYPE(pao_env_type), POINTER                        :: pao
     414             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     415             : 
     416          98 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     417          98 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     418          98 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     419             : 
     420             :       CALL get_qs_env(qs_env, &
     421             :                       matrix_s=matrix_s, &
     422             :                       atomic_kind_set=atomic_kind_set, &
     423          98 :                       qs_kind_set=qs_kind_set)
     424             : 
     425             :       ! allocate matrix_H0
     426             :       CALL dbcsr_create(pao%matrix_H0, &
     427             :                         name="PAO matrix_H0", &
     428             :                         dist=pao%diag_distribution, &
     429          98 :                         template=matrix_s(1)%matrix)
     430          98 :       CALL dbcsr_reserve_diag_blocks(pao%matrix_H0)
     431             : 
     432             :       ! calculate initial atomic fock matrix H0
     433             :       ! Can't use matrix_ks from ls_scf_qs_atomic_guess(), because it's not rotationally invariant.
     434             :       ! getting H0 directly from the atomic code
     435             :       CALL calculate_atomic_fock_matrix(pao%matrix_H0, &
     436             :                                         atomic_kind_set, &
     437             :                                         qs_kind_set, &
     438          98 :                                         ounit=pao%iw)
     439             : 
     440          98 :    END SUBROUTINE pao_build_core_hamiltonian
     441             : 
     442             : ! **************************************************************************************************
     443             : !> \brief Test whether the PAO optimization has reached convergence
     444             : !> \param pao ...
     445             : !> \param ls_scf_env ...
     446             : !> \param new_energy ...
     447             : !> \param is_converged ...
     448             : ! **************************************************************************************************
     449        2616 :    SUBROUTINE pao_test_convergence(pao, ls_scf_env, new_energy, is_converged)
     450             :       TYPE(pao_env_type), POINTER                        :: pao
     451             :       TYPE(ls_scf_env_type)                              :: ls_scf_env
     452             :       REAL(KIND=dp), INTENT(IN)                          :: new_energy
     453             :       LOGICAL, INTENT(OUT)                               :: is_converged
     454             : 
     455             :       REAL(KIND=dp)                                      :: energy_diff, loop_eps, now, time_diff
     456             : 
     457             :       ! calculate progress
     458        2616 :       energy_diff = new_energy - pao%energy_prev
     459        2616 :       pao%energy_prev = new_energy
     460        2616 :       now = m_walltime()
     461        2616 :       time_diff = now - pao%step_start_time
     462        2616 :       pao%step_start_time = now
     463             : 
     464             :       ! convergence criterion
     465        2616 :       loop_eps = pao%norm_G/ls_scf_env%nelectron_total
     466        2616 :       is_converged = loop_eps < pao%eps_pao
     467             : 
     468        2616 :       IF (pao%istep > 1) THEN
     469        2540 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| energy improvement:", energy_diff
     470             :          ! CPWARN_IF(energy_diff>0.0_dp, "PAO| energy increased")
     471             : 
     472             :          ! print one-liner
     473        2540 :          IF (pao%iw > 0) WRITE (pao%iw, '(A,I6,11X,F20.9,1X,E10.3,1X,E10.3,1X,F9.3)') &
     474        1270 :             " PAO| step ", &
     475        1270 :             pao%istep, &
     476        1270 :             new_energy, &
     477        1270 :             loop_eps, &
     478        1270 :             pao%linesearch%step_size, & !prev step, which let to the current energy
     479        2540 :             time_diff
     480             :       END IF
     481        2616 :    END SUBROUTINE pao_test_convergence
     482             : 
     483             : ! **************************************************************************************************
     484             : !> \brief Calculate the pao energy
     485             : !> \param pao ...
     486             : !> \param qs_env ...
     487             : !> \param ls_scf_env ...
     488             : !> \param energy ...
     489             : ! **************************************************************************************************
     490       11806 :    SUBROUTINE pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     491             :       TYPE(pao_env_type), POINTER                        :: pao
     492             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     493             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     494             :       REAL(KIND=dp), INTENT(OUT)                         :: energy
     495             : 
     496             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_calc_energy'
     497             : 
     498             :       INTEGER                                            :: handle, ispin
     499             :       REAL(KIND=dp)                                      :: penalty, trace_PH
     500             : 
     501       11806 :       CALL timeset(routineN, handle)
     502             : 
     503             :       ! calculate matrix U, which determines the pao basis
     504       11806 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE., penalty=penalty)
     505             : 
     506             :       ! calculat S, S_inv, S_sqrt, and S_sqrt_inv in the new pao basis
     507       11806 :       CALL pao_rebuild_S(qs_env, ls_scf_env)
     508             : 
     509             :       ! calculate the density matrix P in the pao basis
     510       11806 :       CALL pao_dm_trs4(qs_env, ls_scf_env)
     511             : 
     512             :       ! calculate the energy from the trace(PH) in the pao basis
     513       11806 :       energy = 0.0_dp
     514       23612 :       DO ispin = 1, ls_scf_env%nspins
     515       11806 :          CALL dbcsr_dot(ls_scf_env%matrix_p(ispin), ls_scf_env%matrix_ks(ispin), trace_PH)
     516       23612 :          energy = energy + trace_PH
     517             :       END DO
     518             : 
     519             :       ! add penalty term
     520       11806 :       energy = energy + penalty
     521             : 
     522       11806 :       IF (pao%iw > 0) THEN
     523        5903 :          WRITE (pao%iw, *) ""
     524        5903 :          WRITE (pao%iw, *) "PAO| energy:", energy, "penalty:", penalty
     525             :       END IF
     526       11806 :       CALL timestop(handle)
     527       11806 :    END SUBROUTINE pao_calc_energy
     528             : 
     529             : ! **************************************************************************************************
     530             : !> \brief Ensure that the number of electrons is correct.
     531             : !> \param ls_scf_env ...
     532             : ! **************************************************************************************************
     533       10326 :    SUBROUTINE pao_check_trace_PS(ls_scf_env)
     534             :       TYPE(ls_scf_env_type)                              :: ls_scf_env
     535             : 
     536             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_check_trace_PS'
     537             : 
     538             :       INTEGER                                            :: handle, ispin
     539             :       REAL(KIND=dp)                                      :: tmp, trace_PS
     540             :       TYPE(dbcsr_type)                                   :: matrix_S_desym
     541             : 
     542       10326 :       CALL timeset(routineN, handle)
     543       10326 :       CALL dbcsr_create(matrix_S_desym, template=ls_scf_env%matrix_s, matrix_type="N")
     544       10326 :       CALL dbcsr_desymmetrize(ls_scf_env%matrix_s, matrix_S_desym)
     545             : 
     546       10326 :       trace_PS = 0.0_dp
     547       20652 :       DO ispin = 1, ls_scf_env%nspins
     548       10326 :          CALL dbcsr_dot(ls_scf_env%matrix_p(ispin), matrix_S_desym, tmp)
     549       20652 :          trace_PS = trace_PS + tmp
     550             :       END DO
     551             : 
     552       10326 :       CALL dbcsr_release(matrix_S_desym)
     553             : 
     554       10326 :       IF (ABS(ls_scf_env%nelectron_total - trace_PS) > 0.5) &
     555           0 :          CPABORT("Number of electrons wrong. Trace(PS) ="//cp_to_string(trace_PS))
     556             : 
     557       10326 :       CALL timestop(handle)
     558       10326 :    END SUBROUTINE pao_check_trace_PS
     559             : 
     560             : ! **************************************************************************************************
     561             : !> \brief Read primary density matrix from file.
     562             : !> \param pao ...
     563             : !> \param qs_env ...
     564             : ! **************************************************************************************************
     565          56 :    SUBROUTINE pao_read_preopt_dm(pao, qs_env)
     566             :       TYPE(pao_env_type), POINTER                        :: pao
     567             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     568             : 
     569             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_read_preopt_dm'
     570             : 
     571             :       INTEGER                                            :: handle, ispin
     572             :       REAL(KIND=dp)                                      :: cs_pos
     573             :       TYPE(dbcsr_distribution_type)                      :: dist
     574          28 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s, rho_ao
     575             :       TYPE(dbcsr_type)                                   :: matrix_tmp
     576             :       TYPE(dft_control_type), POINTER                    :: dft_control
     577             :       TYPE(qs_energy_type), POINTER                      :: energy
     578             :       TYPE(qs_rho_type), POINTER                         :: rho
     579             : 
     580          28 :       CALL timeset(routineN, handle)
     581             : 
     582             :       CALL get_qs_env(qs_env, &
     583             :                       dft_control=dft_control, &
     584             :                       matrix_s=matrix_s, &
     585             :                       rho=rho, &
     586          28 :                       energy=energy)
     587             : 
     588          28 :       CALL qs_rho_get(rho, rho_ao=rho_ao)
     589             : 
     590          28 :       IF (dft_control%nspins /= 1) CPABORT("open shell not yet implemented")
     591             : 
     592          28 :       CALL dbcsr_get_info(matrix_s(1)%matrix, distribution=dist)
     593             : 
     594          56 :       DO ispin = 1, dft_control%nspins
     595          28 :          CALL dbcsr_binary_read(pao%preopt_dm_file, matrix_new=matrix_tmp, distribution=dist)
     596          28 :          cs_pos = dbcsr_checksum(matrix_tmp, pos=.TRUE.)
     597          28 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Read restart DM "// &
     598          14 :             TRIM(pao%preopt_dm_file)//" with checksum: ", cs_pos
     599          28 :          CALL dbcsr_copy(rho_ao(ispin)%matrix, matrix_tmp, keep_sparsity=.TRUE.)
     600          56 :          CALL dbcsr_release(matrix_tmp)
     601             :       END DO
     602             : 
     603             :       ! calculate corresponding ks matrix
     604          28 :       CALL qs_rho_update_rho(rho, qs_env=qs_env)
     605          28 :       CALL qs_ks_did_change(qs_env%ks_env, rho_changed=.TRUE.)
     606             :       CALL qs_ks_update_qs_env(qs_env, calculate_forces=.FALSE., &
     607          28 :                                just_energy=.FALSE., print_active=.TRUE.)
     608          28 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Quickstep energy from restart density:", energy%total
     609             : 
     610          28 :       CALL timestop(handle)
     611             : 
     612          28 :    END SUBROUTINE pao_read_preopt_dm
     613             : 
     614             : ! **************************************************************************************************
     615             : !> \brief Rebuilds S, S_inv, S_sqrt, and S_sqrt_inv in the pao basis
     616             : !> \param qs_env ...
     617             : !> \param ls_scf_env ...
     618             : ! **************************************************************************************************
     619       11806 :    SUBROUTINE pao_rebuild_S(qs_env, ls_scf_env)
     620             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     621             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     622             : 
     623             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_rebuild_S'
     624             : 
     625             :       INTEGER                                            :: handle
     626       11806 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     627             : 
     628       11806 :       CALL timeset(routineN, handle)
     629             : 
     630       11806 :       CALL dbcsr_release(ls_scf_env%matrix_s_inv)
     631       11806 :       CALL dbcsr_release(ls_scf_env%matrix_s_sqrt)
     632       11806 :       CALL dbcsr_release(ls_scf_env%matrix_s_sqrt_inv)
     633             : 
     634       11806 :       CALL get_qs_env(qs_env, matrix_s=matrix_s)
     635       11806 :       CALL ls_scf_init_matrix_s(matrix_s(1)%matrix, ls_scf_env)
     636             : 
     637       11806 :       CALL timestop(handle)
     638       11806 :    END SUBROUTINE pao_rebuild_S
     639             : 
     640             : ! **************************************************************************************************
     641             : !> \brief Calculate density matrix using TRS4 purification
     642             : !> \param qs_env ...
     643             : !> \param ls_scf_env ...
     644             : ! **************************************************************************************************
     645       11806 :    SUBROUTINE pao_dm_trs4(qs_env, ls_scf_env)
     646             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     647             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     648             : 
     649             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_dm_trs4'
     650             : 
     651             :       CHARACTER(LEN=default_path_length)                 :: project_name
     652             :       INTEGER                                            :: handle, ispin, nelectron_spin_real, nspin
     653             :       LOGICAL                                            :: converged
     654             :       REAL(KIND=dp)                                      :: homo_spin, lumo_spin, mu_spin
     655             :       TYPE(cp_logger_type), POINTER                      :: logger
     656       11806 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks
     657             : 
     658       11806 :       CALL timeset(routineN, handle)
     659       11806 :       logger => cp_get_default_logger()
     660       11806 :       project_name = logger%iter_info%project_name
     661       11806 :       nspin = ls_scf_env%nspins
     662             : 
     663       11806 :       CALL get_qs_env(qs_env, matrix_ks=matrix_ks)
     664       23612 :       DO ispin = 1, nspin
     665             :          CALL matrix_qs_to_ls(ls_scf_env%matrix_ks(ispin), matrix_ks(ispin)%matrix, &
     666       11806 :                               ls_scf_env%ls_mstruct, covariant=.TRUE.)
     667             : 
     668       11806 :          nelectron_spin_real = ls_scf_env%nelectron_spin(ispin)
     669       11806 :          IF (ls_scf_env%nspins == 1) nelectron_spin_real = nelectron_spin_real/2
     670             :          CALL density_matrix_trs4(ls_scf_env%matrix_p(ispin), ls_scf_env%matrix_ks(ispin), &
     671             :                                   ls_scf_env%matrix_s_sqrt_inv, &
     672             :                                   nelectron_spin_real, ls_scf_env%eps_filter, homo_spin, lumo_spin, mu_spin, &
     673             :                                   dynamic_threshold=.FALSE., converged=converged, &
     674             :                                   max_iter_lanczos=ls_scf_env%max_iter_lanczos, &
     675       11806 :                                   eps_lanczos=ls_scf_env%eps_lanczos)
     676       23612 :          IF (.NOT. converged) CPABORT("TRS4 did not converge")
     677             :       END DO
     678             : 
     679       11806 :       IF (nspin == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(1), 2.0_dp)
     680             : 
     681       11806 :       CALL timestop(handle)
     682       11806 :    END SUBROUTINE pao_dm_trs4
     683             : 
     684             : ! **************************************************************************************************
     685             : !> \brief Debugging routine for checking the analytic gradient.
     686             : !> \param pao ...
     687             : !> \param qs_env ...
     688             : !> \param ls_scf_env ...
     689             : ! **************************************************************************************************
     690        2628 :    SUBROUTINE pao_check_grad(pao, qs_env, ls_scf_env)
     691             :       TYPE(pao_env_type), POINTER                        :: pao
     692             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     693             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     694             : 
     695             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_check_grad'
     696             : 
     697             :       INTEGER                                            :: handle, i, iatom, j, natoms
     698        2616 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_col, blk_sizes_row
     699             :       LOGICAL                                            :: found
     700             :       REAL(dp)                                           :: delta, delta_max, eps, Gij_num
     701        2616 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_G, block_X
     702             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     703             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     704             : 
     705        2604 :       IF (pao%check_grad_tol < 0.0_dp) RETURN ! no checking
     706             : 
     707          12 :       CALL timeset(routineN, handle)
     708             : 
     709          12 :       ls_mstruct => ls_scf_env%ls_mstruct
     710             : 
     711          12 :       CALL get_qs_env(qs_env, para_env=para_env, natom=natoms)
     712             : 
     713          12 :       eps = pao%num_grad_eps
     714          12 :       delta_max = 0.0_dp
     715             : 
     716          12 :       CALL dbcsr_get_info(pao%matrix_X, col_blk_size=blk_sizes_col, row_blk_size=blk_sizes_row)
     717             : 
     718             :       ! can not use an iterator here, because other DBCSR routines are called within loop.
     719          38 :       DO iatom = 1, natoms
     720          26 :          IF (pao%iw > 0) WRITE (pao%iw, *) 'PAO| checking gradient of atom ', iatom
     721          26 :          CALL dbcsr_get_block_p(matrix=pao%matrix_X, row=iatom, col=iatom, block=block_X, found=found)
     722             : 
     723          26 :          IF (ASSOCIATED(block_X)) THEN !only one node actually has the block
     724          13 :             CALL dbcsr_get_block_p(matrix=pao%matrix_G, row=iatom, col=iatom, block=block_G, found=found)
     725          13 :             CPASSERT(ASSOCIATED(block_G))
     726             :          END IF
     727             : 
     728         586 :          DO i = 1, blk_sizes_row(iatom)
     729        1070 :             DO j = 1, blk_sizes_col(iatom)
     730         828 :                SELECT CASE (pao%num_grad_order)
     731             :                CASE (2) ! calculate derivative to 2th order
     732         306 :                   Gij_num = -eval_point(block_X, i, j, -eps, pao, ls_scf_env, qs_env)
     733         306 :                   Gij_num = Gij_num + eval_point(block_X, i, j, +eps, pao, ls_scf_env, qs_env)
     734         306 :                   Gij_num = Gij_num/(2.0_dp*eps)
     735             : 
     736             :                CASE (4) ! calculate derivative to 4th order
     737         180 :                   Gij_num = eval_point(block_X, i, j, -2_dp*eps, pao, ls_scf_env, qs_env)
     738         180 :                   Gij_num = Gij_num - 8_dp*eval_point(block_X, i, j, -1_dp*eps, pao, ls_scf_env, qs_env)
     739         180 :                   Gij_num = Gij_num + 8_dp*eval_point(block_X, i, j, +1_dp*eps, pao, ls_scf_env, qs_env)
     740         180 :                   Gij_num = Gij_num - eval_point(block_X, i, j, +2_dp*eps, pao, ls_scf_env, qs_env)
     741         180 :                   Gij_num = Gij_num/(12.0_dp*eps)
     742             : 
     743             :                CASE (6) ! calculate derivative to 6th order
     744          36 :                   Gij_num = -1_dp*eval_point(block_X, i, j, -3_dp*eps, pao, ls_scf_env, qs_env)
     745          36 :                   Gij_num = Gij_num + 9_dp*eval_point(block_X, i, j, -2_dp*eps, pao, ls_scf_env, qs_env)
     746          36 :                   Gij_num = Gij_num - 45_dp*eval_point(block_X, i, j, -1_dp*eps, pao, ls_scf_env, qs_env)
     747          36 :                   Gij_num = Gij_num + 45_dp*eval_point(block_X, i, j, +1_dp*eps, pao, ls_scf_env, qs_env)
     748          36 :                   Gij_num = Gij_num - 9_dp*eval_point(block_X, i, j, +2_dp*eps, pao, ls_scf_env, qs_env)
     749          36 :                   Gij_num = Gij_num + 1_dp*eval_point(block_X, i, j, +3_dp*eps, pao, ls_scf_env, qs_env)
     750          36 :                   Gij_num = Gij_num/(60.0_dp*eps)
     751             : 
     752             :                CASE DEFAULT
     753         522 :                   CPABORT("Unsupported numerical derivative order: "//cp_to_string(pao%num_grad_order))
     754             :                END SELECT
     755             : 
     756        1044 :                IF (ASSOCIATED(block_X)) THEN
     757         261 :                   delta = ABS(Gij_num - block_G(i, j))
     758         261 :                   delta_max = MAX(delta_max, delta)
     759             :                   !WRITE (*,*) "gradient check", iatom, i, j, Gij_num, block_G(i,j), delta
     760             :                END IF
     761             :             END DO
     762             :          END DO
     763             :       END DO
     764             : 
     765          12 :       CALL para_env%max(delta_max)
     766          12 :       IF (pao%iw > 0) WRITE (pao%iw, *) 'PAO| checked gradient, max delta:', delta_max
     767          12 :       IF (delta_max > pao%check_grad_tol) CALL cp_abort(__LOCATION__, &
     768           0 :                                                         "Analytic and numeric gradients differ too much:"//cp_to_string(delta_max))
     769             : 
     770          12 :       CALL timestop(handle)
     771        2616 :    END SUBROUTINE pao_check_grad
     772             : 
     773             : ! **************************************************************************************************
     774             : !> \brief Helper routine for pao_check_grad()
     775             : !> \param block_X ...
     776             : !> \param i ...
     777             : !> \param j ...
     778             : !> \param eps ...
     779             : !> \param pao ...
     780             : !> \param ls_scf_env ...
     781             : !> \param qs_env ...
     782             : !> \return ...
     783             : ! **************************************************************************************************
     784        3096 :    FUNCTION eval_point(block_X, i, j, eps, pao, ls_scf_env, qs_env) RESULT(energy)
     785             :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     786             :       INTEGER, INTENT(IN)                                :: i, j
     787             :       REAL(dp), INTENT(IN)                               :: eps
     788             :       TYPE(pao_env_type), POINTER                        :: pao
     789             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     790             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     791             :       REAL(dp)                                           :: energy
     792             : 
     793             :       REAL(dp)                                           :: old_Xij
     794             : 
     795        1548 :       IF (ASSOCIATED(block_X)) THEN
     796         774 :          old_Xij = block_X(i, j) ! backup old block_X
     797         774 :          block_X(i, j) = block_X(i, j) + eps ! add perturbation
     798             :       END IF
     799             : 
     800             :       ! calculate energy
     801        1548 :       CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     802             : 
     803             :       ! restore old block_X
     804        1548 :       IF (ASSOCIATED(block_X)) THEN
     805         774 :          block_X(i, j) = old_Xij
     806             :       END IF
     807             : 
     808        1548 :    END FUNCTION eval_point
     809             : 
     810             : ! **************************************************************************************************
     811             : !> \brief Stores density matrix as initial guess for next SCF optimization.
     812             : !> \param qs_env ...
     813             : !> \param ls_scf_env ...
     814             : ! **************************************************************************************************
     815         588 :    SUBROUTINE pao_store_P(qs_env, ls_scf_env)
     816             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     817             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     818             : 
     819             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_store_P'
     820             : 
     821             :       INTEGER                                            :: handle, ispin, istore
     822         294 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     823             :       TYPE(dft_control_type), POINTER                    :: dft_control
     824             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     825             :       TYPE(pao_env_type), POINTER                        :: pao
     826             : 
     827           0 :       IF (ls_scf_env%scf_history%nstore == 0) RETURN
     828         294 :       CALL timeset(routineN, handle)
     829         294 :       ls_mstruct => ls_scf_env%ls_mstruct
     830         294 :       pao => ls_scf_env%pao_env
     831         294 :       CALL get_qs_env(qs_env, dft_control=dft_control, matrix_s=matrix_s)
     832             : 
     833         294 :       ls_scf_env%scf_history%istore = ls_scf_env%scf_history%istore + 1
     834         294 :       istore = MOD(ls_scf_env%scf_history%istore - 1, ls_scf_env%scf_history%nstore) + 1
     835         294 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Storing density matrix for ASPC guess in slot:", istore
     836             : 
     837             :       ! initialize storage
     838         294 :       IF (ls_scf_env%scf_history%istore <= ls_scf_env%scf_history%nstore) THEN
     839         216 :          DO ispin = 1, dft_control%nspins
     840         216 :             CALL dbcsr_create(ls_scf_env%scf_history%matrix(ispin, istore), template=matrix_s(1)%matrix)
     841             :          END DO
     842             :       END IF
     843             : 
     844             :       ! We are storing the density matrix in the non-orthonormal primary basis.
     845             :       ! While the orthonormal basis would yield better extrapolations,
     846             :       ! we simply can not afford to calculat S_sqrt in the primary basis.
     847         588 :       DO ispin = 1, dft_control%nspins
     848             :          ! transform into primary basis
     849             :          CALL matrix_ls_to_qs(ls_scf_env%scf_history%matrix(ispin, istore), ls_scf_env%matrix_p(ispin), &
     850         588 :                               ls_scf_env%ls_mstruct, covariant=.FALSE., keep_sparsity=.FALSE.)
     851             :       END DO
     852             : 
     853         294 :       CALL timestop(handle)
     854         294 :    END SUBROUTINE pao_store_P
     855             : 
     856             : ! **************************************************************************************************
     857             : !> \brief Provide an initial guess for the density matrix
     858             : !> \param pao ...
     859             : !> \param qs_env ...
     860             : !> \param ls_scf_env ...
     861             : ! **************************************************************************************************
     862         294 :    SUBROUTINE pao_guess_initial_P(pao, qs_env, ls_scf_env)
     863             :       TYPE(pao_env_type), POINTER                        :: pao
     864             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     865             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     866             : 
     867             :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_guess_initial_P'
     868             : 
     869             :       INTEGER                                            :: handle
     870             : 
     871         294 :       CALL timeset(routineN, handle)
     872             : 
     873         294 :       IF (ls_scf_env%scf_history%istore > 0) THEN
     874         196 :          CALL pao_aspc_guess_P(pao, qs_env, ls_scf_env)
     875         196 :          pao%need_initial_scf = .TRUE.
     876             :       ELSE
     877          98 :          IF (LEN_TRIM(pao%preopt_dm_file) > 0) THEN
     878          28 :             CALL pao_read_preopt_dm(pao, qs_env)
     879          28 :             pao%need_initial_scf = .FALSE.
     880          28 :             pao%preopt_dm_file = "" ! load only for first MD step
     881             :          ELSE
     882          70 :             CALL ls_scf_qs_atomic_guess(qs_env, ls_scf_env, ls_scf_env%energy_init)
     883          70 :             IF (pao%iw > 0) WRITE (pao%iw, '(A,F20.9)') &
     884          35 :                " PAO| Energy from initial atomic guess:", ls_scf_env%energy_init
     885          70 :             pao%need_initial_scf = .TRUE.
     886             :          END IF
     887             :       END IF
     888             : 
     889         294 :       CALL timestop(handle)
     890             : 
     891         294 :    END SUBROUTINE pao_guess_initial_P
     892             : 
     893             : ! **************************************************************************************************
     894             : !> \brief Run the Always Stable Predictor-Corrector to guess an initial density matrix
     895             : !> \param pao ...
     896             : !> \param qs_env ...
     897             : !> \param ls_scf_env ...
     898             : ! **************************************************************************************************
     899         196 :    SUBROUTINE pao_aspc_guess_P(pao, qs_env, ls_scf_env)
     900             :       TYPE(pao_env_type), POINTER                        :: pao
     901             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     902             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     903             : 
     904             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_aspc_guess_P'
     905             : 
     906             :       INTEGER                                            :: handle, iaspc, ispin, istore, naspc
     907             :       REAL(dp)                                           :: alpha
     908         196 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
     909             :       TYPE(dbcsr_type)                                   :: matrix_P
     910             :       TYPE(dft_control_type), POINTER                    :: dft_control
     911             :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     912             : 
     913         196 :       CALL timeset(routineN, handle)
     914         196 :       ls_mstruct => ls_scf_env%ls_mstruct
     915         196 :       CPASSERT(ls_scf_env%scf_history%istore > 0)
     916         196 :       CALL cite_reference(Kolafa2004)
     917         196 :       CALL cite_reference(Kuhne2007)
     918         196 :       CALL get_qs_env(qs_env, dft_control=dft_control, matrix_s=matrix_s)
     919             : 
     920         196 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Calculating initial guess with ASPC"
     921             : 
     922         196 :       CALL dbcsr_create(matrix_P, template=matrix_s(1)%matrix)
     923             : 
     924         196 :       naspc = MIN(ls_scf_env%scf_history%istore, ls_scf_env%scf_history%nstore)
     925         392 :       DO ispin = 1, dft_control%nspins
     926             :          ! actual extrapolation
     927         196 :          CALL dbcsr_set(matrix_P, 0.0_dp)
     928         416 :          DO iaspc = 1, naspc
     929             :             alpha = (-1.0_dp)**(iaspc + 1)*REAL(iaspc, KIND=dp)* &
     930         220 :                     binomial(2*naspc, naspc - iaspc)/binomial(2*naspc - 2, naspc - 1)
     931         220 :             istore = MOD(ls_scf_env%scf_history%istore - iaspc, ls_scf_env%scf_history%nstore) + 1
     932         416 :             CALL dbcsr_add(matrix_P, ls_scf_env%scf_history%matrix(ispin, istore), 1.0_dp, alpha)
     933             :          END DO
     934             : 
     935             :          ! transform back from primary basis into pao basis
     936         392 :          CALL matrix_qs_to_ls(ls_scf_env%matrix_p(ispin), matrix_P, ls_scf_env%ls_mstruct, covariant=.FALSE.)
     937             :       END DO
     938             : 
     939         196 :       CALL dbcsr_release(matrix_P)
     940             : 
     941             :       ! linear combination of P's is not idempotent. A bit of McWeeny is needed to ensure it is again
     942         392 :       DO ispin = 1, dft_control%nspins
     943         196 :          IF (dft_control%nspins == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(ispin), 0.5_dp)
     944             :          ! to ensure that noisy blocks do not build up during MD (in particular with curvy) filter that guess a bit more
     945         196 :          CALL dbcsr_filter(ls_scf_env%matrix_p(ispin), ls_scf_env%eps_filter**(2.0_dp/3.0_dp))
     946             :          ! we could go to the orthonomal basis, but it seems not worth the trouble
     947             :          ! TODO : 10 iterations is a conservative upper bound, figure out when it fails
     948         196 :          CALL purify_mcweeny(ls_scf_env%matrix_p(ispin:ispin), ls_scf_env%matrix_s, ls_scf_env%eps_filter, 10)
     949         392 :          IF (dft_control%nspins == 1) CALL dbcsr_scale(ls_scf_env%matrix_p(ispin), 2.0_dp)
     950             :       END DO
     951             : 
     952         196 :       CALL pao_check_trace_PS(ls_scf_env) ! sanity check
     953             : 
     954             :       ! compute corresponding energy and ks matrix
     955         196 :       CALL ls_scf_dm_to_ks(qs_env, ls_scf_env, ls_scf_env%energy_init, iscf=0)
     956             : 
     957         196 :       CALL timestop(handle)
     958         196 :    END SUBROUTINE pao_aspc_guess_P
     959             : 
     960             : ! **************************************************************************************************
     961             : !> \brief Calculate the forces contributed by PAO
     962             : !> \param qs_env ...
     963             : !> \param ls_scf_env ...
     964             : ! **************************************************************************************************
     965          44 :    SUBROUTINE pao_add_forces(qs_env, ls_scf_env)
     966             :       TYPE(qs_environment_type), POINTER                 :: qs_env
     967             :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     968             : 
     969             :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_add_forces'
     970             : 
     971             :       INTEGER                                            :: handle, iatom, natoms
     972          44 :       REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: forces
     973             :       TYPE(mp_para_env_type), POINTER                    :: para_env
     974             :       TYPE(pao_env_type), POINTER                        :: pao
     975          44 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     976             : 
     977          44 :       CALL timeset(routineN, handle)
     978          44 :       pao => ls_scf_env%pao_env
     979             : 
     980          44 :       IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Adding forces."
     981             : 
     982          44 :       IF (pao%max_pao /= 0) THEN
     983          20 :          IF (pao%penalty_strength /= 0.0_dp) &
     984           0 :             CPABORT("PAO forces require PENALTY_STRENGTH or MAX_PAO set to zero")
     985          20 :          IF (pao%linpot_regu_strength /= 0.0_dp) &
     986           0 :             CPABORT("PAO forces require LINPOT_REGULARIZATION_STRENGTH or MAX_PAO set to zero")
     987          20 :          IF (pao%regularization /= 0.0_dp) &
     988           0 :             CPABORT("PAO forces require REGULARIZATION or MAX_PAO set to zero")
     989             :       END IF
     990             : 
     991             :       CALL get_qs_env(qs_env, &
     992             :                       para_env=para_env, &
     993             :                       particle_set=particle_set, &
     994          44 :                       natom=natoms)
     995             : 
     996         132 :       ALLOCATE (forces(natoms, 3))
     997          44 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., forces=forces) ! without penalty terms
     998             : 
     999          44 :       IF (SIZE(pao%ml_training_set) > 0) &
    1000          18 :          CALL pao_ml_forces(pao, qs_env, pao%matrix_G, forces)
    1001             : 
    1002          44 :       IF (ALLOCATED(pao%models)) &
    1003           2 :          CALL pao_model_forces(pao, qs_env, pao%matrix_G, forces)
    1004             : 
    1005          44 :       CALL para_env%sum(forces)
    1006         150 :       DO iatom = 1, natoms
    1007         468 :          particle_set(iatom)%f = particle_set(iatom)%f + forces(iatom, :)
    1008             :       END DO
    1009             : 
    1010          44 :       DEALLOCATE (forces)
    1011             : 
    1012          44 :       CALL timestop(handle)
    1013             : 
    1014          44 :    END SUBROUTINE pao_add_forces
    1015             : 
    1016             : END MODULE pao_methods

Generated by: LCOV version 1.15