Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
fe1d82a
small format changes
Mar 10, 2026
1a725f6
refactor(esolver): extract charge density symmetrization to Symmetry_…
Mar 10, 2026
285ee1c
refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao mo…
Mar 10, 2026
2a520e3
refactor(esolver): complete DeltaSpin refactoring in LCAO
Mar 10, 2026
91be943
refactor(esolver): extract DFT+U code to dftu_lcao module
Mar 10, 2026
5a96483
refactor(esolver): extract diagonalization parameters setup to hsolve…
Mar 10, 2026
4936dc1
fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper
Mar 10, 2026
ea218f6
fix(deltaspin): add #ifdef __LCAO for conditional compilation
Mar 10, 2026
c365a3a
refactor(esolver): extract SDFT diagonalization parameters setup
Mar 10, 2026
cbd0ce7
refactor(hamilt): introduce HamiltBase non-template base class
Mar 11, 2026
6e0f43c
refactor(esolver): add static_cast for p_hamilt in esolver files
Mar 11, 2026
14c1b8a
refactor(esolver): remove psi member from ESolver_KS base class
Mar 11, 2026
0c2fa0f
refactor(esolver): remove template parameters from ESolver_KS base class
Mar 11, 2026
4b51e36
refactor(device): remove explicit template parameter from get_device_…
Mar 12, 2026
f74b4b0
refactor(esolver): remove device member variable from ESolver_KS_PW
Mar 12, 2026
dc9450a
style(esolver): explicitly initialize ctx to nullptr in constructor
Mar 12, 2026
3b108eb
feat(device): add runtime device type support to DeviceContext
Mar 12, 2026
77d081f
feat(device): add runtime device context overloads for gradual migration
Mar 12, 2026
6da8970
Merge branch 'develop' into remove_template4
mohanchen Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions source/source_base/math_chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Chebyshev<REAL, Device>::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord
}
coefr_cpu = new REAL[norder];
coefc_cpu = new std::complex<REAL>[norder];
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
resmem_var_op()(this->coef_real, norder);
resmem_complex_op()(this->coef_complex, norder);
Expand All @@ -82,7 +82,7 @@ template <typename REAL, typename Device>
Chebyshev<REAL, Device>::~Chebyshev()
{
delete[] polytrace;
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
delmem_var_op()(this->coef_real);
delmem_complex_op()(this->coef_complex);
Expand Down Expand Up @@ -209,7 +209,7 @@ void Chebyshev<REAL, Device>::calcoef_real(std::function<REAL(REAL)> fun)
}
}

if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
syncmem_var_h2d_op()(coef_real, coefr_cpu, norder);
}
Expand Down Expand Up @@ -299,7 +299,7 @@ void Chebyshev<REAL, Device>::calcoef_complex(std::function<std::complex<REAL>(s
coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3);
}
}
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
}
Expand Down Expand Up @@ -390,7 +390,7 @@ void Chebyshev<REAL, Device>::calcoef_pair(std::function<REAL(REAL)> fun1, std::
}
}

