Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def optimizer_update_8bit_blockwise(
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )

with torch_accelerator_module.device(state1.device):
# Use g.device for device context: paged state tensors appear as CPU tensors
# but are backed by USM shared memory and accessible from the accelerator.
with torch_accelerator_module.device(g.device):
optimizer_update_8bit_blockwise_impl(
Comment on lines +238 to 241
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining this with the comment here. It's a good point!

optimizer_name=optimizer_name,
g=g,
Expand Down Expand Up @@ -279,7 +281,9 @@ def optimizer_update_32bit(
gnorm_scale: float,
skip_zeros=False,
) -> None:
with torch_accelerator_module.device(state1.device):
# Use g.device for device context: paged state tensors appear as CPU tensors
# but are backed by USM shared memory and accessible from the accelerator.
with torch_accelerator_module.device(g.device):
kernels_optim.optimizer_update_32bit_impl(
optimizer_name=optimizer_name,
g=g,
Expand Down
12 changes: 12 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def __init__(self, lib: ct.CDLL):
lib.cget_managed_ptr.restype = ct.c_void_p


class XpuBNBNativeLibrary(BNBNativeLibrary):
"""XPU native library with SYCL USM paged memory support."""

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
if hasattr(lib, "cget_managed_ptr"):
lib.cget_managed_ptr.restype = ct.c_void_p


def get_available_cuda_binary_versions() -> list[str]:
"""Get formatted CUDA versions from existing library files using cuda_specs logic"""
lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}"
Expand Down Expand Up @@ -312,6 +321,9 @@ def get_native_library() -> BNBNativeLibrary:
if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)

if torch._C._has_xpu:
return XpuBNBNativeLibrary(dll)

return BNBNativeLibrary(dll)


Expand Down
16 changes: 12 additions & 4 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def _cuda_device_of(a: torch.Tensor):

def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
managed_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(managed_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
out.is_paged = True
Expand Down Expand Up @@ -132,7 +132,10 @@ def elementwise_func(func_name, A, B, value, prefetch=True):
# if we return from this function, we want to the tensor
# to be in the correct state, that is the final state after the
# operation occurred. So we synchronize.
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()


def fill(A, value, device=None, prefetch=True):
Expand Down Expand Up @@ -384,7 +387,12 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
if tensor.device.type == "xpu":
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
if tensor.device.type == "cuda":
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
# For CPU tensors (e.g. paged optimizer states), use current device's stream.
if hasattr(torch, "xpu") and torch.xpu.is_available():
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(torch.xpu.current_device()))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(torch.cuda.current_device()))


def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
Expand Down
62 changes: 62 additions & 0 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,19 @@ void gemv_4bit_inference_fp32(

#endif

#if BUILD_XPU
// Helper: get default SYCL queue for XPU paged memory operations.
// SYCL USM (Unified Shared Memory) provides equivalent functionality to:
// - CUDA's cudaMallocManaged / Level Zero's zeMemAllocShared
// - CUDA's cudaMemPrefetchAsync / Level Zero's zeCommandListAppendMemoryPrefetch
// Level Zero has no equivalent to cudaPeekAtLastError; each L0 call returns ze_result_t.
// SYCL wraps L0 and uses exceptions for error reporting.
static sycl::queue& xpu_default_queue() {
static sycl::queue q{sycl::gpu_selector_v, sycl::property::queue::in_order{}};
return q;
}
#endif

extern "C" {
#if BUILD_CUDA || BUILD_HIP

Expand Down Expand Up @@ -687,6 +700,55 @@ void cgemv_4bit_inference_fp32(
gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

// XPU Paged Memory Support using SYCL USM (Unified Shared Memory)
// Equivalent CUDA APIs -> SYCL/Level Zero APIs:
// cudaMallocManaged -> sycl::malloc_shared / zeMemAllocShared
// cudaMemPrefetchAsync -> sycl::queue::prefetch / zeCommandListAppendMemoryPrefetch
// cudaPeekAtLastError -> N/A (SYCL uses exceptions; L0 returns ze_result_t per call)

void* cget_managed_ptr(size_t bytes) {
try {
auto& q = xpu_default_queue();
void* ptr = sycl::malloc_shared(bytes, q);
if (ptr == nullptr) {
fprintf(stderr, "XPU Error: sycl::malloc_shared returned nullptr for %zu bytes\n", bytes);
}
return ptr;
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU SYCL Error in cget_managed_ptr: %s\n", e.what());
return nullptr;
}
}

void cprefetch(void* ptr, size_t bytes, int device) {
// device == -1 means prefetch to host; for SYCL we skip in that case
// since SYCL prefetch targets the device associated with the queue.
if (device < 0)
return;
try {
auto& q = xpu_default_queue();
q.prefetch(ptr, bytes);
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU Warning: sycl::queue::prefetch failed: %s\n", e.what());
}
}

void cfill_fp32(float* A, float* B, float value, long n) {
try {
auto& q = xpu_default_queue();
q.fill(A, value, static_cast<size_t>(n)).wait();
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU Error in cfill_fp32: %s\n", e.what());
}
}

void cfill_uint8(unsigned char* A, unsigned char* B, unsigned char value, long n) {
// Use host-side memset instead of sycl::queue::fill<unsigned char>
// which segfaults on certain Intel GPU drivers (e.g. Max 1550).
// USM shared memory is host-accessible, so memset works directly.
memset(A, value, static_cast<size_t>(n));
}

#endif

void cquantize_blockwise_cpu_fp32(
Expand Down
Loading
Loading