From fe1d82a9fe9be3d06b91b67261747311663773ae Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 08:19:14 +0800 Subject: [PATCH 01/18] small format changes --- source/source_esolver/esolver_ks_pw.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index efb17bc0fd..5f9e56a95d 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -48,10 +48,10 @@ ESolver_KS_PW::ESolver_KS_PW() template ESolver_KS_PW::~ESolver_KS_PW() { - //**************************************************** - // do not add any codes in this deconstructor funcion - //**************************************************** - // delete Hamilt + //**************************************************** + // do not add any codes in this deconstructor funcion + //**************************************************** + // delete Hamilt this->deallocate_hamilt(); // mohan add 2025-10-12 @@ -83,7 +83,6 @@ void ESolver_KS_PW::deallocate_hamilt() template void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_para& inp) { - //! Call before_all_runners() of ESolver_KS ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. @@ -105,7 +104,6 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ModuleBase::TITLE("ESolver_KS_PW", "before_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); - //! Call before_scf() of ESolver_KS ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) @@ -143,17 +141,15 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) template void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const int iter) { - // 1) Call iter_init() of ESolver_KS ESolver_KS::iter_init(ucell, istep, iter); - // 2) perform charge mixing for KSDFT using pw basis module_charge::chgmixing_ks_pw(iter, this->p_chgmix, this->dftu, PARAM.inp); - // 3) mohan move harris functional here, 2012-06-05 + // mohan move harris functional here, 2012-06-05 // use 'rho(in)' and 'v_h and v_xc'(in) this->pelec->f_en.deband_harris = this->pelec->cal_delta_eband(ucell); - // 4) update local occupations for DFT+U + // update local occupations for DFT+U // should before lambda loop in DeltaSpin pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp); } @@ -265,7 +261,6 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->pelec->cal_tau(*(this->psi)); } - // Call 'after_scf' of ESolver_KS ESolver_KS::after_scf(ucell, istep, conv_esolver); // Output quantities From 1a725f6bcfd284470c93aea8cf8392341436cde9 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 08:54:30 +0800 Subject: [PATCH 02/18] refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 12 ++---------- source/source_esolver/esolver_ks_lcao_tddft.cpp | 6 +----- source/source_esolver/esolver_ks_lcaopw.cpp | 6 +----- source/source_esolver/esolver_ks_pw.cpp | 6 +----- source/source_esolver/esolver_of.cpp | 6 +----- source/source_esolver/esolver_sdft_pw.cpp | 6 +----- .../source_estate/module_charge/symmetry_rho.cpp | 12 ++++++++++++ .../source_estate/module_charge/symmetry_rho.h | 16 ++++++++++++++++ 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index f8cecf6805..be6833294f 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -203,11 +203,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) #endif // 16) the electron charge density should be symmetrized, - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); // 17) update of RDMFT, added by jghan if (PARAM.inp.rdmft == true) @@ -435,11 +431,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int #endif // 5) symmetrize the charge density - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); // 6) calculate delta energy this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index 8a0035681b..b7641a09fc 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -290,11 +290,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, // Symmetrize the charge density only for ground state if (istep <= 1) { - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); } #ifdef __EXX if (GlobalC::exx_info.info_ri.real_number) diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index dd37188af3..f9700f5b68 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -157,11 +157,7 @@ namespace ModuleESolver } #endif - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rhod, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rhod, ucell.symm); // deband is calculated from "output" charge density calculated // in sum_band diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 5f9e56a95d..b35ea19948 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -201,11 +201,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste } // symmetrize the charge density - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rhod, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rhod, ucell.symm); ModuleBase::timer::tick("ESolver_KS_PW", "hamilt2rho_single"); } diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index 4a086205c4..1fb754d2b6 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -234,11 +234,7 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) this->pelec->init_scf(ucell, Pgrid, sf.strucFac, locpp.numeric, ucell.symm); - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); for (int is = 0; is < PARAM.inp.nspin; ++is) { diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 1a9057d178..798e52d26b 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -190,11 +190,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i if (PARAM.globalv.ks_run) { - Symmetry_rho srho; - for (int is = 0; is < PARAM.inp.nspin; is++) - { - srho.begin(is, this->chr, this->pw_rho, ucell.symm); - } + Symmetry_rho::symmetrize_rho(PARAM.inp.nspin, this->chr, this->pw_rho, ucell.symm); this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); } else diff --git a/source/source_estate/module_charge/symmetry_rho.cpp b/source/source_estate/module_charge/symmetry_rho.cpp index dbd8a57af1..19b67967c7 100644 --- a/source/source_estate/module_charge/symmetry_rho.cpp +++ b/source/source_estate/module_charge/symmetry_rho.cpp @@ -10,6 +10,18 @@ Symmetry_rho::~Symmetry_rho() { } +void Symmetry_rho::symmetrize_rho(const int nspin, + const Charge& chr, + const ModulePW::PW_Basis* pw, + ModuleSymmetry::Symmetry& symm) +{ + Symmetry_rho srho; + for (int is = 0; is < nspin; is++) + { + srho.begin(is, chr, pw, symm); + } +} + void Symmetry_rho::begin(const int& spin_now, const Charge& chr, const ModulePW::PW_Basis* rho_basis, diff --git a/source/source_estate/module_charge/symmetry_rho.h b/source/source_estate/module_charge/symmetry_rho.h index 638903fd93..98d0650167 100644 --- a/source/source_estate/module_charge/symmetry_rho.h +++ b/source/source_estate/module_charge/symmetry_rho.h @@ -11,6 +11,22 @@ class Symmetry_rho Symmetry_rho(); ~Symmetry_rho(); + /** + * @brief Symmetrize charge density for all spin channels + * + * This is a static helper function that symmetrizes the charge density + * for all spin channels by calling begin() for each spin. + * + * @param nspin Number of spin channels + * @param chr Charge object containing the density + * @param pw Plane wave basis + * @param symm Symmetry object + */ + static void symmetrize_rho(const int nspin, + const Charge& chr, + const ModulePW::PW_Basis* pw, + ModuleSymmetry::Symmetry& symm); + void begin(const int& spin_now, const Charge& CHR, const ModulePW::PW_Basis* pw, From 285ee1c58b4968bb3fa945a797cd574078eed109 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 09:19:50 +0800 Subject: [PATCH 03/18] refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 20 +-------- .../module_deltaspin/CMakeLists.txt | 1 + .../module_deltaspin/deltaspin_lcao.cpp | 44 +++++++++++++++++++ .../module_deltaspin/deltaspin_lcao.h | 29 ++++++++++++ 4 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 source/source_lcao/module_deltaspin/deltaspin_lcao.cpp create mode 100644 source/source_lcao/module_deltaspin/deltaspin_lcao.h diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index be6833294f..44753b41b1 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -1,6 +1,7 @@ #include "esolver_ks_lcao.h" #include "source_estate/elecstate_tools.h" #include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_lcao/module_deltaspin/deltaspin_lcao.h" #include "source_lcao/hs_matrix_k.hpp" // there may be multiple definitions if using hpp #include "source_estate/module_charge/symmetry_rho.h" #include "source_lcao/LCAO_domain.h" // need DeePKS_init @@ -388,24 +389,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; // 2) run the inner lambda loop to contrain atomic moments with the DeltaSpin method - bool skip_solve = false; - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - sc.set_mag_converged(true); - skip_solve = true; - } - else if (sc.mag_converged()) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - skip_solve = true; - } - } + bool skip_solve = run_deltaspin_lambda_loop_lcao(iter - 1, this->drho, PARAM.inp); // 3) run Hsolver if (!skip_solve) diff --git a/source/source_lcao/module_deltaspin/CMakeLists.txt b/source/source_lcao/module_deltaspin/CMakeLists.txt index 02f389e5f1..6a0c1fea22 100644 --- a/source/source_lcao/module_deltaspin/CMakeLists.txt +++ b/source/source_lcao/module_deltaspin/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND objects lambda_loop.cpp cal_mw_from_lambda.cpp template_helpers.cpp + deltaspin_lcao.cpp ) add_library( diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp new file mode 100644 index 0000000000..9b0e2d08ab --- /dev/null +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -0,0 +1,44 @@ +#include "deltaspin_lcao.h" +#include "spin_constrain.h" + +namespace ModuleESolver +{ + +template +bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp) +{ + bool skip_solve = false; + + if (inp.sc_mag_switch) + { + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + + if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) + { + /// optimize lambda to get target magnetic moments, but the lambda is not near target + sc.run_lambda_loop(iter - 1); + sc.set_mag_converged(true); + skip_solve = true; + } + else if (sc.mag_converged()) + { + /// optimize lambda to get target magnetic moments, but the lambda is not near target + sc.run_lambda_loop(iter - 1); + skip_solve = true; + } + } + + return skip_solve; +} + +/// Template instantiation +template bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp); +template bool run_deltaspin_lambda_loop_lcao>(const int iter, + const double drho, + const Input_para& inp); + +} // namespace ModuleESolver diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h new file mode 100644 index 0000000000..95d3352732 --- /dev/null +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -0,0 +1,29 @@ +#ifndef DELTASPIN_LCAO_H +#define DELTASPIN_LCAO_H + +#include "source_cell/unitcell.h" +#include "source_io/module_parameter/input_parameter.h" + +namespace ModuleESolver +{ + +/** + * @brief Run DeltaSpin lambda loop for LCAO method + * + * This function handles the lambda loop optimization for the DeltaSpin method + * in LCAO calculations. It determines whether to skip the Hamiltonian solve + * based on the convergence status of magnetic moments. + * + * @param iter Current iteration number + * @param drho Charge density convergence criterion + * @param inp Input parameters + * @return bool Whether to skip the Hamiltonian solve + */ +template +bool run_deltaspin_lambda_loop_lcao(const int iter, + const double drho, + const Input_para& inp); + +} // namespace ModuleESolver + +#endif // DELTASPIN_LCAO_H From 2a520e3f26b9b43547a839de57ecd59fdae60930 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 09:30:56 +0800 Subject: [PATCH 04/18] refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 14 +---- .../module_deltaspin/deltaspin_lcao.cpp | 59 ++++++++++++++++++- .../module_deltaspin/deltaspin_lcao.h | 37 ++++++++++++ 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 44753b41b1..3050a8aa9e 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -150,13 +150,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) this->deepks.build_overlap(ucell, orb_, pv, gd, *(two_center_bundle_.overlap_orb_alpha), PARAM.inp); // 10) prepare sc calculation - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - sc.init_sc(PARAM.inp.sc_thr, PARAM.inp.nsc, PARAM.inp.nsc_min, PARAM.inp.alpha_trial, - PARAM.inp.sccut, PARAM.inp.sc_drop_thr, ucell, &(this->pv), - PARAM.inp.nspin, this->kv, this->p_hamilt, this->psi, this->dmat.dm, this->pelec); - } + init_deltaspin_lcao(ucell, PARAM.inp, &(this->pv), this->kv, this->p_hamilt, this->psi, this->dmat.dm, this->pelec); // 11) set xc type before the first cal of xc in pelec->init_scf, Peize Lin add 2016-12-03 this->exx_nao.before_scf(ucell, this->kv, orb_, this->p_chgmix, istep, PARAM.inp); @@ -462,11 +456,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); // 3) for delta spin - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); - sc.cal_mi_lcao(iter); - } + cal_mi_lcao_wrapper(iter); // call iter_finish() of ESolver_KS, where band gap is printed, // eig and occ are printed, magnetization is calculated, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index 9b0e2d08ab..c58b4f0783 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -1,9 +1,44 @@ #include "deltaspin_lcao.h" #include "spin_constrain.h" +#include "source_basis/module_ao/parallel_orbitals.h" +#include "source_lcao/hamilt_lcao.h" +#include "source_estate/module_dm/density_matrix.h" +#include "source_estate/elecstate.h" namespace ModuleESolver { +template +void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec) +{ + if (!inp.sc_mag_switch) + { + return; + } + + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, + inp.sccut, inp.sc_drop_thr, ucell, + static_cast(pv), + inp.nspin, kv, p_hamilt, psi, + static_cast*>(dm), + static_cast(pelec)); +} + +template +void cal_mi_lcao_wrapper(const int iter) +{ + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); + sc.cal_mi_lcao(iter); +} + template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, @@ -18,14 +53,14 @@ bool run_deltaspin_lambda_loop_lcao(const int iter, if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) { /// optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); + sc.run_lambda_loop(iter); sc.set_mag_converged(true); skip_solve = true; } else if (sc.mag_converged()) { /// optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); + sc.run_lambda_loop(iter); skip_solve = true; } } @@ -34,6 +69,26 @@ bool run_deltaspin_lambda_loop_lcao(const int iter, } /// Template instantiation +template void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); +template void init_deltaspin_lcao>(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); + +template void cal_mi_lcao_wrapper(const int iter); +template void cal_mi_lcao_wrapper>(const int iter); + template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, const Input_para& inp); diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h index 95d3352732..f91326490b 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.h +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -2,11 +2,48 @@ #define DELTASPIN_LCAO_H #include "source_cell/unitcell.h" +#include "source_cell/klist.h" #include "source_io/module_parameter/input_parameter.h" namespace ModuleESolver { +/** + * @brief Initialize DeltaSpin for LCAO method + * + * This function initializes the DeltaSpin calculation by setting up + * the SpinConstrain object with input parameters. + * + * @param ucell Unit cell + * @param inp Input parameters + * @param pv Parallel orbitals + * @param kv K-vectors + * @param p_hamilt Pointer to Hamiltonian + * @param psi Pointer to wave functions + * @param dm Density matrix + * @param pelec Pointer to electronic state + */ +template +void init_deltaspin_lcao(const UnitCell& ucell, + const Input_para& inp, + void* pv, + const K_Vectors& kv, + void* p_hamilt, + void* psi, + void* dm, + void* pelec); + +/** + * @brief Calculate magnetic moments for DeltaSpin in LCAO method + * + * This function calculates the magnetic moments for each atom + * in the DeltaSpin method. + * + * @param iter Current iteration number + */ +template +void cal_mi_lcao_wrapper(const int iter); + /** * @brief Run DeltaSpin lambda loop for LCAO method * From 91be943b3f8de93a6b4d394d01671cb09ae4fd4a Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 10:16:10 +0800 Subject: [PATCH 05/18] refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_lcao.cpp | 32 +------ source/source_lcao/CMakeLists.txt | 1 + source/source_lcao/dftu_lcao.cpp | 112 ++++++++++++++++++++++ source/source_lcao/dftu_lcao.h | 65 +++++++++++++ 4 files changed, 181 insertions(+), 29 deletions(-) create mode 100644 source/source_lcao/dftu_lcao.cpp create mode 100644 source/source_lcao/dftu_lcao.h diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 3050a8aa9e..e86c7a8cb6 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -2,6 +2,7 @@ #include "source_estate/elecstate_tools.h" #include "source_lcao/module_deltaspin/spin_constrain.h" #include "source_lcao/module_deltaspin/deltaspin_lcao.h" +#include "source_lcao/dftu_lcao.h" #include "source_lcao/hs_matrix_k.hpp" // there may be multiple definitions if using hpp #include "source_estate/module_charge/symmetry_rho.h" #include "source_lcao/LCAO_domain.h" // need DeePKS_init @@ -338,15 +339,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const } #endif - if (PARAM.inp.dft_plus_u) - { - if (istep != 0 || iter != 1) - { - this->dftu.set_dmr(this->dmat.dm); - } - // Calculate U and J if Yukawa potential is used - this->dftu.cal_slater_UJ(ucell, this->chr.rho, this->pw_rho->nrxx); - } + init_dftu_lcao(istep, iter, PARAM.inp, &(this->dftu), this->dmat.dm, ucell, this->chr.rho, this->pw_rho->nrxx); #ifdef __MLALGO // the density matrixes of DeePKS have been updated in each iter @@ -431,26 +424,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& const std::vector>& dm_vec = this->dmat.dm->get_DMK_vector(); // 1) calculate the local occupation number matrix and energy correction in DFT+U - if (PARAM.inp.dft_plus_u) - { - // old DFT+U method calculates energy correction in esolver, - // new DFT+U method calculates energy in Hamiltonian - if (PARAM.inp.dft_plus_u == 2) - { - if (this->dftu.omc != 2) - { - dftu_cal_occup_m(iter, ucell, dm_vec, this->kv, - this->p_chgmix->get_mixing_beta(), hamilt_lcao, this->dftu); - } - this->dftu.cal_energy_correction(ucell, istep); - } - this->dftu.output(ucell); - // use the converged occupation matrix for next MD/Relax SCF calculation - if (conv_esolver) - { - this->dftu.initialed_locale = true; - } - } + finish_dftu_lcao(iter, conv_esolver, PARAM.inp, &(this->dftu), ucell, dm_vec, this->kv, this->p_chgmix->get_mixing_beta(), hamilt_lcao); // 2) for deepks, calculate delta_e, output labels during electronic steps this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); diff --git a/source/source_lcao/CMakeLists.txt b/source/source_lcao/CMakeLists.txt index 844da5dc84..a793f5d5d0 100644 --- a/source/source_lcao/CMakeLists.txt +++ b/source/source_lcao/CMakeLists.txt @@ -23,6 +23,7 @@ if(ENABLE_LCAO) module_operator_lcao/dspin_lcao.cpp module_operator_lcao/dftu_lcao.cpp module_operator_lcao/operator_force_stress_utils.cpp + dftu_lcao.cpp pulay_fs_center2.cpp FORCE_STRESS.cpp FORCE_gamma.cpp diff --git a/source/source_lcao/dftu_lcao.cpp b/source/source_lcao/dftu_lcao.cpp new file mode 100644 index 0000000000..5a4c6c45c8 --- /dev/null +++ b/source/source_lcao/dftu_lcao.cpp @@ -0,0 +1,112 @@ +#include "dftu_lcao.h" +#include "source_lcao/module_dftu/dftu.h" +#include "source_estate/module_dm/density_matrix.h" +#include "source_lcao/hamilt_lcao.h" + +namespace ModuleESolver +{ + +template +void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx) +{ + if (!inp.dft_plus_u) + { + return; + } + + auto* dftu_ptr = static_cast(dftu); + auto* dm_ptr = static_cast*>(dm); + + if (istep != 0 || iter != 1) + { + dftu_ptr->set_dmr(dm_ptr); + } + + /// Calculate U and J if Yukawa potential is used + dftu_ptr->cal_slater_UJ(ucell, rho, nrxx); +} + +template +void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao) +{ + if (!inp.dft_plus_u) + { + return; + } + + auto* dftu_ptr = static_cast(dftu); + auto* hamilt_lcao_ptr = static_cast*>(hamilt_lcao); + + /// old DFT+U method calculates energy correction in esolver, + /// new DFT+U method calculates energy in Hamiltonian + if (inp.dft_plus_u == 2) + { + if (dftu_ptr->omc != 2) + { + dftu_cal_occup_m(iter, ucell, dm_vec, kv, mixing_beta, + static_cast*>(hamilt_lcao_ptr), *dftu_ptr); + } + dftu_ptr->cal_energy_correction(ucell, iter); + } + dftu_ptr->output(ucell); + + /// use the converged occupation matrix for next MD/Relax SCF calculation + if (conv_esolver) + { + dftu_ptr->initialed_locale = true; + } +} + +/// Template instantiation +template void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); +template void init_dftu_lcao>(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); + +template void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); +template void finish_dftu_lcao>(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); + +} // namespace ModuleESolver diff --git a/source/source_lcao/dftu_lcao.h b/source/source_lcao/dftu_lcao.h new file mode 100644 index 0000000000..5138b66256 --- /dev/null +++ b/source/source_lcao/dftu_lcao.h @@ -0,0 +1,65 @@ +#ifndef DFTU_LCAO_H +#define DFTU_LCAO_H + +#include "source_cell/unitcell.h" +#include "source_cell/klist.h" +#include "source_io/module_parameter/input_parameter.h" + +namespace ModuleESolver +{ + +/** + * @brief Initialize DFT+U for LCAO method in iter_init + * + * This function handles the DFT+U initialization during the SCF iteration. + * It sets the density matrix and calculates Slater integrals if needed. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param inp Input parameters + * @param dftu DFT+U object + * @param dm Density matrix + * @param ucell Unit cell + * @param rho Charge density + * @param nrxx Number of real space grid points + */ +template +void init_dftu_lcao(const int istep, + const int iter, + const Input_para& inp, + void* dftu, + void* dm, + const UnitCell& ucell, + double** rho, + const int nrxx); + +/** + * @brief Finish DFT+U calculation for LCAO method in iter_finish + * + * This function handles the DFT+U finalization during the SCF iteration. + * It calculates the occupation matrix and energy correction if needed. + * + * @param iter Current SCF iteration + * @param conv_esolver Whether ESolver has converged + * @param inp Input parameters + * @param dftu DFT+U object + * @param ucell Unit cell + * @param dm_vec Density matrix vector + * @param kv K-vectors + * @param mixing_beta Mixing beta parameter + * @param hamilt_lcao Hamiltonian LCAO object + */ +template +void finish_dftu_lcao(const int iter, + const bool conv_esolver, + const Input_para& inp, + void* dftu, + const UnitCell& ucell, + const std::vector>& dm_vec, + const K_Vectors& kv, + const double mixing_beta, + void* hamilt_lcao); + +} // namespace ModuleESolver + +#endif // DFTU_LCAO_H From 5a9648317862c698bb1a7f6fe225c3ca82f4dcfc Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 12:30:46 +0800 Subject: [PATCH 06/18] refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. --- source/source_esolver/esolver_ks_pw.cpp | 14 ++----- source/source_hsolver/CMakeLists.txt | 1 + source/source_hsolver/diago_params.cpp | 55 +++++++++++++++++++++++++ source/source_hsolver/diago_params.h | 29 +++++++++++++ 4 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 source/source_hsolver/diago_params.cpp create mode 100644 source/source_hsolver/diago_params.h diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b35ea19948..e255b95b46 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -6,6 +6,7 @@ #include "source_hsolver/diago_iter_assist.h" #include "source_hsolver/hsolver_pw.h" +#include "source_hsolver/diago_params.h" #include "source_hsolver/kernels/hegvd_op.h" #include "source_io/module_parameter/parameter.h" @@ -164,17 +165,8 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; - // choose if psi should be diag in subspace - // be careful that istep start from 0 and iter start from 1 - // if (iter == 1) - hsolver::DiagoIterAssist::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true; - hsolver::DiagoIterAssist::SCF_ITER = iter; - hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - - if (PARAM.inp.calculation != "nscf") - { - hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; - } + // setup diagonalization parameters + hsolver::setup_diago_params_pw(istep, iter, ethr, PARAM.inp); bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; diff --git a/source/source_hsolver/CMakeLists.txt b/source/source_hsolver/CMakeLists.txt index f4e87cdf94..b115d6d4cd 100644 --- a/source/source_hsolver/CMakeLists.txt +++ b/source/source_hsolver/CMakeLists.txt @@ -12,6 +12,7 @@ list(APPEND objects hsolver.cpp diago_pxxxgvx.cpp diag_hs_para.cpp + diago_params.cpp ) diff --git a/source/source_hsolver/diago_params.cpp b/source/source_hsolver/diago_params.cpp new file mode 100644 index 0000000000..a0c720a625 --- /dev/null +++ b/source/source_hsolver/diago_params.cpp @@ -0,0 +1,55 @@ +#include "diago_params.h" +#include "diago_iter_assist.h" + +namespace hsolver +{ + +template +void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp) +{ + /// choose if psi should be diag in subspace + /// be careful that istep start from 0 and iter start from 1 + DiagoIterAssist::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true; + DiagoIterAssist::SCF_ITER = iter; + DiagoIterAssist::PW_DIAG_THR = ethr; + + if (inp.calculation != "nscf") + { + DiagoIterAssist::PW_DIAG_NMAX = inp.pw_diag_nmax; + } +} + +/// Template instantiation for CPU +template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +/// Template instantiation for GPU +#if ((defined __CUDA) || (defined __ROCM)) +template void setup_diago_params_pw, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +#endif + +} // namespace hsolver diff --git a/source/source_hsolver/diago_params.h b/source/source_hsolver/diago_params.h new file mode 100644 index 0000000000..995090bebd --- /dev/null +++ b/source/source_hsolver/diago_params.h @@ -0,0 +1,29 @@ +#ifndef DIAGO_PARAMS_H +#define DIAGO_PARAMS_H + +#include "source_io/module_parameter/input_parameter.h" + +namespace hsolver +{ + +/** + * @brief Setup diagonalization parameters for PW method + * + * This function sets up the diagonalization parameters for plane wave method, + * including subspace diagonalization flag, SCF iteration number, diagonalization + * threshold, and maximum number of diagonalization steps. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param ethr Diagonalization threshold + * @param inp Input parameters + */ +template +void setup_diago_params_pw(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +} // namespace hsolver + +#endif // DIAGO_PARAMS_H From 4936dc1f28adc6fdbbf208825c95b4dabcbf8462 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 14:00:59 +0800 Subject: [PATCH 07/18] fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3f2. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. --- source/source_esolver/esolver_ks_lcao.cpp | 2 +- .../source_lcao/module_deltaspin/deltaspin_lcao.cpp | 11 ++++++++--- source/source_lcao/module_deltaspin/deltaspin_lcao.h | 3 ++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index e86c7a8cb6..43e83aa8df 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -430,7 +430,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& this->deepks.delta_e(ucell, this->kv, this->orb_, this->pv, this->gd, dm_vec, this->pelec->f_en, PARAM.inp); // 3) for delta spin - cal_mi_lcao_wrapper(iter); + cal_mi_lcao_wrapper(iter, PARAM.inp); // call iter_finish() of ESolver_KS, where band gap is printed, // eig and occ are printed, magnetization is calculated, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index c58b4f0783..96e969277c 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -33,8 +33,13 @@ void init_deltaspin_lcao(const UnitCell& ucell, } template -void cal_mi_lcao_wrapper(const int iter) +void cal_mi_lcao_wrapper(const int iter, const Input_para& inp) { + if (!inp.sc_mag_switch) + { + return; + } + spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); sc.cal_mi_lcao(iter); } @@ -86,8 +91,8 @@ template void init_deltaspin_lcao>(const UnitCell& ucell, void* dm, void* pelec); -template void cal_mi_lcao_wrapper(const int iter); -template void cal_mi_lcao_wrapper>(const int iter); +template void cal_mi_lcao_wrapper(const int iter, const Input_para& inp); +template void cal_mi_lcao_wrapper>(const int iter, const Input_para& inp); template bool run_deltaspin_lambda_loop_lcao(const int iter, const double drho, diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.h b/source/source_lcao/module_deltaspin/deltaspin_lcao.h index f91326490b..959109ece7 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.h +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.h @@ -40,9 +40,10 @@ void init_deltaspin_lcao(const UnitCell& ucell, * in the DeltaSpin method. * * @param iter Current iteration number + * @param inp Input parameters */ template -void cal_mi_lcao_wrapper(const int iter); +void cal_mi_lcao_wrapper(const int iter, const Input_para& inp); /** * @brief Run DeltaSpin lambda loop for LCAO method From ea218f642d08d38a871d7fe27dc488f08fdce25d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 18:03:31 +0800 Subject: [PATCH 08/18] fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. --- source/source_hsolver/diago_params.h | 18 ++++++++++++++++++ .../module_deltaspin/deltaspin_lcao.cpp | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/source/source_hsolver/diago_params.h b/source/source_hsolver/diago_params.h index 995090bebd..5d46b01046 100644 --- a/source/source_hsolver/diago_params.h +++ b/source/source_hsolver/diago_params.h @@ -24,6 +24,24 @@ void setup_diago_params_pw(const int istep, const double ethr, const Input_para& inp); +/** + * @brief Setup diagonalization parameters for SDFT method + * + * This function sets up the diagonalization parameters for stochastic DFT method, + * including subspace diagonalization flag, diagonalization threshold, and + * maximum number of diagonalization steps. + * + * @param istep Current ionic step + * @param iter Current SCF iteration + * @param ethr Diagonalization threshold + * @param inp Input parameters + */ +template +void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + } // namespace hsolver #endif // DIAGO_PARAMS_H diff --git a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp index 96e969277c..6a7effb6d0 100644 --- a/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp +++ b/source/source_lcao/module_deltaspin/deltaspin_lcao.cpp @@ -24,12 +24,20 @@ void init_deltaspin_lcao(const UnitCell& ucell, } spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); +#ifdef __LCAO sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, inp.sccut, inp.sc_drop_thr, ucell, static_cast(pv), inp.nspin, kv, p_hamilt, psi, static_cast*>(dm), static_cast(pelec)); +#else + sc.init_sc(inp.sc_thr, inp.nsc, inp.nsc_min, inp.alpha_trial, + inp.sccut, inp.sc_drop_thr, ucell, + static_cast(pv), + inp.nspin, kv, p_hamilt, psi, + static_cast(pelec)); +#endif } template @@ -40,8 +48,10 @@ void cal_mi_lcao_wrapper(const int iter, const Input_para& inp) return; } +#ifdef __LCAO spinconstrain::SpinConstrain& sc = spinconstrain::SpinConstrain::getScInstance(); sc.cal_mi_lcao(iter); +#endif } template From c365a3afd18cced4dfdc424c120fe0a0e3bfb4a6 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 10 Mar 2026 18:19:57 +0800 Subject: [PATCH 09/18] refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) --- source/source_esolver/esolver_sdft_pw.cpp | 16 ++----- source/source_hsolver/diago_params.cpp | 51 +++++++++++++++++++++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 798e52d26b..26118eed21 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -8,6 +8,7 @@ #include "source_pw/module_stodft/sto_forces.h" #include "source_pw/module_stodft/sto_stress_pw.h" #include "source_hsolver/diago_iter_assist.h" +#include "source_hsolver/diago_params.h" #include "source_io/module_parameter/parameter.h" #include @@ -142,20 +143,11 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i // reset energy this->pelec->f_en.eband = 0.0; this->pelec->f_en.demet = 0.0; - // choose if psi should be diag in subspace - // be careful that istep start from 0 and iter start from 1 - if (istep == 0 && iter == 1 || PARAM.inp.calculation == "nscf") - { - hsolver::DiagoIterAssist::need_subspace = false; - } - else - { - hsolver::DiagoIterAssist::need_subspace = true; - } + + // setup diagonalization parameters for SDFT + hsolver::setup_diago_params_sdft(istep, iter, ethr, PARAM.inp); bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; - hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; // hsolver only exists in this function hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, diff --git a/source/source_hsolver/diago_params.cpp b/source/source_hsolver/diago_params.cpp index a0c720a625..28e1040a97 100644 --- a/source/source_hsolver/diago_params.cpp +++ b/source/source_hsolver/diago_params.cpp @@ -22,6 +22,27 @@ void setup_diago_params_pw(const int istep, } } +template +void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp) +{ + /// choose if psi should be diag in subspace + /// be careful that istep start from 0 and iter start from 1 + if (istep == 0 && iter == 1 || inp.calculation == "nscf") + { + DiagoIterAssist::need_subspace = false; + } + else + { + DiagoIterAssist::need_subspace = true; + } + + DiagoIterAssist::PW_DIAG_THR = ethr; + DiagoIterAssist::PW_DIAG_NMAX = inp.pw_diag_nmax; +} + /// Template instantiation for CPU template void setup_diago_params_pw, base_device::DEVICE_CPU>(const int istep, const int iter, @@ -52,4 +73,34 @@ template void setup_diago_params_pw(const int i const Input_para& inp); #endif +/// Template instantiation for SDFT CPU +template void setup_diago_params_sdft, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft, base_device::DEVICE_CPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); + +/// Template instantiation for SDFT GPU +#if ((defined __CUDA) || (defined __ROCM)) +template void setup_diago_params_sdft, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft, base_device::DEVICE_GPU>(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +template void setup_diago_params_sdft(const int istep, + const int iter, + const double ethr, + const Input_para& inp); +#endif + } // namespace hsolver From cbd0ce77936542fed08b3c568fde9250fce50a7d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Wed, 11 Mar 2026 13:18:58 +0800 Subject: [PATCH 10/18] refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt - Modify Hamilt to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h --- source/source_esolver/esolver_ks.h | 5 +- source/source_esolver/esolver_ks_lcaopw.cpp | 2 +- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 4 +- source/source_hamilt/hamilt.h | 13 ++++-- source/source_hamilt/hamilt_base.h | 52 +++++++++++++++++++++ 6 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 source/source_hamilt/hamilt_base.h diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 787b58ba74..1913aa4101 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -7,6 +7,7 @@ #include "source_estate/module_charge/charge_mixing.h" // use charge mixing #include "source_psi/psi.h" // use electronic wave functions #include "source_hamilt/hamilt.h" // use Hamiltonian +#include "source_hamilt/hamilt_base.h" // use Hamiltonian base class #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 namespace ModuleESolver @@ -47,8 +48,8 @@ class ESolver_KS : public ESolver_FP //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override; - //! Hamiltonian - hamilt::Hamilt* p_hamilt = nullptr; + //! Hamiltonian (base class pointer, actual type determined at runtime) + hamilt::HamiltBase* p_hamilt = nullptr; //! PW for wave functions, only used in KSDFT, not in OFDFT ModulePW::PW_Basis_K* pw_wfc = nullptr; diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index f9700f5b68..db00b6265d 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -146,7 +146,7 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, + hsolver_lip_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index e255b95b46..a032c2c976 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -188,7 +188,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste hsolver::DiagoIterAssist::need_subspace, PARAM.inp.use_k_continuity); - hsolver_pw_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, + hsolver_pw_obj.solve(static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat); } diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 26118eed21..597825cd6d 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -167,7 +167,7 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver::DiagoIterAssist::need_subspace); hsolver_pw_sdft_obj.solve(ucell, - this->p_hamilt, + static_cast*>(this->p_hamilt), this->stp.psi_t[0], this->stp.psi_cpu[0], this->pelec, @@ -291,7 +291,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) this->pw_wfc, this->stp.psi_t, &this->ppcell, - this->p_hamilt, + static_cast, Device>*>(this->p_hamilt), this->stoche, &stowf); sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto); diff --git a/source/source_hamilt/hamilt.h b/source/source_hamilt/hamilt.h index 6d732d7a82..3d554c0fe6 100644 --- a/source/source_hamilt/hamilt.h +++ b/source/source_hamilt/hamilt.h @@ -7,21 +7,28 @@ #include "matrixblock.h" #include "source_psi/psi.h" #include "operator.h" +#include "hamilt_base.h" namespace hamilt { template -class Hamilt +class Hamilt : public HamiltBase { public: virtual ~Hamilt(){}; /// for target K point, update consequence of hPsi() and matrix() - virtual void updateHk(const int ik){return;} + void updateHk(const int ik) override { return; } /// refresh status of Hamiltonian, for example, refresh H(R) and S(R) in LCAO case - virtual void refresh(bool yes = true){return;} + void refresh(bool yes = true) override { return; } + + /// get the class name + std::string get_classname() const override { return classname; } + + /// get the operator chain + void* get_ops() override { return static_cast(ops); } /// core function: for solving eigenvalues of Hamiltonian with iterative method virtual void hPsi( diff --git a/source/source_hamilt/hamilt_base.h b/source/source_hamilt/hamilt_base.h new file mode 100644 index 0000000000..06325bf050 --- /dev/null +++ b/source/source_hamilt/hamilt_base.h @@ -0,0 +1,52 @@ +#ifndef HAMILT_BASE_H +#define HAMILT_BASE_H + +#include + +namespace hamilt +{ + +/** + * @brief Base class for Hamiltonian + * + * This is a non-template base class for Hamilt. + * It provides a common interface for all Hamiltonian types, + * allowing ESolver to manage Hamiltonian without template parameters. + */ +class HamiltBase +{ + public: + virtual ~HamiltBase() {} + + /** + * @brief Update Hamiltonian for a specific k-point + * + * @param ik k-point index + */ + virtual void updateHk(const int ik) { return; } + + /** + * @brief Refresh the status of Hamiltonian + * + * @param yes whether to refresh + */ + virtual void refresh(bool yes = true) { return; } + + /** + * @brief Get the class name + * + * @return class name + */ + virtual std::string get_classname() const { return "none"; } + + /** + * @brief Get the operator chain (as void* to avoid template) + * + * @return pointer to operator chain + */ + virtual void* get_ops() { return nullptr; } +}; + +} // namespace hamilt + +#endif // HAMILT_BASE_H From 6e0f43ca072682621867dee28d5ee8d4e58e943d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Wed, 11 Mar 2026 16:02:42 +0800 Subject: [PATCH 11/18] refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast*> when passing p_hamilt to functions expecting Hamilt* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt* when needed. --- source/source_esolver/esolver_gets.cpp | 12 ++++++++---- source/source_esolver/esolver_ks_lcao.cpp | 4 ++-- source/source_esolver/esolver_ks_lcao_tddft.cpp | 12 ++++++------ source/source_esolver/esolver_ks_pw.cpp | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_gets.cpp b/source/source_esolver/esolver_gets.cpp index 7eff0f537a..e03e7b8bdc 100644 --- a/source/source_esolver/esolver_gets.cpp +++ b/source/source_esolver/esolver_gets.cpp @@ -108,8 +108,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, std::complex>*>(this->p_hamilt->ops) - ->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, std::complex>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } else { @@ -119,13 +120,16 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep) this->kv, *(two_center_bundle_.overlap_orb), orb_.cutoffs()); - dynamic_cast, double>*>(this->p_hamilt->ops)->contributeHR(); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + auto* ops_ptr = dynamic_cast, double>*>(hamilt_ptr->ops); + ops_ptr->contributeHR(); } } const std::string fn = PARAM.globalv.global_out_dir + "sr_nao.csr"; - ModuleIO::output_SR(pv, gd, this->p_hamilt, fn); + auto* hamilt_ptr = static_cast>*>(this->p_hamilt); + ModuleIO::output_SR(pv, gd, hamilt_ptr, fn); if (PARAM.inp.out_mat_r) { diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 43e83aa8df..bad1cec7b6 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -179,7 +179,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) { //! 13.1.2) init charge density from Hamiltonian matrix file LCAO_domain::init_chg_hr(PARAM.globalv.global_readin_dir, PARAM.inp.nspin, - this->p_hamilt, ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, + static_cast*>(this->p_hamilt), ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.ks_solver); } } @@ -382,7 +382,7 @@ void ESolver_KS_LCAO::hamilt2rho_single(UnitCell& ucell, int istep, int if (!skip_solve) { hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm, + hsolver_lcao_obj.solve(static_cast*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, this->chr, PARAM.inp.nspin, skip_charge); } diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index b7641a09fc..05dc8c9233 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -235,7 +235,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -255,7 +255,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, PARAM.inp.nbands, PARAM.globalv.nlocal, this->kv.get_nks(), - this->p_hamilt, + static_cast>*>(this->p_hamilt), this->pv, this->psi, this->psi_laststep, @@ -277,7 +277,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2rho_single(UnitCell& ucell, { bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLCAO> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, + hsolver_lcao_obj.solve(static_cast>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm, @@ -342,11 +342,11 @@ void ESolver_KS_LCAO_TDDFT::iter_finish(UnitCell& ucell, { if (use_tensor && use_lapack) { - elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft_tensor_lapack(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } else { - elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt); + elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast>*>(this->p_hamilt)); } } } @@ -416,7 +416,7 @@ void ESolver_KS_LCAO_TDDFT::store_h_s_psi(UnitCell& ucell, this->p_hamilt->updateHk(ik); hamilt::MatrixBlock> h_mat; hamilt::MatrixBlock> s_mat; - this->p_hamilt->matrix(h_mat, s_mat); + static_cast>*>(this->p_hamilt)->matrix(h_mat, s_mat); // Store H and S matrices to Hk_laststep and Sk_laststep if (use_tensor && use_lapack) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index a032c2c976..b2976733bf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -128,10 +128,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06 pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell, - this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.psi_t, static_cast*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp); // setup psi (electronic wave functions) - this->stp.init(this->p_hamilt); + this->stp.init(static_cast*>(this->p_hamilt)); //! Setup EXX helper for Hamiltonian and psi exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); From 14c1b8a78e01fbd7b07f4670f1bf5a3b6dcfb032 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 06:25:32 +0800 Subject: [PATCH 12/18] refactor(esolver): remove psi member from ESolver_KS base class Move psi::Psi* psi from ESolver_KS base class to derived classes to eliminate template parameter dependency and improve code organization. Changes: 1. ESolver_KS base class: - Remove psi::Psi* psi member variable - Remove Setup_Psi::deallocate_psi() call in destructor - Remove unnecessary includes: psi.h and setup_psi.h 2. ESolver_KS_LCAO: - Add psi::Psi* psi member variable - Add Setup_Psi::deallocate_psi() in destructor - Add include: setup_psi.h 3. ESolver_KS_LCAO_TDDFT: - Improve psi_laststep deallocation with nullptr check - psi member inherited from ESolver_KS_LCAO 4. ESolver_KS_PW: - Use stp.psi_cpu directly instead of base class psi - Remove unnecessary memory allocation in after_scf() 5. pw_others.cpp (BUG FIX): - Fix gen_bessel: use *(this->stp.psi_cpu) instead of this->psi[0] - Previous code accessed uninitialized base class psi (nullptr) - This was a latent bug that could cause crashes Benefits: - Eliminates template parameter T dependency in ESolver_KS base class - Clearer memory management: each derived class manages its own psi - Reduces compilation dependencies - Fixes potential memory access bug in pw_others.cpp Tested: Compiled successfully in build_5pt and build_1p --- source/source_esolver/esolver_ks.cpp | 3 --- source/source_esolver/esolver_ks.h | 4 ---- source/source_esolver/esolver_ks_lcao.cpp | 2 ++ source/source_esolver/esolver_ks_lcao.h | 3 +++ source/source_esolver/esolver_ks_lcao_tddft.cpp | 6 +++++- source/source_esolver/esolver_ks_pw.cpp | 7 ++----- source/source_esolver/pw_others.cpp | 2 +- 7 files changed, 13 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index fc99b8a572..93fb116aca 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -15,7 +15,6 @@ #include "source_io/module_output/output_log.h" // use write_head #include "source_estate/elecstate_print.h" // print_etot #include "source_io/module_output/print_info.h" // print_parameters -#include "source_psi/setup_psi.h" // mohan add 20251009 #include "source_lcao/module_dftu/dftu.h" // mohan add 2025-11-07 namespace ModuleESolver @@ -31,8 +30,6 @@ ESolver_KS::~ESolver_KS() //**************************************************** // do not add any codes in this deconstructor funcion //**************************************************** - Setup_Psi::deallocate_psi(this->psi); - delete this->p_hamilt; delete this->p_chgmix; this->ppcell.release_memory(); diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index 1913aa4101..eee36fbd88 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -5,7 +5,6 @@ #include "source_basis/module_pw/pw_basis_k.h" // use plane wave #include "source_cell/klist.h" // use k-points in Brillouin zone #include "source_estate/module_charge/charge_mixing.h" // use charge mixing -#include "source_psi/psi.h" // use electronic wave functions #include "source_hamilt/hamilt.h" // use Hamiltonian #include "source_hamilt/hamilt_base.h" // use Hamiltonian base class #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 @@ -60,9 +59,6 @@ class ESolver_KS : public ESolver_FP //! nonlocal pseudopotentials pseudopot_cell_vnl ppcell; - //! Electronic wavefunctions - psi::Psi* psi = nullptr; - //! DFT+U method, mohan add 2025-11-07 Plus_U dftu; diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index bad1cec7b6..d418f762c7 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -22,6 +22,7 @@ #include "source_io/module_output/print_info.h" #include "source_lcao/rho_tau_lcao.h" // mohan add 20251024 #include "source_lcao/LCAO_set.h" // mohan add 20251111 +#include "source_psi/setup_psi.h" // use Setup_Psi for deallocate_psi namespace ModuleESolver { @@ -40,6 +41,7 @@ ESolver_KS_LCAO::~ESolver_KS_LCAO() //**************************************************** // do not add any codes in this deconstructor funcion //**************************************************** + Setup_Psi::deallocate_psi(this->psi); } template diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 4191306788..0e013ec9ae 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -57,6 +57,9 @@ class ESolver_KS_LCAO : public ESolver_KS virtual void others(UnitCell& ucell, const int istep) override; + //! Electronic wave functions (moved from base class) + psi::Psi* psi = nullptr; + //! Store information about Adjacent Atoms Record_adj RA; diff --git a/source/source_esolver/esolver_ks_lcao_tddft.cpp b/source/source_esolver/esolver_ks_lcao_tddft.cpp index 05dc8c9233..2e463acdcd 100644 --- a/source/source_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/source_esolver/esolver_ks_lcao_tddft.cpp @@ -40,7 +40,11 @@ ESolver_KS_LCAO_TDDFT::~ESolver_KS_LCAO_TDDFT() //************************************************* // Do not add any code in this destructor function //************************************************* - delete psi_laststep; + if (psi_laststep != nullptr) + { + delete psi_laststep; + psi_laststep = nullptr; + } if (td_p != nullptr) { diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b2976733bf..7415552c90 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -240,13 +240,10 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const ModuleBase::TITLE("ESolver_KS_PW", "after_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); - // Since ESolver_KS::psi is hidden by ESolver_KS_PW::psi, - // we need to copy the data from ESolver_KS::psi to ESolver_KS_PW::psi. - // sunliang 2025-04-10 + // Calculate kinetic energy density tau for ELF if needed if (PARAM.inp.out_elf[0] > 0) { - this->ESolver_KS::psi = new psi::Psi(this->stp.psi_cpu[0]); - this->pelec->cal_tau(*(this->psi)); + this->pelec->cal_tau(*(this->stp.psi_cpu)); } ESolver_KS::after_scf(ucell, istep, conv_esolver); diff --git a/source/source_esolver/pw_others.cpp b/source/source_esolver/pw_others.cpp index fc42df14bd..49f7465b46 100644 --- a/source/source_esolver/pw_others.cpp +++ b/source/source_esolver/pw_others.cpp @@ -32,7 +32,7 @@ void ESolver_KS_PW::others(UnitCell& ucell, const int istep) { Numerical_Descriptor nc; nc.output_descriptor(ucell, - this->psi[0], + *(this->stp.psi_cpu), PARAM.inp.bessel_descriptor_lmax, PARAM.inp.bessel_descriptor_rcut, PARAM.inp.bessel_descriptor_tolerence, From 0c2fa0f7b6d57ef2a737be317cad01f38526de97 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 06:36:33 +0800 Subject: [PATCH 13/18] refactor(esolver): remove template parameters from ESolver_KS base class This is a major milestone in ESolver refactoring! ESolver_KS no longer needs template parameters because: - All member variables are non-template types - All member functions do not use T or Device parameters - Template parameters were only needed for derived classes Changes: 1. ESolver_KS base class: - Remove template declaration - Remove all template declarations from member functions - Remove template instantiation code at end of file - Fix Tab indentation to spaces for better readability 2. Derived classes: - ESolver_KS_PW: public ESolver_KS (was ESolver_KS) - ESolver_KS_LCAO: public ESolver_KS (was ESolver_KS) - ESolver_GetS: public ESolver_KS (was ESolver_KS>) - Update base class calls: ESolver_KS:: (was ESolver_KS::) Code reduction: - esolver_ks.h: 78 -> 77 lines (-1 line) - esolver_ks.cpp: 346 -> 317 lines (-29 lines) - Total ESolver code: 424 -> 394 lines (-30 lines) - Overall: 8 files changed, 50 insertions(+), 80 deletions(-), net -30 lines Benefits: - Simpler base class without template complexity - Faster compilation (no template instantiation needed) - Clearer inheritance hierarchy - Easier to extract common code in future refactoring - Sets foundation for further ESolver template removal Tested: Compiled successfully in build_5pt --- source/source_esolver/esolver_gets.h | 2 +- source/source_esolver/esolver_ks.cpp | 97 ++++++++--------------- source/source_esolver/esolver_ks.h | 1 - source/source_esolver/esolver_ks_lcao.cpp | 12 +-- source/source_esolver/esolver_ks_lcao.h | 2 +- source/source_esolver/esolver_ks_pw.cpp | 12 +-- source/source_esolver/esolver_ks_pw.h | 2 +- source/source_esolver/esolver_sdft_pw.cpp | 2 +- 8 files changed, 50 insertions(+), 80 deletions(-) diff --git a/source/source_esolver/esolver_gets.h b/source/source_esolver/esolver_gets.h index 564fd55035..7a7fb1d34b 100644 --- a/source/source_esolver/esolver_gets.h +++ b/source/source_esolver/esolver_gets.h @@ -10,7 +10,7 @@ namespace ModuleESolver { -class ESolver_GetS : public ESolver_KS> +class ESolver_GetS : public ESolver_KS { public: ESolver_GetS(); diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 93fb116aca..cc94510a66 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -20,27 +20,24 @@ namespace ModuleESolver { -template -ESolver_KS::ESolver_KS(){} +ESolver_KS::ESolver_KS() {} -template -ESolver_KS::~ESolver_KS() +ESolver_KS::~ESolver_KS() { - //**************************************************** - // do not add any codes in this deconstructor funcion - //**************************************************** + //**************************************************** + // do not add any codes in this deconstructor funcion + //**************************************************** delete this->p_hamilt; delete this->p_chgmix; this->ppcell.release_memory(); - + // mohan add 2025-10-18, should be put int clean() function pw::teardown_pwwfc(this->pw_wfc); } -template -void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para& inp) +void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para& inp) { ModuleBase::TITLE("ESolver_KS", "before_all_runners"); @@ -78,12 +75,10 @@ void ESolver_KS::before_all_runners(UnitCell& ucell, const Input_para } -template -void ESolver_KS::hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) +void ESolver_KS::hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) {} -template -void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const int iter, const double ethr) +void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const int iter, const double ethr) { // 1) use Hamiltonian to obtain charge density this->hamilt2rho_single(ucell, istep, iter, diag_ethr); @@ -123,8 +118,7 @@ void ESolver_KS::hamilt2rho(UnitCell& ucell, const int istep, const i } } -template -void ESolver_KS::runner(UnitCell& ucell, const int istep) +void ESolver_KS::runner(UnitCell& ucell, const int istep) { ModuleBase::TITLE("ESolver_KS", "runner"); ModuleBase::timer::tick(this->classname, "runner"); @@ -139,14 +133,14 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) this->diag_ethr = PARAM.inp.pw_diag_thr; this->scf_nmax_flag = false; // mohan add 2025-09-21 for (int iter = 1; iter <= this->maxniter; ++iter) - { - if(iter == this->maxniter) - { - this->scf_nmax_flag=true; - } + { + if(iter == this->maxniter) + { + this->scf_nmax_flag=true; + } - // 3) initialization of SCF iterations - this->iter_init(ucell, istep, iter); + // 3) initialization of SCF iterations + this->iter_init(ucell, istep, iter); // 4) use Hamiltonian to obtain charge density this->hamilt2rho(ucell, istep, iter, diag_ethr); @@ -166,22 +160,20 @@ void ESolver_KS::runner(UnitCell& ucell, const int istep) } } // end scf iterations - // 7) after scf + // 7) after scf this->after_scf(ucell, istep, conv_esolver); ModuleBase::timer::tick(this->classname, "runner"); return; }; -template -void ESolver_KS::before_scf(UnitCell& ucell, const int istep) +void ESolver_KS::before_scf(UnitCell& ucell, const int istep) { ModuleBase::TITLE("ESolver_KS", "before_scf"); ESolver_FP::before_scf(ucell, istep); } -template -void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const int iter) +void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const int iter) { if(PARAM.inp.esolver_type != "tddft") { @@ -207,8 +199,7 @@ void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const in this->chr.save_rho_before_sum_band(); } -template -void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& iter, bool &conv_esolver) +void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& iter, bool &conv_esolver) { // 1.1) print out band gap @@ -224,25 +215,25 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i // 1.2) print out eigenvalues and occupations if (PARAM.inp.out_band[0]) { - if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) - { - ModuleIO::write_eig_iter(this->pelec->ekb,this->pelec->wg,*this->pelec->klist); - } + if (iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) + { + ModuleIO::write_eig_iter(this->pelec->ekb,this->pelec->wg,*this->pelec->klist); + } } // 2.1) compute magnetization, only for spin==2 ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho, this->pelec->nelec_spin.data()); - // 2.2) charge mixing + // 2.2) charge mixing // SCF will continue if U is not converged for uramping calculation - bool converged_u = true; - // to avoid unnecessary dependence on dft+u, refactor is needed + bool converged_u = true; + // to avoid unnecessary dependence on dft+u, refactor is needed #ifdef __LCAO - if (PARAM.inp.dft_plus_u) - { - converged_u = this->dftu.u_converged(); - } + if (PARAM.inp.dft_plus_u) + { + converged_u = this->dftu.u_converged(); + } #endif module_charge::chgmixing_ks(iter, ucell, this->pelec, this->chr, this->p_chgmix, @@ -293,8 +284,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i } //! Something to do after SCF iterations when SCF is converged or comes to the max iter step. -template -void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) +void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) { ModuleBase::TITLE("ESolver_KS", "after_scf"); @@ -318,29 +308,10 @@ void ESolver_KS::after_scf(UnitCell& ucell, const int istep, const bo } -template -void ESolver_KS::after_all_runners(UnitCell& ucell) +void ESolver_KS::after_all_runners(UnitCell& ucell) { // 1) write Etot information ESolver_FP::after_all_runners(ucell); } -//------------------------------------------------------------------------------ -//! the 16th-20th functions of ESolver_KS -//! mohan add 2024-05-12 -//------------------------------------------------------------------------------ -//! This is for mixed-precision pw/LCAO basis sets. -template class ESolver_KS, base_device::DEVICE_CPU>; -template class ESolver_KS, base_device::DEVICE_CPU>; - -//! This is for GPU codes. -#if ((defined __CUDA) || (defined __ROCM)) -template class ESolver_KS, base_device::DEVICE_GPU>; -template class ESolver_KS, base_device::DEVICE_GPU>; -#endif - -//! This is for LCAO basis set. -#ifdef __LCAO -template class ESolver_KS; -#endif } // namespace ModuleESolver diff --git a/source/source_esolver/esolver_ks.h b/source/source_esolver/esolver_ks.h index eee36fbd88..b6affc7b0c 100644 --- a/source/source_esolver/esolver_ks.h +++ b/source/source_esolver/esolver_ks.h @@ -12,7 +12,6 @@ namespace ModuleESolver { -template class ESolver_KS : public ESolver_FP { public: diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index d418f762c7..0558942e91 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -51,7 +51,7 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners"); // 1) before_all_runners in ESolver_KS - ESolver_KS::before_all_runners(ucell, inp); + ESolver_KS::before_all_runners(ucell, inp); // 2) autoset nbands in ElecState before init_basis (for Psi 2d division) if (this->pelec == nullptr) @@ -107,7 +107,7 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ModuleBase::timer::tick("ESolver_KS_LCAO", "before_scf"); //! 1) call before_scf() of ESolver_KS. - ESolver_KS::before_scf(ucell, istep); + ESolver_KS::before_scf(ucell, istep); //! 2) find search radius double search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running, @@ -271,7 +271,7 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) ModuleBase::TITLE("ESolver_KS_LCAO", "after_all_runners"); ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners"); - ESolver_KS::after_all_runners(ucell); + ESolver_KS::after_all_runners(ucell); auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); if(!hamilt_lcao) @@ -303,7 +303,7 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init"); // call iter_init() of ESolver_KS - ESolver_KS::iter_init(ucell, istep, iter); + ESolver_KS::iter_init(ucell, istep, iter); module_charge::chgmixing_ks_lcao(iter, this->p_chgmix, this->dftu, this->dmat.dm->get_DMR_pointer(1)->get_nnr(), PARAM.inp); @@ -438,7 +438,7 @@ void ESolver_KS_LCAO::iter_finish(UnitCell& ucell, const int istep, int& // eig and occ are printed, magnetization is calculated, // charge mixing is performed, potential is updated, // HF and kS energies are computed, meta-GGA, Jason and restart - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); // mix density matrix if mixing_restart + mixing_dmr + not first // mixing_restart at every iter except the last iter @@ -476,7 +476,7 @@ void ESolver_KS_LCAO::after_scf(UnitCell& ucell, const int istep, const } //! 1) call after_scf() of ESolver_KS - ESolver_KS::after_scf(ucell, istep, conv_esolver); + ESolver_KS::after_scf(ucell, istep, conv_esolver); //! 2) output of lcao every few ionic steps ModuleIO::ctrl_scf_lcao(ucell, diff --git a/source/source_esolver/esolver_ks_lcao.h b/source/source_esolver/esolver_ks_lcao.h index 0e013ec9ae..143f7089ba 100644 --- a/source/source_esolver/esolver_ks_lcao.h +++ b/source/source_esolver/esolver_ks_lcao.h @@ -28,7 +28,7 @@ namespace ModuleESolver { template -class ESolver_KS_LCAO : public ESolver_KS +class ESolver_KS_LCAO : public ESolver_KS { public: ESolver_KS_LCAO(); diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 7415552c90..74507e09c7 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -84,7 +84,7 @@ void ESolver_KS_PW::deallocate_hamilt() template void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_para& inp) { - ESolver_KS::before_all_runners(ucell, inp); + ESolver_KS::before_all_runners(ucell, inp); //! setup and allocation for pelec, potentials, etc. elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr, @@ -105,7 +105,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ModuleBase::TITLE("ESolver_KS_PW", "before_scf"); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); - ESolver_KS::before_scf(ucell, istep); + ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp); @@ -142,7 +142,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) template void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const int iter) { - ESolver_KS::iter_init(ucell, istep, iter); + ESolver_KS::iter_init(ucell, istep, iter); module_charge::chgmixing_ks_pw(iter, this->p_chgmix, this->dftu, PARAM.inp); @@ -212,7 +212,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell); // Call iter_finish() of ESolver_KS - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); // D in USPP needs vloc, thus needs update when veff updated // calculate the effective coefficient matrix for non-local @@ -246,7 +246,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->pelec->cal_tau(*(this->stp.psi_cpu)); } - ESolver_KS::after_scf(ucell, istep, conv_esolver); + ESolver_KS::after_scf(ucell, istep, conv_esolver); // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, @@ -300,7 +300,7 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s template void ESolver_KS_PW::after_all_runners(UnitCell& ucell) { - ESolver_KS::after_all_runners(ucell); + ESolver_KS::after_all_runners(ucell); ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 01e1027d79..6a6be52b73 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -13,7 +13,7 @@ namespace ModuleESolver { template -class ESolver_KS_PW : public ESolver_KS +class ESolver_KS_PW : public ESolver_KS { private: using Real = typename GetTypeReal::type; diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 597825cd6d..f7f9a29983 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -119,7 +119,7 @@ template void ESolver_SDFT_PW::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) { // call iter_finish() of ESolver_KS - ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); + ESolver_KS::iter_finish(ucell, istep, iter, conv_esolver); } template From 4b51e36da270685a31a6d8062a3fd499d86a2f3b Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 08:48:04 +0800 Subject: [PATCH 14/18] refactor(device): remove explicit template parameter from get_device_type calls - Move get_device_type implementation to header file using std::is_same - Add DEVICE_DSP support - Remove template specialization declarations and definitions - Update all call sites to use automatic template parameter deduction - The compiler now deduces Device type from the ctx parameter --- source/source_base/math_chebyshev.cpp | 14 +++++++------- .../source_base/module_device/device_helpers.cpp | 13 ------------- .../source_base/module_device/device_helpers.h | 16 ++++++++-------- .../module_device/test/device_test.cpp | 4 ++-- source/source_esolver/esolver_ks_pw.cpp | 2 +- source/source_hsolver/diago_dav_subspace.cpp | 2 +- source/source_hsolver/diago_david.cpp | 2 +- source/source_hsolver/diago_iter_assist.cpp | 8 ++++---- source/source_hsolver/test/hsolver_pw_sup.h | 2 +- source/source_pw/module_pwdft/forces.cpp | 6 +++--- source/source_pw/module_pwdft/forces_cc.cpp | 4 ++-- source/source_pw/module_pwdft/forces_scc.cpp | 2 +- source/source_pw/module_pwdft/fs_kin_tools.cpp | 2 +- .../source_pw/module_pwdft/fs_nonlocal_tools.cpp | 2 +- source/source_pw/module_pwdft/nonlocal_maths.hpp | 4 ++-- .../source_pw/module_pwdft/onsite_proj_tools.cpp | 4 ++-- .../source_pw/module_pwdft/onsite_projector.cpp | 2 +- source/source_pw/module_pwdft/op_pw_ekin.cpp | 4 ++-- source/source_pw/module_pwdft/stress_cc.cpp | 2 +- source/source_pw/module_pwdft/stress_loc.cpp | 2 +- .../module_pwdft/structure_factor_k.cpp | 2 +- source/source_pw/module_stodft/sto_forces.cpp | 2 +- source/source_pw/module_stodft/sto_wf.cpp | 10 +++++----- 23 files changed, 49 insertions(+), 62 deletions(-) diff --git a/source/source_base/math_chebyshev.cpp b/source/source_base/math_chebyshev.cpp index 8a84686ea5..b7e59a89f9 100644 --- a/source/source_base/math_chebyshev.cpp +++ b/source/source_base/math_chebyshev.cpp @@ -61,7 +61,7 @@ Chebyshev::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord } coefr_cpu = new REAL[norder]; coefc_cpu = new std::complex[norder]; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { resmem_var_op()(this->coef_real, norder); resmem_complex_op()(this->coef_complex, norder); @@ -82,7 +82,7 @@ template Chebyshev::~Chebyshev() { delete[] polytrace; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { delmem_var_op()(this->coef_real); delmem_complex_op()(this->coef_complex); @@ -209,7 +209,7 @@ void Chebyshev::calcoef_real(std::function fun) } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_var_h2d_op()(coef_real, coefr_cpu, norder); } @@ -299,7 +299,7 @@ void Chebyshev::calcoef_complex(std::function(s coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -390,7 +390,7 @@ void Chebyshev::calcoef_pair(std::function fun1, std:: } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -684,7 +684,7 @@ bool Chebyshev::checkconverge( funA(arrayn_1, arrayn, 1); REAL sum1, sum2; REAL t; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn_1, arrayn_1, N); sum2 = this->ddot_real(arrayn_1, arrayn, N); @@ -714,7 +714,7 @@ bool Chebyshev::checkconverge( for (int ior = 2; ior < norder; ++ior) { funA(arrayn, arraynp1, 1); - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn, arrayn, N); sum2 = this->ddot_real(arrayn, arraynp1, N); diff --git a/source/source_base/module_device/device_helpers.cpp b/source/source_base/module_device/device_helpers.cpp index 1c53020718..0b5d5a1693 100644 --- a/source/source_base/module_device/device_helpers.cpp +++ b/source/source_base/module_device/device_helpers.cpp @@ -3,19 +3,6 @@ namespace base_device { -// Device type specializations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev) -{ - return CpuDevice; -} - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev) -{ - return GpuDevice; -} - // Precision specializations template <> std::string get_current_precision(const float* var) diff --git a/source/source_base/module_device/device_helpers.h b/source/source_base/module_device/device_helpers.h index 60eddd888d..6aa71938de 100644 --- a/source/source_base/module_device/device_helpers.h +++ b/source/source_base/module_device/device_helpers.h @@ -13,6 +13,7 @@ #include "types.h" #include #include +#include namespace base_device { @@ -24,14 +25,13 @@ namespace base_device * @return AbacusDevice_t enum value */ template -AbacusDevice_t get_device_type(const Device* dev); - -// Template specialization declarations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev); - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev); +AbacusDevice_t get_device_type(const Device* dev) +{ + if (std::is_same::value) return CpuDevice; + else if (std::is_same::value) return GpuDevice; + else if (std::is_same::value) return DspDevice; + else return UnKnown; +} /** * @brief Get the precision string for a given numeric type. diff --git a/source/source_base/module_device/test/device_test.cpp b/source/source_base/module_device/test/device_test.cpp index 02d485c8ef..faf083c721 100644 --- a/source/source_base/module_device/test/device_test.cpp +++ b/source/source_base/module_device/test/device_test.cpp @@ -20,14 +20,14 @@ class TestModulePsiDevice : public ::testing::Test TEST_F(TestModulePsiDevice, get_device_type_cpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); EXPECT_EQ(device, base_device::CpuDevice); } #if __UT_USE_CUDA || __UT_USE_ROCM TEST_F(TestModulePsiDevice, get_device_type_gpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); EXPECT_EQ(device, base_device::GpuDevice); } #endif // __UT_USE_CUDA || __UT_USE_ROCM diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 74507e09c7..a876039457 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,7 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); } template diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 27c6a5b348..4ff93d03e9 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -36,7 +36,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in), diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->one = &one_; this->zero = &zero_; diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index ef4ba67cf3..49d5d0d953 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -20,7 +20,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; this->one = &one_; diff --git a/source/source_hsolver/diago_iter_assist.cpp b/source/source_hsolver/diago_iter_assist.cpp index fb87ad2350..8c5673c37a 100644 --- a/source/source_hsolver/diago_iter_assist.cpp +++ b/source/source_hsolver/diago_iter_assist.cpp @@ -400,14 +400,14 @@ void DiagoIterAssist::diag_heevx(const int matrix_size, // (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors); heevx_op()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // eigenvalues to e, from device to host syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // eigenvalues to e syncmem_var_op()(e, eigenvalues, num_eigenpairs); @@ -436,14 +436,14 @@ void DiagoIterAssist::diag_hegvd(const int nstart, hegvd_op()(ctx, nstart, ldh, hcc, scc, eigenvalues, vcc); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // set eigenvalues in GPU to e in CPU syncmem_var_d2h_op()(e, eigenvalues, nbands); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // set eigenvalues in CPU to e in CPU syncmem_var_op()(e, eigenvalues, nbands); diff --git a/source/source_hsolver/test/hsolver_pw_sup.h b/source/source_hsolver/test/hsolver_pw_sup.h index a5aab01735..5f5108c627 100644 --- a/source/source_hsolver/test/hsolver_pw_sup.h +++ b/source/source_hsolver/test/hsolver_pw_sup.h @@ -126,7 +126,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const bool use_paw_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; test_david = 2; diff --git a/source/source_pw/module_pwdft/forces.cpp b/source/source_pw/module_pwdft/forces.cpp index a6894c49ca..3e58b737ba 100644 --- a/source/source_pw/module_pwdft/forces.cpp +++ b/source/source_pw/module_pwdft/forces.cpp @@ -38,7 +38,7 @@ void Forces::cal_force(UnitCell& ucell, { ModuleBase::timer::tick("Forces", "cal_force"); ModuleBase::TITLE("Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const ModuleBase::matrix& ekb = elec.ekb; const Charge* const chr = elec.charge; @@ -331,7 +331,7 @@ void Forces::cal_force_loc(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_loc"); ModuleBase::timer::tick("Forces", "cal_force_loc"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::complex* aux = new std::complex[rho_basis->nmaxgr]; // now, in all pools , the charge are the same, // so, the force calculated by each pool is equal. @@ -478,7 +478,7 @@ void Forces::cal_force_ew(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_ew"); ModuleBase::timer::tick("Forces", "cal_force_ew"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); double fact = 2.0; std::vector> aux(rho_basis->npw); diff --git a/source/source_pw/module_pwdft/forces_cc.cpp b/source/source_pw/module_pwdft/forces_cc.cpp index 7788ed2af5..917e00b83c 100644 --- a/source/source_pw/module_pwdft/forces_cc.cpp +++ b/source/source_pw/module_pwdft/forces_cc.cpp @@ -116,7 +116,7 @@ void Forces::cal_force_cc(ModuleBase::matrix& forcecc, double *force_d = nullptr; double *rhocgigg_vec_d = nullptr; std::complex* psiv_d = nullptr; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); for (int ig = 0; ig < rho_basis->npw; ig++) @@ -258,7 +258,7 @@ void Forces::deriv_drhoc double gx = 0, rhocg1 = 0; //double *aux = new double[mesh]; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform // auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/forces_scc.cpp b/source/source_pw/module_pwdft/forces_scc.cpp index 7134232416..5e00b87dca 100644 --- a/source/source_pw/module_pwdft/forces_scc.cpp +++ b/source/source_pw/module_pwdft/forces_scc.cpp @@ -152,7 +152,7 @@ void Forces::deriv_drhoc_scc(const bool& numeric, int igl0 = 0; double gx = 0; double rhocg1 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); /// the modulus of g for a given shell /// the fourier transform /// auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/fs_kin_tools.cpp b/source/source_pw/module_pwdft/fs_kin_tools.cpp index 853ae34abd..0c04b26f2a 100644 --- a/source/source_pw/module_pwdft/fs_kin_tools.cpp +++ b/source/source_pw/module_pwdft/fs_kin_tools.cpp @@ -10,7 +10,7 @@ FS_Kin_tools::FS_Kin_tools(const UnitCell& ucell_in, const ModuleBase::matrix& wg) : ucell_(ucell_in), nksbands_(wg.nc), wg(wg.c), wk(p_kv->wk.data()) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->wfc_basis_ = wfc_basis_in; const int npwk_max = this->wfc_basis_->npwk_max; const int nks = this->wfc_basis_->nks; diff --git a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp index 934d3c476d..64f4015daf 100644 --- a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp +++ b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp @@ -26,7 +26,7 @@ FS_Nonlocal_tools::FS_Nonlocal_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nkb = nlpp_->nkb; this->max_npw = wfc_basis_->npwk_max; this->ntype = ucell_->ntype; diff --git a/source/source_pw/module_pwdft/nonlocal_maths.hpp b/source/source_pw/module_pwdft/nonlocal_maths.hpp index 3e09675bcb..3a8b133cb8 100644 --- a/source/source_pw/module_pwdft/nonlocal_maths.hpp +++ b/source/source_pw/module_pwdft/nonlocal_maths.hpp @@ -18,14 +18,14 @@ class Nonlocal_maths public: Nonlocal_maths(const pseudopot_cell_vnl* nlpp_in, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nlpp_in->nhtol; this->lmax_ = nlpp_in->lmaxkb; this->ucell_ = ucell_in; } Nonlocal_maths(const ModuleBase::matrix& nhtol, const int lmax, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nhtol; this->lmax_ = lmax; this->ucell_ = ucell_in; diff --git a/source/source_pw/module_pwdft/onsite_proj_tools.cpp b/source/source_pw/module_pwdft/onsite_proj_tools.cpp index 509a65a6ab..aa6ed0f83f 100644 --- a/source/source_pw/module_pwdft/onsite_proj_tools.cpp +++ b/source/source_pw/module_pwdft/onsite_proj_tools.cpp @@ -24,7 +24,7 @@ Onsite_Proj_tools::Onsite_Proj_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), psi_(psi_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // seems kvec_c never used... this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); @@ -126,7 +126,7 @@ Onsite_Proj_tools::Onsite_Proj_tools( wfc_basis_ = wfc_basis_in; sf_ = sf_in; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); // skip deeq, qq_nt diff --git a/source/source_pw/module_pwdft/onsite_projector.cpp b/source/source_pw/module_pwdft/onsite_projector.cpp index f9b8ad7cbc..4353700d65 100644 --- a/source/source_pw/module_pwdft/onsite_projector.cpp +++ b/source/source_pw/module_pwdft/onsite_projector.cpp @@ -104,7 +104,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir const ModuleBase::matrix& wg, const ModuleBase::matrix& ekb) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if(!this->initialed) { diff --git a/source/source_pw/module_pwdft/op_pw_ekin.cpp b/source/source_pw/module_pwdft/op_pw_ekin.cpp index 05d28266fd..9c62204050 100644 --- a/source/source_pw/module_pwdft/op_pw_ekin.cpp +++ b/source/source_pw/module_pwdft/op_pw_ekin.cpp @@ -18,7 +18,7 @@ Ekinetic>::Ekinetic( this->gk2 = gk2_in; this->gk2_row = gk2_row; this->gk2_col = gk2_col; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Constuctor of Operator::EkineticPW is failed, please check your code!"); @@ -67,7 +67,7 @@ hamilt::Ekinetic>::Ekinetic(const Ekineticgk2 = ekinetic->get_gk2(); this->gk2_row = ekinetic->get_gk2_row(); this->gk2_col = ekinetic->get_gk2_col(); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Copy Constuctor of Operator::EkineticPW is failed, please check your code!"); } diff --git a/source/source_pw/module_pwdft/stress_cc.cpp b/source/source_pw/module_pwdft/stress_cc.cpp index 211d5a4bda..0607371074 100644 --- a/source/source_pw/module_pwdft/stress_cc.cpp +++ b/source/source_pw/module_pwdft/stress_cc.cpp @@ -230,7 +230,7 @@ void Stress_Func::deriv_drhoc double gx = 0.0; double rhocg1 = 0.0; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform diff --git a/source/source_pw/module_pwdft/stress_loc.cpp b/source/source_pw/module_pwdft/stress_loc.cpp index 0b932afcdb..f6cac47604 100644 --- a/source/source_pw/module_pwdft/stress_loc.cpp +++ b/source/source_pw/module_pwdft/stress_loc.cpp @@ -189,7 +189,7 @@ const UnitCell& ucell_in int igl0 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::vector gx_arr(rho_basis->ngg+1); double* gx_arr_d = nullptr; diff --git a/source/source_pw/module_pwdft/structure_factor_k.cpp b/source/source_pw/module_pwdft/structure_factor_k.cpp index 52dc326545..3ca4980c58 100644 --- a/source/source_pw/module_pwdft/structure_factor_k.cpp +++ b/source/source_pw/module_pwdft/structure_factor_k.cpp @@ -57,7 +57,7 @@ void Structure_Factor::get_sk(Device* ctx, ModuleBase::timer::tick("Structure_Factor", "get_sk"); base_device::DEVICE_CPU* cpu_ctx = {}; - base_device::AbacusDevice_t device = base_device::get_device_type(ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(ctx); using cal_sk_op = hamilt::cal_sk_op; using resmem_int_op = base_device::memory::resize_memory_op; using delmem_int_op = base_device::memory::delete_memory_op; diff --git a/source/source_pw/module_stodft/sto_forces.cpp b/source/source_pw/module_stodft/sto_forces.cpp index 4e57ae98c7..e349d3de2c 100644 --- a/source/source_pw/module_stodft/sto_forces.cpp +++ b/source/source_pw/module_stodft/sto_forces.cpp @@ -31,7 +31,7 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, { ModuleBase::timer::tick("Sto_Forces", "cal_force"); ModuleBase::TITLE("Sto_Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const Charge* chr = elec.charge; force.create(this->nat, 3); diff --git a/source/source_pw/module_stodft/sto_wf.cpp b/source/source_pw/module_stodft/sto_wf.cpp index 2ba8db2908..a0204e1f87 100644 --- a/source/source_pw/module_stodft/sto_wf.cpp +++ b/source/source_pw/module_stodft/sto_wf.cpp @@ -19,7 +19,7 @@ Stochastic_WF::~Stochastic_WF() { delete chi0_cpu; Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { delete chi0; } @@ -119,7 +119,7 @@ void Stochastic_WF::allocate_chi0() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -248,7 +248,7 @@ void Stochastic_WF::init_com_orbitals() delete[] totnpw; // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -280,7 +280,7 @@ void Stochastic_WF::init_com_orbitals() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -370,7 +370,7 @@ template void Stochastic_WF::sync_chi0() { Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { syncmem_h2d_op()(this->chi0->get_pointer(), this->chi0_cpu->get_pointer(), From f74b4b08e17cd0c1c575d3e52a78875ce30e380c Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 08:56:39 +0800 Subject: [PATCH 15/18] refactor(esolver): remove device member variable from ESolver_KS_PW - Modify copy_d2h to accept ctx parameter and call get_device_type internally - Remove device parameter from ctrl_scf_pw function - Remove device member variable from ESolver_KS_PW class - Simplify function interfaces by using automatic template deduction --- source/source_esolver/esolver_ks_pw.cpp | 3 +-- source/source_esolver/esolver_ks_pw.h | 3 --- source/source_io/module_ctrl/ctrl_output_pw.cpp | 7 +------ source/source_io/module_ctrl/ctrl_output_pw.h | 1 - source/source_psi/setup_psi_pw.cpp | 4 ++-- source/source_psi/setup_psi_pw.h | 2 +- 6 files changed, 5 insertions(+), 15 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index a876039457..6506432352 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,6 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->device = base_device::get_device_type(this->ctx); } template @@ -251,7 +250,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->stp, - this->ctx, this->device, this->Pgrid, PARAM.inp); + this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 6a6be52b73..323e2df5a2 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -60,9 +60,6 @@ class ESolver_KS_PW : public ESolver_KS // for get_pchg and get_wf, use ctx as input of fft Device* ctx = {}; - // for device to host data transformation - base_device::AbacusDevice_t device = {}; - }; } // namespace ModuleESolver #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 2f0c157a82..8e2f0c3918 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -92,7 +92,6 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -100,7 +99,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); // Transfer data from device (GPU) to host (CPU) in pw basis - stp.copy_d2h(device); + stp.copy_d2h(ctx); //---------------------------------------------------------- //! 4) Compute density of states (DOS) @@ -386,7 +385,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -403,7 +401,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -421,7 +418,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -438,7 +434,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 798629c55e..3ac7a2ab9c 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -31,7 +31,6 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, // mohan add 2025-10-15 const Parallel_Grid ¶_grid, const Input_para& inp); diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 14e564c4fb..c7428cfd7d 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -62,9 +62,9 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) // Transfer data from GPU to CPU in pw basis template -void Setup_Psi_pw::copy_d2h(const base_device::AbacusDevice_t &device) +void Setup_Psi_pw::copy_d2h(const Device* ctx) { - if (device == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 13bf593f37..1e79664e2b 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -59,7 +59,7 @@ class Setup_Psi_pw void update_psi_d(); // Transfer data from device to host in pw basis - void copy_d2h(const base_device::AbacusDevice_t &device); + void copy_d2h(const Device* ctx); void clean(); From dc9450ae713a43e0bd6cb48e5437b124cef5ef51 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 09:03:01 +0800 Subject: [PATCH 16/18] style(esolver): explicitly initialize ctx to nullptr in constructor --- source/source_esolver/esolver_ks_pw.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 6506432352..0ea4c31d1d 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,6 +43,7 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; + this->ctx = nullptr; } template From 3b108eb60472867871aee91ae911fe95e16a5800 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 14:16:15 +0800 Subject: [PATCH 17/18] feat(device): add runtime device type support to DeviceContext - Add device_type_ member variable to DeviceContext class - Add set_device_type() and get_device_type() methods - Add is_cpu(), is_gpu(), is_dsp() convenience methods - Add get_device_type(const DeviceContext*) overload for runtime device type query - Maintain backward compatibility with existing template-based get_device_type --- source/source_base/module_device/device.h | 41 +++++++++++++++++++ .../module_device/device_helpers.h | 12 +++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index afa55d5d6e..395d8c470d 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -145,6 +145,36 @@ class DeviceContext { */ int get_local_rank() const { return local_rank_; } + /** + * @brief Set the device type (CpuDevice, GpuDevice, or DspDevice) + * @param type The device type + */ + void set_device_type(AbacusDevice_t type) { device_type_ = type; } + + /** + * @brief Get the device type + * @return AbacusDevice_t The device type + */ + AbacusDevice_t get_device_type() const { return device_type_; } + + /** + * @brief Check if the device is CPU + * @return true if the device is CPU + */ + bool is_cpu() const { return device_type_ == CpuDevice; } + + /** + * @brief Check if the device is GPU + * @return true if the device is GPU + */ + bool is_gpu() const { return device_type_ == GpuDevice; } + + /** + * @brief Check if the device is DSP + * @return true if the device is DSP + */ + bool is_dsp() const { return device_type_ == DspDevice; } + // Disable copy and assignment DeviceContext(const DeviceContext&) = delete; DeviceContext& operator=(const DeviceContext&) = delete; @@ -158,10 +188,21 @@ class DeviceContext { int device_id_ = -1; int device_count_ = 0; int local_rank_ = 0; + AbacusDevice_t device_type_ = CpuDevice; std::mutex init_mutex_; }; +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx) +{ + return ctx->get_device_type(); +} + } // end of namespace base_device #endif // MODULE_DEVICE_H_ diff --git a/source/source_base/module_device/device_helpers.h b/source/source_base/module_device/device_helpers.h index 6aa71938de..2870eea2d7 100644 --- a/source/source_base/module_device/device_helpers.h +++ b/source/source_base/module_device/device_helpers.h @@ -18,8 +18,18 @@ namespace base_device { +// Forward declaration +class DeviceContext; + +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx); + /** - * @brief Get the device type enum for a given device type. + * @brief Get the device type enum for a given device type (compile-time version). * @tparam Device The device type (DEVICE_CPU or DEVICE_GPU) * @param dev Pointer to device (used for template deduction) * @return AbacusDevice_t enum value From 77d081f7f3b3c25b10f4ab493035e1c21d965be9 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 12 Mar 2026 16:23:35 +0800 Subject: [PATCH 18/18] feat(device): add runtime device context overloads for gradual migration - Add copy_d2h(const DeviceContext*) overload to Setup_Psi_pw - Add ctrl_scf_pw(..., const DeviceContext*, ...) overload - Add ctrl_runner_pw(..., const DeviceContext*, ...) overload - Keep original functions for backward compatibility - Replace tabs with spaces in modified files --- source/source_io/module_ctrl/ctrl_output_pw.h | 59 +++++++++++++++---- source/source_psi/setup_psi_pw.cpp | 35 ++++++++--- source/source_psi/setup_psi_pw.h | 15 +++-- 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 3ac7a2ab9c..262fc782a8 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -11,11 +11,11 @@ namespace ModuleIO // print out information in 'iter_finish' in ESolver_KS_PW void ctrl_iter_pw(const int istep, - const int iter, - const double &conv_esolver, - psi::Psi, base_device::DEVICE_CPU>* psi, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, const Input_para& inp); // print out information in 'after_scf' in ESolver_KS_PW @@ -24,32 +24,65 @@ void ctrl_scf_pw(const int istep, UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, - const ModulePW::PW_Basis *pw_rho, - const ModulePW::PW_Basis *pw_rhod, - const ModulePW::PW_Basis_Big *pw_big, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_scf' in ESolver_KS_PW (runtime version) +template +void ctrl_scf_pw(const int istep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + Setup_Psi_pw &stp, + const base_device::DeviceContext* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + // print out information in 'after_all_runners' in ESolver_KS_PW template void ctrl_runner_pw(UnitCell& ucell, - elecstate::ElecState* pelec, + elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, - Charge &chr, + Charge &chr, K_Vectors &kv, Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, - surchem &solvent, + surchem &solvent, const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_all_runners' in ESolver_KS_PW (runtime version) +template +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + Setup_Psi_pw &stp, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DeviceContext* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + } #endif diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index c7428cfd7d..11b02b8512 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -9,12 +9,12 @@ Setup_Psi_pw::~Setup_Psi_pw(){} template void Setup_Psi_pw::before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp) + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) { //! Allocate and initialize psi this->p_psi_init = new psi::PSIPrepare(inp.init_wfc, @@ -70,10 +70,27 @@ void Setup_Psi_pw::copy_d2h(const Device* ctx) this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), this->psi_cpu[0].size()); } - else - { + else + { + // do nothing + } + return; +} + +// Transfer data from GPU to CPU in pw basis (runtime version) +template +void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +{ + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); + } + else + { // do nothing - } + } return; } diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 1e79664e2b..6e7a42467d 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -47,12 +47,12 @@ class Setup_Psi_pw //------------ void before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp); + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); void init(hamilt::Hamilt* p_hamilt); @@ -60,6 +60,9 @@ class Setup_Psi_pw // Transfer data from device to host in pw basis void copy_d2h(const Device* ctx); + + // Transfer data from device to host in pw basis (runtime version) + void copy_d2h(const base_device::DeviceContext* ctx); void clean();