if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
}
Expand Down Expand Up @@ -684,7 +684,7 @@ bool Chebyshev<REAL, Device>::checkconverge(
funA(arrayn_1, arrayn, 1);
REAL sum1, sum2;
REAL t;
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
sum1 = this->ddot_real(arrayn_1, arrayn_1, N);
sum2 = this->ddot_real(arrayn_1, arrayn, N);
Expand Down Expand Up @@ -714,7 +714,7 @@ bool Chebyshev<REAL, Device>::checkconverge(
for (int ior = 2; ior < norder; ++ior)
{
funA(arrayn, arraynp1, 1);
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
{
sum1 = this->ddot_real(arrayn, arrayn, N);
sum2 = this->ddot_real(arrayn, arraynp1, N);
Expand Down
41 changes: 41 additions & 0 deletions source/source_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ class DeviceContext {
*/
int get_local_rank() const { return local_rank_; }

/**
* @brief Set the device type (CpuDevice, GpuDevice, or DspDevice)
* @param type The device type
*/
void set_device_type(AbacusDevice_t type) { device_type_ = type; }

/**
* @brief Get the device type
* @return AbacusDevice_t The device type
*/
AbacusDevice_t get_device_type() const { return device_type_; }

/**
* @brief Check if the device is CPU
* @return true if the device is CPU
*/
bool is_cpu() const { return device_type_ == CpuDevice; }

/**
* @brief Check if the device is GPU
* @return true if the device is GPU
*/
bool is_gpu() const { return device_type_ == GpuDevice; }

/**
* @brief Check if the device is DSP
* @return true if the device is DSP
*/
bool is_dsp() const { return device_type_ == DspDevice; }

// Disable copy and assignment
DeviceContext(const DeviceContext&) = delete;
DeviceContext& operator=(const DeviceContext&) = delete;
Expand All @@ -158,10 +188,21 @@ class DeviceContext {
int device_id_ = -1;
int device_count_ = 0;
int local_rank_ = 0;
AbacusDevice_t device_type_ = CpuDevice;

std::mutex init_mutex_;
};

/**
* @brief Get the device type enum from DeviceContext (runtime version).
* @param ctx Pointer to DeviceContext
* @return AbacusDevice_t enum value
*/
inline AbacusDevice_t get_device_type(const DeviceContext* ctx)
{
return ctx->get_device_type();
}

} // end of namespace base_device

#endif // MODULE_DEVICE_H_
13 changes: 0 additions & 13 deletions source/source_base/module_device/device_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,6 @@
namespace base_device
{

// Device type specializations
template <>
AbacusDevice_t get_device_type<DEVICE_CPU>(const DEVICE_CPU* dev)
{
return CpuDevice;
}

template <>
AbacusDevice_t get_device_type<DEVICE_GPU>(const DEVICE_GPU* dev)
{
return GpuDevice;
}

// Precision specializations
template <>
std::string get_current_precision<float>(const float* var)
Expand Down
28 changes: 19 additions & 9 deletions source/source_base/module_device/device_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,35 @@
#include "types.h"
#include <complex>
#include <string>
#include <type_traits>

namespace base_device
{

// Forward declaration
class DeviceContext;

/**
* @brief Get the device type enum from DeviceContext (runtime version).
* @param ctx Pointer to DeviceContext
* @return AbacusDevice_t enum value
*/
inline AbacusDevice_t get_device_type(const DeviceContext* ctx);

/**
* @brief Get the device type enum for a given device type.
* @brief Get the device type enum for a given device type (compile-time version).
* @tparam Device The device type (DEVICE_CPU or DEVICE_GPU)
* @param dev Pointer to device (used for template deduction)
* @return AbacusDevice_t enum value
*/
template <typename Device>
AbacusDevice_t get_device_type(const Device* dev);

// Template specialization declarations
template <>
AbacusDevice_t get_device_type<DEVICE_CPU>(const DEVICE_CPU* dev);

template <>
AbacusDevice_t get_device_type<DEVICE_GPU>(const DEVICE_GPU* dev);
AbacusDevice_t get_device_type(const Device* dev)
{
if (std::is_same<Device, DEVICE_CPU>::value) return CpuDevice;
else if (std::is_same<Device, DEVICE_GPU>::value) return GpuDevice;
else if (std::is_same<Device, DEVICE_DSP>::value) return DspDevice;
else return UnKnown;
}

/**
* @brief Get the precision string for a given numeric type.
Expand Down
4 changes: 2 additions & 2 deletions source/source_base/module_device/test/device_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class TestModulePsiDevice : public ::testing::Test

TEST_F(TestModulePsiDevice, get_device_type_cpu)
{
base_device::AbacusDevice_t device = base_device::get_device_type<base_device::DEVICE_CPU>(cpu_ctx);
base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx);
EXPECT_EQ(device, base_device::CpuDevice);
}

#if __UT_USE_CUDA || __UT_USE_ROCM
TEST_F(TestModulePsiDevice, get_device_type_gpu)
{
base_device::AbacusDevice_t device = base_device::get_device_type<base_device::DEVICE_GPU>(gpu_ctx);
base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx);
EXPECT_EQ(device, base_device::GpuDevice);
}
#endif // __UT_USE_CUDA || __UT_USE_ROCM
4 changes: 2 additions & 2 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
{
this->classname = "ESolver_KS_PW";
this->basisname = "PW";
this->device = base_device::get_device_type<Device>(this->ctx);
this->ctx = nullptr;
}

template <typename T, typename Device>
Expand Down Expand Up @@ -251,7 +251,7 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
// Output quantities
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
this->ctx, this->device, this->Pgrid, PARAM.inp);
this->ctx, this->Pgrid, PARAM.inp);

ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
}
Expand Down
3 changes: 0 additions & 3 deletions source/source_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class ESolver_KS_PW : public ESolver_KS
// for get_pchg and get_wf, use ctx as input of fft
Device* ctx = {};

// for device to host data transformation
base_device::AbacusDevice_t device = {};

};
} // namespace ModuleESolver
#endif
2 changes: 1 addition & 1 deletion source/source_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in),
diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);
this->device = base_device::get_device_type(this->ctx);

this->one = &one_;
this->zero = &zero_;
Expand Down
2 changes: 1 addition & 1 deletion source/source_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
const diag_comm_info& diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);
this->device = base_device::get_device_type(this->ctx);
this->precondition = precondition_in;

this->one = &one_;
Expand Down
8 changes: 4 additions & 4 deletions source/source_hsolver/diago_iter_assist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,14 @@ void DiagoIterAssist<T, Device>::diag_heevx(const int matrix_size,
// (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors);
heevx_op<T, Device>()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
{
#if ((defined __CUDA) || (defined __ROCM))
// eigenvalues to e, from device to host
syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs);
#endif
}
else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
{
// eigenvalues to e
syncmem_var_op()(e, eigenvalues, num_eigenpairs);
Expand Down Expand Up @@ -436,14 +436,14 @@ void DiagoIterAssist<T, Device>::diag_hegvd(const int nstart,

hegvd_op<T, Device>()(ctx, nstart, ldh, hcc, scc, eigenvalues, vcc);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
{
#if ((defined __CUDA) || (defined __ROCM))
// set eigenvalues in GPU to e in CPU
syncmem_var_d2h_op()(e, eigenvalues, nbands);
#endif
}
else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
{
// set eigenvalues in CPU to e in CPU
syncmem_var_op()(e, eigenvalues, nbands);
Expand Down
2 changes: 1 addition & 1 deletion source/source_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) {
this->device = base_device::get_device_type<Device>(this->ctx);
this->device = base_device::get_device_type(this->ctx);
this->precondition = precondition_in;

test_david = 2;
Expand Down
7 changes: 1 addition & 6 deletions source/source_io/module_ctrl/ctrl_output_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,14 @@ void ModuleIO::ctrl_scf_pw(const int istep,
const ModulePW::PW_Basis_Big *pw_big,
Setup_Psi_pw<T, Device> &stp,
const Device* ctx,
const base_device::AbacusDevice_t &device,
const Parallel_Grid &para_grid,
const Input_para& inp)
{
ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw");
ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw");

// Transfer data from device (GPU) to host (CPU) in pw basis
stp.copy_d2h(device);
stp.copy_d2h(ctx);

//----------------------------------------------------------
//! 4) Compute density of states (DOS)
Expand Down Expand Up @@ -386,7 +385,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_CPU
const ModulePW::PW_Basis_Big *pw_big,
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_CPU> &stp,
const base_device::DEVICE_CPU* ctx,
const base_device::AbacusDevice_t &device,
const Parallel_Grid &para_grid,
const Input_para& inp);

Expand All @@ -403,7 +401,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_CP
const ModulePW::PW_Basis_Big *pw_big,
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_CPU> &stp,
const base_device::DEVICE_CPU* ctx,
const base_device::AbacusDevice_t &device,
const Parallel_Grid &para_grid,
const Input_para& inp);

Expand All @@ -421,7 +418,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_GPU
const ModulePW::PW_Basis_Big *pw_big,
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_GPU> &stp,
const base_device::DEVICE_GPU* ctx,
const base_device::AbacusDevice_t &device,
const Parallel_Grid &para_grid,
const Input_para& inp);

Expand All @@ -438,7 +434,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_GP
const ModulePW::PW_Basis_Big *pw_big,
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_GPU> &stp,
const base_device::DEVICE_GPU* ctx,
const base_device::AbacusDevice_t &device,
const Parallel_Grid &para_grid,
const Input_para& inp);
#endif
Expand Down
Loading
Loading