mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-02 21:30:36 +00:00
Pin bitsandbytes to continuous-release_main on ROCm (4-bit decode fix) (#4954)
* Pin bitsandbytes to continuous-release_main on ROCm for 4-bit decode fix
bitsandbytes 0.49.2 on PyPI ships with a broken 4-bit GEMV kernel on
every ROCm target:
- CDNA (gfx90a / gfx942 / gfx950 = MI210 / MI300X / MI350) via a
broken blocksize=32/64 warp64 GEMV kernel whose tests were
explicitly skipped with ROCM_WARP_SIZE_64 guards because the
code was known broken.
- RDNA3 / RDNA3.5 (gfx1100-1103 / gfx1150-1152) via a compile-time
BNB_WARP_SIZE macro in the host-side dispatch that resolves to
64 when the multi-arch wheel is compiled with CDNA as the
primary target, so num_blocks is wrong on RDNA and half the GEMV
output is never written.
At decode shape (1, 1, hidden) both bugs produce NaN. Training is
unaffected because training shapes are (batch, seq_len > 1, hidden)
and never touch the GEMV path. The crash during autoregressive
inference surfaces as _assert_async_cuda_kernel in torch.multinomial
which on HIP becomes a hard HSA_STATUS_ERROR_EXCEPTION instead of
a clean Python error.
Both bugs are fixed by bitsandbytes commit 713a3b8 ("[ROCm] Enable
blocksize 32 4-bit quantization and GEMV kernels on AMD CDNA",
PR #1887, merged 2026-03-09) which replaces BNB_WARP_SIZE with a
runtime hipDeviceGetAttribute query and ships a working CDNA warp64
kernel. That commit has not shipped to PyPI yet, but
continuous-release_main wheels are published on every push to bnb
main via GitHub Releases.
Point the ROCm install path at the continuous-release_main x86_64 and
aarch64 wheels and fall back to PyPI >=0.49.1 when the pre-release is
unreachable (offline installs, firewalled hosts, or architectures not
covered by the pre-release wheels). Drop the pin once bnb cuts a
0.50+ tag on PyPI.
Verified on MI300X (gfx942, ROCm 7.2, torch 2.10.0+rocm7.1): direct
bnb GEMV shape test now returns 0.0078 max abs error at seq_len=1
(no NaN) vs NaN on 0.49.2, and full Unsloth + for_inference + 4-bit
sampling generation works end-to-end.
NVIDIA / CPU / Mac / Windows paths are unaffected -- the helper is
gated on the ROCm torch index and platform.machine() respectively.
* Drop Studio ROCm 16-bit fallback now that bnb 0.50+ fixes 4-bit decode
The 16-bit fallback in studio/backend/core/inference/inference.py was
added as a workaround for a bug that this PR already fixes at the
install layer: bitsandbytes <= 0.49.2 has a broken 4-bit GEMV kernel
on every ROCm target, which NaNs at decode shape (seq_len=1) and
crashes autoregressive inference. bnb PR #1887 (commit 713a3b8, in
0.50.0.dev0+, pinned by install.sh / install_python_stack.py in this
PR) restores correct 4-bit decode on MI300X and verified working
end-to-end with full Unsloth + for_inference + sampling.
Revert the dual code path so ROCm and NVIDIA both go through the
normal FastLanguageModel.from_pretrained + for_inference flow:
- Remove the conditional `from unsloth import` that skipped the
import on ROCm. The monkey-patches it was trying to avoid were
never the cause of the crash; bnb 4-bit GEMV was.
- Remove the `if _hw_module.IS_ROCM:` branch in load_model that
loaded with plain transformers + PEFT + bfloat16, and the
`_resolve_fp16_base` helper it relied on.
- Remove the `get_chat_template is not None` fallback in
_load_chat_template_info -- get_chat_template is now always
imported.
- Refactor the audio/vision ROCm guard to check _hw_module.IS_ROCM
directly instead of the removed _IS_ROCM_ENV global. Audio and
vision on ROCm still need separate validation (FastVisionModel
and the CSM audio codecs were never tested on HIP) so the guard
stays for now.
Add _bnb_rocm_4bit_ok() as a runtime safety net for users who
install from this PR before the install.sh bnb pin kicks in, or
whose installer fell back to the PyPI pin because the continuous-
release wheel was unreachable. When the installed bnb is < 0.50 on
ROCm, force load_in_4bit=False and strip any -unsloth-bnb-4bit /
-bnb-4bit suffix from the model path so a pre-quantized repo
resolves to its FP16 sibling instead of pulling bnb back in via
the repo's quantization_config. LoRA adapters whose base is a
pre-quantized repo on old bnb will still fail inside Unsloth's
loader -- the only real fix there is `unsloth studio update`.
Verified on MI300X (gfx942, ROCm 7.2, torch 2.10.0+rocm7.1):
- HAPPY path (bnb 0.50.0.dev0, load_in_4bit=True, pre-quantized
repo): loads in 4-bit via the fixed GEMV, generation returns
"Paris." for greedy and sampling.
- SAFETY-NET path (simulated old bnb, suffix-stripped to the
FP16 sibling, load_in_4bit=False): loads in bf16, generation
returns "Paris." for greedy and sampling.
Net diff is ~45 lines smaller than the pre-revert state because
the entire plain-transformers 16-bit branch is gone.
* Cache _bnb_rocm_4bit_ok() with functools.cache
load_model() can be called many times in a single session but the bnb
version and hardware state cannot change at runtime, so memoise the
check. First call is ~1.9 ms (dominated by the lazy `import bitsandbytes`
inside the try block), subsequent calls drop to sub-microsecond dict
lookups. Zero behavioral change.
* Shorten verbose bnb/ROCm comments
Comment-only cleanup across install.sh, studio/install_python_stack.py,
and studio/backend/core/inference/inference.py. No behavioral change.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Remove _bnb_rocm_4bit_ok safety net from inference.py
Studio's ROCm support is brand new (PR #4720, merged today) and every
fresh install pulls the bnb continuous-release_main wheel via
install.sh / install_python_stack.py in this same PR. There are no
existing ROCm Studio installs carrying bnb < 0.50, so the defensive
version-check fallback is guarding against a scenario that cannot
actually occur. Delete the helper, the functools import, and the
safety-net block -- inference.py now calls FastLanguageModel.from_pretrained
directly with no ROCm branching.
* Drop audio/vision ROCm guard in inference.py — verified unblocked by bnb fix
Vision inference was blocked by the same bnb 4-bit GEMV bug that affected
text inference (vision models use bnb 4-bit for the LM backbone). With
bnb 0.50+ pinned in install.sh / install_python_stack.py, vision works
end-to-end on MI300X: Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit
loaded in 4-bit via FastVisionModel + for_inference returns a correct
answer to a multimodal prompt.
Audio (CSM) was never actually blocked by HIP — on this hardware CSM
loads and runs its backbone forward pass fine with bnb 0.50, then fails
during generate() with a transformers-level kwarg validation mismatch
in generation_csm.py (`backbone_last_hidden_state` rejected). That's a
pre-existing transformers/CSM integration bug that reproduces identically
on NVIDIA, so the ROCm-gated guard was never actually protecting users
from anything HIP-specific.
Remove the combined audio/vision guard and the now-unused _hw_module
import. Also restore the one-word "Can be" in an inline comment that
drifted during the earlier comment-shortening pass, so the inference.py
delta vs pre-#4720 is exactly the max_seq_length<=0 crash fix and
nothing else.
* Shorten max_seq_length=0 guard comment to one line
---------
Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cad8c6ad05
commit
65b4028560
3 changed files with 137 additions and 132 deletions
|
|
@ -43,6 +43,32 @@ _ROCM_TORCH_INDEX: dict[tuple[int, int], str] = {
|
|||
}
|
||||
_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl"
|
||||
|
||||
# bitsandbytes continuous-release_main wheels with the ROCm 4-bit GEMV fix
|
||||
# (bnb PR #1887, post-0.49.2). bnb <= 0.49.2 NaNs at decode shape on every
|
||||
# AMD GPU. Drop the pin once bnb 0.50+ ships on PyPI.
|
||||
_BNB_ROCM_PRERELEASE_URLS: dict[str, str] = {
|
||||
"x86_64": (
|
||||
"https://github.com/bitsandbytes-foundation/bitsandbytes/releases/"
|
||||
"download/continuous-release_main/"
|
||||
"bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl"
|
||||
),
|
||||
"aarch64": (
|
||||
"https://github.com/bitsandbytes-foundation/bitsandbytes/releases/"
|
||||
"download/continuous-release_main/"
|
||||
"bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_aarch64.whl"
|
||||
),
|
||||
}
|
||||
_BNB_ROCM_PYPI_FALLBACK = "bitsandbytes>=0.49.1"
|
||||
|
||||
|
||||
def _bnb_rocm_prerelease_url() -> str | None:
|
||||
"""Return the continuous-release_main bnb wheel URL for the current
|
||||
architecture, or None when no pre-release wheel is available.
|
||||
"""
|
||||
arch = platform.machine().lower()
|
||||
arch = {"amd64": "x86_64", "arm64": "aarch64"}.get(arch, arch)
|
||||
return _BNB_ROCM_PRERELEASE_URLS.get(arch)
|
||||
|
||||
|
||||
def _detect_rocm_version() -> tuple[int, int] | None:
|
||||
"""Return (major, minor) of the installed ROCm stack, or None."""
|
||||
|
|
@ -284,21 +310,37 @@ def _ensure_rocm_torch() -> None:
|
|||
)
|
||||
rocm_torch_ready = True
|
||||
|
||||
# Install bitsandbytes only when the venv has a ROCm-compatible torch
|
||||
# (either already present or just installed). Avoids leaving an AMD
|
||||
# bitsandbytes on top of a CPU/CUDA torch on hosts where the ROCm
|
||||
# runtime is older than any published torch wheel. Uses
|
||||
# --force-reinstall so an existing CPU/CUDA bitsandbytes is replaced
|
||||
# by the AMD build during upgrades.
|
||||
# Install bitsandbytes only when torch links against ROCm. Prefers the
|
||||
# continuous-release_main wheel (bnb PR #1887 4-bit GEMV fix) and falls
|
||||
# back to PyPI when the pre-release URL is unreachable.
|
||||
if rocm_torch_ready:
|
||||
pip_install(
|
||||
"bitsandbytes (AMD)",
|
||||
"--force-reinstall",
|
||||
"--no-cache-dir",
|
||||
"--no-deps",
|
||||
"bitsandbytes>=0.49.1",
|
||||
constrain = False,
|
||||
)
|
||||
_bnb_url = _bnb_rocm_prerelease_url()
|
||||
_bnb_installed = False
|
||||
if _bnb_url is not None:
|
||||
_bnb_installed = pip_install_try(
|
||||
"bitsandbytes (AMD, pre-release main)",
|
||||
"--force-reinstall",
|
||||
"--no-cache-dir",
|
||||
"--no-deps",
|
||||
_bnb_url,
|
||||
constrain = False,
|
||||
)
|
||||
if not _bnb_installed:
|
||||
print(
|
||||
_red(
|
||||
" bnb pre-release unreachable; falling back to PyPI "
|
||||
"(4-bit decode will be broken on ROCm)"
|
||||
)
|
||||
)
|
||||
if not _bnb_installed:
|
||||
pip_install(
|
||||
"bitsandbytes (AMD)",
|
||||
"--force-reinstall",
|
||||
"--no-cache-dir",
|
||||
"--no-deps",
|
||||
_BNB_ROCM_PYPI_FALLBACK,
|
||||
constrain = False,
|
||||
)
|
||||
|
||||
|
||||
def _infer_no_torch() -> bool:
|
||||
|
|
@ -593,6 +635,37 @@ def _build_uv_cmd(args: tuple[str, ...]) -> list[str]:
|
|||
return cmd
|
||||
|
||||
|
||||
def pip_install_try(
|
||||
label: str,
|
||||
*args: str,
|
||||
constrain: bool = True,
|
||||
) -> bool:
|
||||
"""Like pip_install but returns False on failure instead of exiting.
|
||||
For optional installs with a follow-up fallback.
|
||||
"""
|
||||
constraint_args: list[str] = []
|
||||
if constrain and CONSTRAINTS.is_file():
|
||||
constraint_args = ["-c", str(CONSTRAINTS)]
|
||||
|
||||
if USE_UV:
|
||||
cmd = _build_uv_cmd(args) + constraint_args
|
||||
else:
|
||||
cmd = _build_pip_cmd(args) + constraint_args
|
||||
|
||||
if VERBOSE:
|
||||
_step(_LABEL, f"{label}...", _dim)
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
stdout = subprocess.PIPE,
|
||||
stderr = subprocess.STDOUT,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
if VERBOSE and result.stdout:
|
||||
print(result.stdout.decode(errors = "replace"))
|
||||
return False
|
||||
|
||||
|
||||
def pip_install(
|
||||
label: str,
|
||||
*args: str,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue