diff --git a/source/source_base/math_chebyshev.cpp b/source/source_base/math_chebyshev.cpp index 8a84686ea5..b7e59a89f9 100644 --- a/source/source_base/math_chebyshev.cpp +++ b/source/source_base/math_chebyshev.cpp @@ -61,7 +61,7 @@ Chebyshev::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord } coefr_cpu = new REAL[norder]; coefc_cpu = new std::complex[norder]; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { resmem_var_op()(this->coef_real, norder); resmem_complex_op()(this->coef_complex, norder); @@ -82,7 +82,7 @@ template Chebyshev::~Chebyshev() { delete[] polytrace; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { delmem_var_op()(this->coef_real); delmem_complex_op()(this->coef_complex); @@ -209,7 +209,7 @@ void Chebyshev::calcoef_real(std::function fun) } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_var_h2d_op()(coef_real, coefr_cpu, norder); } @@ -299,7 +299,7 @@ void Chebyshev::calcoef_complex(std::function(s coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3); } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -390,7 +390,7 @@ void Chebyshev::calcoef_pair(std::function fun1, std:: } } - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder); } @@ -684,7 +684,7 @@ bool Chebyshev::checkconverge( funA(arrayn_1, arrayn, 1); REAL sum1, sum2; REAL t; - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn_1, arrayn_1, N); sum2 = this->ddot_real(arrayn_1, arrayn, N); @@ -714,7 +714,7 @@ bool Chebyshev::checkconverge( for (int ior = 2; ior < norder; ++ior) { funA(arrayn, arraynp1, 1); - if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) + if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { sum1 = this->ddot_real(arrayn, arrayn, N); sum2 = this->ddot_real(arrayn, arraynp1, N); diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index afa55d5d6e..395d8c470d 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -145,6 +145,36 @@ class DeviceContext { */ int get_local_rank() const { return local_rank_; } + /** + * @brief Set the device type (CpuDevice, GpuDevice, or DspDevice) + * @param type The device type + */ + void set_device_type(AbacusDevice_t type) { device_type_ = type; } + + /** + * @brief Get the device type + * @return AbacusDevice_t The device type + */ + AbacusDevice_t get_device_type() const { return device_type_; } + + /** + * @brief Check if the device is CPU + * @return true if the device is CPU + */ + bool is_cpu() const { return device_type_ == CpuDevice; } + + /** + * @brief Check if the device is GPU + * @return true if the device is GPU + */ + bool is_gpu() const { return device_type_ == GpuDevice; } + + /** + * @brief Check if the device is DSP + * @return true if the device is DSP + */ + bool is_dsp() const { return device_type_ == DspDevice; } + // Disable copy and assignment DeviceContext(const DeviceContext&) = delete; DeviceContext& operator=(const DeviceContext&) = delete; @@ -158,10 +188,21 @@ class DeviceContext { int device_id_ = -1; int device_count_ = 0; int local_rank_ = 0; + AbacusDevice_t device_type_ = CpuDevice; std::mutex init_mutex_; }; +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx) +{ + return ctx->get_device_type(); +} + } // end of namespace base_device #endif // MODULE_DEVICE_H_ diff --git a/source/source_base/module_device/device_helpers.cpp b/source/source_base/module_device/device_helpers.cpp index 1c53020718..0b5d5a1693 100644 --- a/source/source_base/module_device/device_helpers.cpp +++ b/source/source_base/module_device/device_helpers.cpp @@ -3,19 +3,6 @@ namespace base_device { -// Device type specializations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev) -{ - return CpuDevice; -} - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev) -{ - return GpuDevice; -} - // Precision specializations template <> std::string get_current_precision(const float* var) diff --git a/source/source_base/module_device/device_helpers.h b/source/source_base/module_device/device_helpers.h index 60eddd888d..2870eea2d7 100644 --- a/source/source_base/module_device/device_helpers.h +++ b/source/source_base/module_device/device_helpers.h @@ -13,25 +13,35 @@ #include "types.h" #include #include +#include namespace base_device { +// Forward declaration +class DeviceContext; + +/** + * @brief Get the device type enum from DeviceContext (runtime version). + * @param ctx Pointer to DeviceContext + * @return AbacusDevice_t enum value + */ +inline AbacusDevice_t get_device_type(const DeviceContext* ctx); + /** - * @brief Get the device type enum for a given device type. + * @brief Get the device type enum for a given device type (compile-time version). * @tparam Device The device type (DEVICE_CPU or DEVICE_GPU) * @param dev Pointer to device (used for template deduction) * @return AbacusDevice_t enum value */ template -AbacusDevice_t get_device_type(const Device* dev); - -// Template specialization declarations -template <> -AbacusDevice_t get_device_type(const DEVICE_CPU* dev); - -template <> -AbacusDevice_t get_device_type(const DEVICE_GPU* dev); +AbacusDevice_t get_device_type(const Device* dev) +{ + if (std::is_same::value) return CpuDevice; + else if (std::is_same::value) return GpuDevice; + else if (std::is_same::value) return DspDevice; + else return UnKnown; +} /** * @brief Get the precision string for a given numeric type. diff --git a/source/source_base/module_device/test/device_test.cpp b/source/source_base/module_device/test/device_test.cpp index 02d485c8ef..faf083c721 100644 --- a/source/source_base/module_device/test/device_test.cpp +++ b/source/source_base/module_device/test/device_test.cpp @@ -20,14 +20,14 @@ class TestModulePsiDevice : public ::testing::Test TEST_F(TestModulePsiDevice, get_device_type_cpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx); EXPECT_EQ(device, base_device::CpuDevice); } #if __UT_USE_CUDA || __UT_USE_ROCM TEST_F(TestModulePsiDevice, get_device_type_gpu) { - base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx); EXPECT_EQ(device, base_device::GpuDevice); } #endif // __UT_USE_CUDA || __UT_USE_ROCM diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 74507e09c7..0ea4c31d1d 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -43,7 +43,7 @@ ESolver_KS_PW::ESolver_KS_PW() { this->classname = "ESolver_KS_PW"; this->basisname = "PW"; - this->device = base_device::get_device_type(this->ctx); + this->ctx = nullptr; } template @@ -251,7 +251,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->stp, - this->ctx, this->device, this->Pgrid, PARAM.inp); + this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 6a6be52b73..323e2df5a2 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -60,9 +60,6 @@ class ESolver_KS_PW : public ESolver_KS // for get_pchg and get_wf, use ctx as input of fft Device* ctx = {}; - // for device to host data transformation - base_device::AbacusDevice_t device = {}; - }; } // namespace ModuleESolver #endif diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 27c6a5b348..4ff93d03e9 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -36,7 +36,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in), diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->one = &one_; this->zero = &zero_; diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index ef4ba67cf3..49d5d0d953 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -20,7 +20,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; this->one = &one_; diff --git a/source/source_hsolver/diago_iter_assist.cpp b/source/source_hsolver/diago_iter_assist.cpp index fb87ad2350..8c5673c37a 100644 --- a/source/source_hsolver/diago_iter_assist.cpp +++ b/source/source_hsolver/diago_iter_assist.cpp @@ -400,14 +400,14 @@ void DiagoIterAssist::diag_heevx(const int matrix_size, // (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors); heevx_op()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // eigenvalues to e, from device to host syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // eigenvalues to e syncmem_var_op()(e, eigenvalues, num_eigenpairs); @@ -436,14 +436,14 @@ void DiagoIterAssist::diag_hegvd(const int nstart, hegvd_op()(ctx, nstart, ldh, hcc, scc, eigenvalues, vcc); - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { #if ((defined __CUDA) || (defined __ROCM)) // set eigenvalues in GPU to e in CPU syncmem_var_d2h_op()(e, eigenvalues, nbands); #endif } - else if (base_device::get_device_type(ctx) == base_device::CpuDevice) + else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { // set eigenvalues in CPU to e in CPU syncmem_var_op()(e, eigenvalues, nbands); diff --git a/source/source_hsolver/test/hsolver_pw_sup.h b/source/source_hsolver/test/hsolver_pw_sup.h index a5aab01735..5f5108c627 100644 --- a/source/source_hsolver/test/hsolver_pw_sup.h +++ b/source/source_hsolver/test/hsolver_pw_sup.h @@ -126,7 +126,7 @@ DiagoDavid::DiagoDavid(const Real* precondition_in, const bool use_paw_in, const diag_comm_info& diag_comm_in) : nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->precondition = precondition_in; test_david = 2; diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index 2f0c157a82..8e2f0c3918 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -92,7 +92,6 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -100,7 +99,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); // Transfer data from device (GPU) to host (CPU) in pw basis - stp.copy_d2h(device); + stp.copy_d2h(ctx); //---------------------------------------------------------- //! 4) Compute density of states (DOS) @@ -386,7 +385,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -403,7 +401,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -421,7 +418,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -438,7 +434,6 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, - const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/module_ctrl/ctrl_output_pw.h b/source/source_io/module_ctrl/ctrl_output_pw.h index 798629c55e..262fc782a8 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.h +++ b/source/source_io/module_ctrl/ctrl_output_pw.h @@ -11,11 +11,11 @@ namespace ModuleIO // print out information in 'iter_finish' in ESolver_KS_PW void ctrl_iter_pw(const int istep, - const int iter, - const double &conv_esolver, - psi::Psi, base_device::DEVICE_CPU>* psi, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, const Input_para& inp); // print out information in 'after_scf' in ESolver_KS_PW @@ -24,33 +24,65 @@ void ctrl_scf_pw(const int istep, UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, - const K_Vectors &kv, - const ModulePW::PW_Basis_K *pw_wfc, - const ModulePW::PW_Basis *pw_rho, - const ModulePW::PW_Basis *pw_rhod, - const ModulePW::PW_Basis_Big *pw_big, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, Setup_Psi_pw &stp, const Device* ctx, - const base_device::AbacusDevice_t &device, // mohan add 2025-10-15 + const Parallel_Grid ¶_grid, + const Input_para& inp); + +// print out information in 'after_scf' in ESolver_KS_PW (runtime version) +template +void ctrl_scf_pw(const int istep, + UnitCell& ucell, + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + Setup_Psi_pw &stp, + const base_device::DeviceContext* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); // print out information in 'after_all_runners' in ESolver_KS_PW template void ctrl_runner_pw(UnitCell& ucell, - elecstate::ElecState* pelec, + elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, - Charge &chr, + Charge &chr, K_Vectors &kv, Setup_Psi_pw &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, - surchem &solvent, + surchem &solvent, const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_all_runners' in ESolver_KS_PW (runtime version) +template +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + K_Vectors &kv, + Setup_Psi_pw &stp, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, + const base_device::DeviceContext* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + } #endif diff --git a/source/source_psi/setup_psi_pw.cpp b/source/source_psi/setup_psi_pw.cpp index 14e564c4fb..11b02b8512 100644 --- a/source/source_psi/setup_psi_pw.cpp +++ b/source/source_psi/setup_psi_pw.cpp @@ -9,12 +9,12 @@ Setup_Psi_pw::~Setup_Psi_pw(){} template void Setup_Psi_pw::before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp) + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) { //! Allocate and initialize psi this->p_psi_init = new psi::PSIPrepare(inp.init_wfc, @@ -62,18 +62,35 @@ void Setup_Psi_pw::init(hamilt::Hamilt* p_hamilt) // Transfer data from GPU to CPU in pw basis template -void Setup_Psi_pw::copy_d2h(const base_device::AbacusDevice_t &device) +void Setup_Psi_pw::copy_d2h(const Device* ctx) { - if (device == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), this->psi_cpu[0].size()); } - else - { + else + { // do nothing - } + } + return; +} + +// Transfer data from GPU to CPU in pw basis (runtime version) +template +void Setup_Psi_pw::copy_d2h(const base_device::DeviceContext* ctx) +{ + if (base_device::get_device_type(ctx) == base_device::GpuDevice) + { + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); + } + else + { + // do nothing + } return; } diff --git a/source/source_psi/setup_psi_pw.h b/source/source_psi/setup_psi_pw.h index 13bf593f37..6e7a42467d 100644 --- a/source/source_psi/setup_psi_pw.h +++ b/source/source_psi/setup_psi_pw.h @@ -47,19 +47,22 @@ class Setup_Psi_pw //------------ void before_runner( - const UnitCell &ucell, - const K_Vectors &kv, - const Structure_Factor &sf, - const ModulePW::PW_Basis_K &pw_wfc, - const pseudopot_cell_vnl &ppcell, - const Input_para &inp); + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); void init(hamilt::Hamilt* p_hamilt); void update_psi_d(); // Transfer data from device to host in pw basis - void copy_d2h(const base_device::AbacusDevice_t &device); + void copy_d2h(const Device* ctx); + + // Transfer data from device to host in pw basis (runtime version) + void copy_d2h(const base_device::DeviceContext* ctx); void clean(); diff --git a/source/source_pw/module_pwdft/forces.cpp b/source/source_pw/module_pwdft/forces.cpp index a6894c49ca..3e58b737ba 100644 --- a/source/source_pw/module_pwdft/forces.cpp +++ b/source/source_pw/module_pwdft/forces.cpp @@ -38,7 +38,7 @@ void Forces::cal_force(UnitCell& ucell, { ModuleBase::timer::tick("Forces", "cal_force"); ModuleBase::TITLE("Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const ModuleBase::matrix& ekb = elec.ekb; const Charge* const chr = elec.charge; @@ -331,7 +331,7 @@ void Forces::cal_force_loc(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_loc"); ModuleBase::timer::tick("Forces", "cal_force_loc"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::complex* aux = new std::complex[rho_basis->nmaxgr]; // now, in all pools , the charge are the same, // so, the force calculated by each pool is equal. @@ -478,7 +478,7 @@ void Forces::cal_force_ew(const UnitCell& ucell, { ModuleBase::TITLE("Forces", "cal_force_ew"); ModuleBase::timer::tick("Forces", "cal_force_ew"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); double fact = 2.0; std::vector> aux(rho_basis->npw); diff --git a/source/source_pw/module_pwdft/forces_cc.cpp b/source/source_pw/module_pwdft/forces_cc.cpp index 7788ed2af5..917e00b83c 100644 --- a/source/source_pw/module_pwdft/forces_cc.cpp +++ b/source/source_pw/module_pwdft/forces_cc.cpp @@ -116,7 +116,7 @@ void Forces::cal_force_cc(ModuleBase::matrix& forcecc, double *force_d = nullptr; double *rhocgigg_vec_d = nullptr; std::complex* psiv_d = nullptr; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); for (int ig = 0; ig < rho_basis->npw; ig++) @@ -258,7 +258,7 @@ void Forces::deriv_drhoc double gx = 0, rhocg1 = 0; //double *aux = new double[mesh]; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform // auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/forces_scc.cpp b/source/source_pw/module_pwdft/forces_scc.cpp index 7134232416..5e00b87dca 100644 --- a/source/source_pw/module_pwdft/forces_scc.cpp +++ b/source/source_pw/module_pwdft/forces_scc.cpp @@ -152,7 +152,7 @@ void Forces::deriv_drhoc_scc(const bool& numeric, int igl0 = 0; double gx = 0; double rhocg1 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); /// the modulus of g for a given shell /// the fourier transform /// auxiliary memory for integration diff --git a/source/source_pw/module_pwdft/fs_kin_tools.cpp b/source/source_pw/module_pwdft/fs_kin_tools.cpp index 853ae34abd..0c04b26f2a 100644 --- a/source/source_pw/module_pwdft/fs_kin_tools.cpp +++ b/source/source_pw/module_pwdft/fs_kin_tools.cpp @@ -10,7 +10,7 @@ FS_Kin_tools::FS_Kin_tools(const UnitCell& ucell_in, const ModuleBase::matrix& wg) : ucell_(ucell_in), nksbands_(wg.nc), wg(wg.c), wk(p_kv->wk.data()) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->wfc_basis_ = wfc_basis_in; const int npwk_max = this->wfc_basis_->npwk_max; const int nks = this->wfc_basis_->nks; diff --git a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp index 934d3c476d..64f4015daf 100644 --- a/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp +++ b/source/source_pw/module_pwdft/fs_nonlocal_tools.cpp @@ -26,7 +26,7 @@ FS_Nonlocal_tools::FS_Nonlocal_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nkb = nlpp_->nkb; this->max_npw = wfc_basis_->npwk_max; this->ntype = ucell_->ntype; diff --git a/source/source_pw/module_pwdft/nonlocal_maths.hpp b/source/source_pw/module_pwdft/nonlocal_maths.hpp index 3e09675bcb..3a8b133cb8 100644 --- a/source/source_pw/module_pwdft/nonlocal_maths.hpp +++ b/source/source_pw/module_pwdft/nonlocal_maths.hpp @@ -18,14 +18,14 @@ class Nonlocal_maths public: Nonlocal_maths(const pseudopot_cell_vnl* nlpp_in, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nlpp_in->nhtol; this->lmax_ = nlpp_in->lmaxkb; this->ucell_ = ucell_in; } Nonlocal_maths(const ModuleBase::matrix& nhtol, const int lmax, const UnitCell* ucell_in) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->nhtol_ = nhtol; this->lmax_ = lmax; this->ucell_ = ucell_in; diff --git a/source/source_pw/module_pwdft/onsite_proj_tools.cpp b/source/source_pw/module_pwdft/onsite_proj_tools.cpp index 509a65a6ab..aa6ed0f83f 100644 --- a/source/source_pw/module_pwdft/onsite_proj_tools.cpp +++ b/source/source_pw/module_pwdft/onsite_proj_tools.cpp @@ -24,7 +24,7 @@ Onsite_Proj_tools::Onsite_Proj_tools(const pseudopot_cell_vnl* n : nlpp_(nlpp_in), ucell_(ucell_in), psi_(psi_in), kv_(kv_in), wfc_basis_(wfc_basis_in), sf_(sf_in) { // get the device context - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // seems kvec_c never used... this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); @@ -126,7 +126,7 @@ Onsite_Proj_tools::Onsite_Proj_tools( wfc_basis_ = wfc_basis_in; sf_ = sf_in; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); this->kvec_c = this->wfc_basis_->template get_kvec_c_data(); // skip deeq, qq_nt diff --git a/source/source_pw/module_pwdft/onsite_projector.cpp b/source/source_pw/module_pwdft/onsite_projector.cpp index f9b8ad7cbc..4353700d65 100644 --- a/source/source_pw/module_pwdft/onsite_projector.cpp +++ b/source/source_pw/module_pwdft/onsite_projector.cpp @@ -104,7 +104,7 @@ void projectors::OnsiteProjector::init(const std::string& orbital_dir const ModuleBase::matrix& wg, const ModuleBase::matrix& ekb) { - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if(!this->initialed) { diff --git a/source/source_pw/module_pwdft/op_pw_ekin.cpp b/source/source_pw/module_pwdft/op_pw_ekin.cpp index 05d28266fd..9c62204050 100644 --- a/source/source_pw/module_pwdft/op_pw_ekin.cpp +++ b/source/source_pw/module_pwdft/op_pw_ekin.cpp @@ -18,7 +18,7 @@ Ekinetic>::Ekinetic( this->gk2 = gk2_in; this->gk2_row = gk2_row; this->gk2_col = gk2_col; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Constuctor of Operator::EkineticPW is failed, please check your code!"); @@ -67,7 +67,7 @@ hamilt::Ekinetic>::Ekinetic(const Ekineticgk2 = ekinetic->get_gk2(); this->gk2_row = ekinetic->get_gk2_row(); this->gk2_col = ekinetic->get_gk2_col(); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); if( this->tpiba2 < 1e-10 || this->gk2 == nullptr) { ModuleBase::WARNING_QUIT("EkineticPW", "Copy Constuctor of Operator::EkineticPW is failed, please check your code!"); } diff --git a/source/source_pw/module_pwdft/stress_cc.cpp b/source/source_pw/module_pwdft/stress_cc.cpp index 211d5a4bda..0607371074 100644 --- a/source/source_pw/module_pwdft/stress_cc.cpp +++ b/source/source_pw/module_pwdft/stress_cc.cpp @@ -230,7 +230,7 @@ void Stress_Func::deriv_drhoc double gx = 0.0; double rhocg1 = 0.0; std::vector aux(mesh); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); // the modulus of g for a given shell // the fourier transform diff --git a/source/source_pw/module_pwdft/stress_loc.cpp b/source/source_pw/module_pwdft/stress_loc.cpp index 0b932afcdb..f6cac47604 100644 --- a/source/source_pw/module_pwdft/stress_loc.cpp +++ b/source/source_pw/module_pwdft/stress_loc.cpp @@ -189,7 +189,7 @@ const UnitCell& ucell_in int igl0 = 0; - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); std::vector gx_arr(rho_basis->ngg+1); double* gx_arr_d = nullptr; diff --git a/source/source_pw/module_pwdft/structure_factor_k.cpp b/source/source_pw/module_pwdft/structure_factor_k.cpp index 52dc326545..3ca4980c58 100644 --- a/source/source_pw/module_pwdft/structure_factor_k.cpp +++ b/source/source_pw/module_pwdft/structure_factor_k.cpp @@ -57,7 +57,7 @@ void Structure_Factor::get_sk(Device* ctx, ModuleBase::timer::tick("Structure_Factor", "get_sk"); base_device::DEVICE_CPU* cpu_ctx = {}; - base_device::AbacusDevice_t device = base_device::get_device_type(ctx); + base_device::AbacusDevice_t device = base_device::get_device_type(ctx); using cal_sk_op = hamilt::cal_sk_op; using resmem_int_op = base_device::memory::resize_memory_op; using delmem_int_op = base_device::memory::delete_memory_op; diff --git a/source/source_pw/module_stodft/sto_forces.cpp b/source/source_pw/module_stodft/sto_forces.cpp index 4e57ae98c7..e349d3de2c 100644 --- a/source/source_pw/module_stodft/sto_forces.cpp +++ b/source/source_pw/module_stodft/sto_forces.cpp @@ -31,7 +31,7 @@ void Sto_Forces::cal_stoforce(ModuleBase::matrix& force, { ModuleBase::timer::tick("Sto_Forces", "cal_force"); ModuleBase::TITLE("Sto_Forces", "init"); - this->device = base_device::get_device_type(this->ctx); + this->device = base_device::get_device_type(this->ctx); const ModuleBase::matrix& wg = elec.wg; const Charge* chr = elec.charge; force.create(this->nat, 3); diff --git a/source/source_pw/module_stodft/sto_wf.cpp b/source/source_pw/module_stodft/sto_wf.cpp index 2ba8db2908..a0204e1f87 100644 --- a/source/source_pw/module_stodft/sto_wf.cpp +++ b/source/source_pw/module_stodft/sto_wf.cpp @@ -19,7 +19,7 @@ Stochastic_WF::~Stochastic_WF() { delete chi0_cpu; Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { delete chi0; } @@ -119,7 +119,7 @@ void Stochastic_WF::allocate_chi0() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -248,7 +248,7 @@ void Stochastic_WF::init_com_orbitals() delete[] totnpw; // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -280,7 +280,7 @@ void Stochastic_WF::init_com_orbitals() // allocate chi0 Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { this->chi0 = new psi::Psi(nks, this->nchip_max, npwx, this->ngk, true); } @@ -370,7 +370,7 @@ template void Stochastic_WF::sync_chi0() { Device* ctx = {}; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) + if (base_device::get_device_type(ctx) == base_device::GpuDevice) { syncmem_h2d_op()(this->chi0->get_pointer(), this->chi0_cpu->get_pointer(),