-
-
Notifications
You must be signed in to change notification settings - Fork 827
Update cuda_specs.py #1833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update cuda_specs.py #1833
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,12 @@ | |
| from typing import Optional | ||
|
|
||
| import torch | ||
| import sys | ||
|
|
||
| if (sys.platform == "win32"): | ||
| rocminfo = "hipinfo" | ||
| else: | ||
| rocminfo = "rocminfo" | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
|
|
@@ -83,8 +89,8 @@ def get_rocm_gpu_arch() -> str: | |
| logger = logging.getLogger(__name__) | ||
| try: | ||
| if torch.version.hip: | ||
| result = subprocess.run(["rocminfo"], capture_output=True, text=True) | ||
| match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) | ||
| result = subprocess.run([rocminfo], capture_output=True, text=True) | ||
| match = re.search(r"Name:\s+gfx([a-z\d]+)", result.stdout, re.IGNORECASE) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This regex |
||
| if match: | ||
| return "gfx" + match.group(1) | ||
| else: | ||
|
|
@@ -102,15 +108,18 @@ def get_rocm_gpu_arch() -> str: | |
| return "unknown" | ||
|
|
||
|
|
||
| # Wavefront size (or warp size) in GPU computing is the number of threads that execute | ||
| # together in lockstep on a GPU core, typically 32 or 64, depending on the architecture | ||
| # (e.g., Nvidia is 32, older AMD GCN was 64, newer AMD RDNA can be 32 or 64). | ||
| def get_rocm_warpsize() -> int: | ||
| """Get ROCm warp size.""" | ||
| logger = logging.getLogger(__name__) | ||
| try: | ||
| if torch.version.hip: | ||
| result = subprocess.run(["rocminfo"], capture_output=True, text=True) | ||
| match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) | ||
| result = subprocess.run([rocminfo], capture_output=True, text=True) | ||
| match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The combined regex |
||
| if match: | ||
| return int(match.group(1)) | ||
| return int(match.group(2)) | ||
| else: | ||
| # default to 64 to be safe | ||
| return 64 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import sysshould be placed beforeimport torch(stdlib before third-party per PEP 8). Also, this module-level platform check runs for all users on import. Consider keeping the platform-conditional logic inside the functions that need it, as PR #1846 does withplatform.system() == "Windows"checks.