diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index f9fdd295d..3db42d536 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,7 @@ import dataclasses from functools import lru_cache import logging +import os import platform import re import subprocess @@ -97,6 +98,14 @@ def get_rocm_gpu_arch() -> str: if match: return "gfx" + match.group(1) else: + override = os.environ.get("BNB_ROCM_GPU_ARCH") + if override: + override = override.strip() + if override: + if not override.startswith("gfx"): + override = f"gfx{override}" + logger.info("Using ROCm GPU arch override: %s", override) + return override return "unknown" else: return "unknown" @@ -108,6 +117,14 @@ def get_rocm_gpu_arch() -> str: ROCm GPU architecture detection failed despite ROCm being available. """, ) + override = os.environ.get("BNB_ROCM_GPU_ARCH") + if override: + override = override.strip() + if override: + if not override.startswith("gfx"): + override = f"gfx{override}" + logger.info("Using ROCm GPU arch override: %s", override) + return override return "unknown" diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index c97996b75..61c143b37 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -189,6 +189,9 @@ pip install -e . * All features are supported for both consumer RDNA devices and Data Center CDNA products. * A compatible PyTorch version with AMD ROCm support is required. It is recommended to use the latest stable release. See [PyTorch on ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/3rd-party/pytorch-install.html) for guidance. +> [!TIP] +> If `rocminfo` is unavailable or fails in your runtime environment, set `BNB_ROCM_GPU_ARCH` (for example `BNB_ROCM_GPU_ARCH=gfx90a`) as a fallback. + ### Installation from PyPI[[rocm-pip]] This is the most straightforward and recommended installation option.