Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from typing import Optional

import torch
import sys

if (sys.platform == "win32"):
rocminfo = "hipinfo"
else:
rocminfo = "rocminfo"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import sys should be placed before import torch (stdlib before third-party per PEP 8). Also, this module-level platform check runs for all users on import. Consider keeping the platform-conditional logic inside the functions that need it, as PR #1846 does with platform.system() == "Windows" checks.



@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -83,8 +89,8 @@ def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
result = subprocess.run([rocminfo], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-z\d]+)", result.stdout, re.IGNORECASE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regex Name:\s+gfx(...) will not match hipinfo.exe output on Windows. The hipinfo.exe utility reports GPU architecture under gcnArchName:, not Name:. For example, hipinfo outputs gcnArchName: gfx1100, not Name: gfx1100. This means get_rocm_gpu_arch() will always return "unknown" on Windows, defeating the purpose of this PR. PR #1846 handles this correctly with a separate gcnArchName: pattern for Windows.

if match:
return "gfx" + match.group(1)
else:
Expand All @@ -102,15 +108,18 @@ def get_rocm_gpu_arch() -> str:
return "unknown"


# Wavefront size (or warp size) in GPU computing is the number of threads that execute
# together in lockstep on a GPU core, typically 32 or 64, depending on the architecture
# (e.g., Nvidia is 32, older AMD GCN was 64, newer AMD RDNA can be 32 or 64).
def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
result = subprocess.run([rocminfo], capture_output=True, text=True)
match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combined regex (wavefront\s|warp)size: is creative but fragile. The [0-9]{2} in group 2 requires exactly two digits, which works for 32/64 but would break for unexpected values. More importantly, the third optional group (\([x0-9]{4}\))? requires exactly 4 characters inside parentheses — this happens to match (0x40) for warp size 64 but would fail for other hex representations. PR #1846's approach of using cleanly separated patterns (warpSize:\s+(\d+) for Windows, the existing Wavefront Size: pattern for Linux) is more robust and readable.

if match:
return int(match.group(1))
return int(match.group(2))
else:
# default to 64 to be safe
return 64
Expand Down