mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-18 06:03:39 +00:00
* studio: skip flash-attn install on Blackwell GPUs (sm_100+) Dao-AILab does not publish prebuilt flash-attn wheels for sm_100, sm_120, or sm_121, and the older-arch wheels fail to load on Blackwell. Add a shared has_blackwell_gpu() helper and gate both the install-time (install_python_stack._ensure_flash_attn) and runtime (worker._ensure_flash_attn_for_long_context) paths on it. Detection uses nvidia-smi --query-gpu=compute_cap, which works on Linux and Windows. * test: stub has_blackwell_gpu in pre-existing runtime flash-attn tests prefers_prebuilt_wheel and falls_back_to_pypi exercise the install paths that the Blackwell guard now short-circuits. Make them explicit about non-Blackwell so they pass on real Blackwell hosts. * studio: cache has_blackwell_gpu, skip Blackwell warning under NO_TORCH - Wrap has_blackwell_gpu in functools.lru_cache so repeated calls in a single process avoid redundant nvidia-smi spawns. Tests clear the cache via setup_method/teardown_method. - In _ensure_flash_attn, run the NO_TORCH short-circuit before the Blackwell check so GGUF-only users (who never install torch anyway) do not see a Blackwell warning. Blackwell check still runs above the IS_WINDOWS / IS_MACOS gates so Blackwell-on-Windows users still see the explicit reason rather than a silent OS skip. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: add has_blackwell_gpu to mlx worker test wheel_utils stub test_mlx_training_worker_config loads worker.py against a hand-rolled utils.wheel_utils stub. Adding has_blackwell_gpu to the stub symbol list so worker's import line resolves. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
477 lines
18 KiB
Python
477 lines
18 KiB
Python
"""Tests for the optional FlashAttention installer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
|
|
sys.path.insert(0, str(STUDIO_DIR))
|
|
sys.path.insert(0, str(STUDIO_DIR / "backend"))
|
|
|
|
import install_python_stack as ips
|
|
from backend.utils import wheel_utils
|
|
|
|
|
|
def _smi_result(stdout: str, returncode: int = 0) -> subprocess.CompletedProcess:
|
|
return subprocess.CompletedProcess(["nvidia-smi"], returncode, stdout, "")
|
|
|
|
|
|
class TestHasBlackwellGpu:
|
|
def setup_method(self):
|
|
wheel_utils.has_blackwell_gpu.cache_clear()
|
|
|
|
def teardown_method(self):
|
|
wheel_utils.has_blackwell_gpu.cache_clear()
|
|
|
|
def test_returns_false_when_nvidia_smi_missing(self):
|
|
with mock.patch.object(wheel_utils.shutil, "which", return_value = None):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
def test_returns_true_for_sm_100(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess, "run", return_value = _smi_result("10.0\n")
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is True
|
|
|
|
def test_returns_true_for_sm_120(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess, "run", return_value = _smi_result("12.0\n")
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is True
|
|
|
|
def test_returns_true_for_sm_121(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess, "run", return_value = _smi_result("12.1\n")
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is True
|
|
|
|
def test_returns_false_for_sm_90(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess, "run", return_value = _smi_result("9.0\n")
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
def test_returns_false_for_sm_89(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess, "run", return_value = _smi_result("8.9\n")
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
def test_mixed_gpus_with_one_blackwell_returns_true(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess,
|
|
"run",
|
|
return_value = _smi_result("8.0\n10.0\n"),
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is True
|
|
|
|
def test_returns_false_when_nvidia_smi_fails(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess,
|
|
"run",
|
|
return_value = _smi_result("", returncode = 1),
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
def test_returns_false_on_subprocess_timeout(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess,
|
|
"run",
|
|
side_effect = subprocess.TimeoutExpired(cmd = "nvidia-smi", timeout = 10),
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
def test_returns_false_on_malformed_output(self):
|
|
with (
|
|
mock.patch.object(
|
|
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
|
|
),
|
|
mock.patch.object(
|
|
wheel_utils.subprocess,
|
|
"run",
|
|
return_value = _smi_result("not-a-number\n\n"),
|
|
),
|
|
):
|
|
assert wheel_utils.has_blackwell_gpu() is False
|
|
|
|
|
|
class TestFlashAttnWheelSelection:
|
|
def test_torch_210_maps_to_v281(self):
|
|
assert ips._select_flash_attn_version("2.10") == "2.8.1"
|
|
|
|
def test_torch_29_maps_to_v283(self):
|
|
assert ips._select_flash_attn_version("2.9") == "2.8.3"
|
|
|
|
def test_unsupported_torch_has_no_wheel_mapping(self):
|
|
assert ips._select_flash_attn_version("2.11") is None
|
|
|
|
def test_exact_wheel_url_uses_full_env_tuple(self):
|
|
url = ips._build_flash_attn_wheel_url(
|
|
{
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
}
|
|
)
|
|
assert url is not None
|
|
assert "v2.8.1" in url
|
|
assert (
|
|
"flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"
|
|
in url
|
|
)
|
|
|
|
def test_missing_cuda_major_disables_wheel_lookup(self):
|
|
assert (
|
|
ips._build_flash_attn_wheel_url(
|
|
{
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
}
|
|
)
|
|
is None
|
|
)
|
|
|
|
|
|
class TestEnsureFlashAttn:
|
|
def _import_check(self, code: int = 1):
|
|
return subprocess.CompletedProcess(["python", "-c", "import flash_attn"], code)
|
|
|
|
def test_prefers_exact_match_wheel(self):
|
|
install_calls = []
|
|
|
|
def fake_install_wheel(*args, **kwargs):
|
|
install_calls.append((args, kwargs))
|
|
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert len(install_calls) == 1
|
|
args, kwargs = install_calls[0]
|
|
assert args == (
|
|
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl",
|
|
)
|
|
assert kwargs["python_executable"] == sys.executable
|
|
assert kwargs["use_uv"] is True
|
|
assert kwargs["uv_needs_system"] is False
|
|
|
|
def test_uv_install_respects_system_flag(self):
|
|
install_calls = []
|
|
|
|
def fake_install_wheel(*args, **kwargs):
|
|
install_calls.append((args, kwargs))
|
|
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", True),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert len(install_calls) == 1
|
|
_, kwargs = install_calls[0]
|
|
assert kwargs["uv_needs_system"] is True
|
|
|
|
def test_wheel_failure_warns_and_continues(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
printed_failures: list[str] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "12",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = True),
|
|
mock.patch.object(
|
|
ips,
|
|
"install_wheel",
|
|
return_value = [
|
|
("uv", subprocess.CompletedProcess(["uv"], 1, "uv wheel failed")),
|
|
(
|
|
"pip",
|
|
subprocess.CompletedProcess(["pip"], 1, "pip wheel failed"),
|
|
),
|
|
],
|
|
),
|
|
mock.patch.object(
|
|
ips,
|
|
"_print_optional_install_failure",
|
|
side_effect = lambda label, result: printed_failures.append(label),
|
|
),
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
assert printed_failures == [
|
|
"Installing flash-attn prebuilt wheel with uv",
|
|
"Installing flash-attn prebuilt wheel with pip",
|
|
]
|
|
assert ("warning", "Continuing without flash-attn") in step_messages
|
|
|
|
def test_wheel_missing_skips_install_at_setup_time(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(
|
|
ips,
|
|
"probe_torch_wheel_env",
|
|
return_value = {
|
|
"python_tag": "cp313",
|
|
"torch_mm": "2.10",
|
|
"cuda_major": "13",
|
|
"cxx11abi": "TRUE",
|
|
"platform_tag": "linux_x86_64",
|
|
},
|
|
),
|
|
mock.patch.object(ips, "url_exists", return_value = False),
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_install_wheel.assert_not_called()
|
|
assert (
|
|
"warning",
|
|
"No published flash-attn prebuilt wheel found",
|
|
) in step_messages
|
|
|
|
def test_skip_env_disables_setup_install(self):
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.dict(os.environ, {"UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL": "1"}),
|
|
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_probe.assert_not_called()
|
|
mock_install_wheel.assert_not_called()
|
|
|
|
def test_blackwell_gpu_skips_install_with_warning(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", False),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "has_blackwell_gpu", return_value = True),
|
|
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_probe.assert_not_called()
|
|
mock_install_wheel.assert_not_called()
|
|
assert any(
|
|
label == "warning" and "Blackwell" in msg for label, msg in step_messages
|
|
)
|
|
|
|
def test_blackwell_gpu_on_windows_emits_blackwell_warning(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", True),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "has_blackwell_gpu", return_value = True),
|
|
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_probe.assert_not_called()
|
|
mock_install_wheel.assert_not_called()
|
|
assert any(
|
|
label == "warning" and "Blackwell" in msg for label, msg in step_messages
|
|
)
|
|
|
|
def test_non_blackwell_windows_does_not_emit_blackwell_warning(self):
|
|
step_messages: list[tuple[str, str]] = []
|
|
|
|
def fake_step(label: str, value: str, color_fn = None):
|
|
step_messages.append((label, value))
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", False),
|
|
mock.patch.object(ips, "IS_WINDOWS", True),
|
|
mock.patch.object(ips, "IS_MACOS", False),
|
|
mock.patch.object(ips, "has_blackwell_gpu", return_value = False),
|
|
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
|
|
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
|
|
mock.patch.object(ips, "_step", side_effect = fake_step),
|
|
mock.patch("subprocess.run", return_value = self._import_check()),
|
|
):
|
|
ips._ensure_flash_attn()
|
|
|
|
mock_probe.assert_not_called()
|
|
mock_install_wheel.assert_not_called()
|
|
assert not any("Blackwell" in msg for _, msg in step_messages)
|
|
|
|
|
|
class TestInstallPythonStackFlashAttnIntegration:
|
|
def _run_install(self, *, no_torch: bool, is_macos: bool, is_windows: bool) -> int:
|
|
flash_attn_calls = 0
|
|
|
|
def fake_run(cmd, **kw):
|
|
return subprocess.CompletedProcess(cmd, 0, b"", b"")
|
|
|
|
def count_flash_attn():
|
|
nonlocal flash_attn_calls
|
|
flash_attn_calls += 1
|
|
|
|
with (
|
|
mock.patch.object(ips, "NO_TORCH", no_torch),
|
|
mock.patch.object(ips, "IS_MACOS", is_macos),
|
|
mock.patch.object(ips, "IS_WINDOWS", is_windows),
|
|
mock.patch.object(ips, "USE_UV", True),
|
|
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
|
|
mock.patch.object(ips, "VERBOSE", False),
|
|
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
|
|
mock.patch.object(ips, "_ensure_flash_attn", side_effect = count_flash_attn),
|
|
mock.patch("subprocess.run", side_effect = fake_run),
|
|
mock.patch.object(ips, "_has_usable_nvidia_gpu", return_value = False),
|
|
mock.patch.object(ips, "_has_rocm_gpu", return_value = False),
|
|
mock.patch.object(
|
|
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
|
|
),
|
|
mock.patch("pathlib.Path.is_dir", return_value = True),
|
|
mock.patch("pathlib.Path.is_file", return_value = True),
|
|
mock.patch.dict(os.environ, {"SKIP_STUDIO_BASE": "1"}, clear = False),
|
|
):
|
|
ips.install_python_stack()
|
|
|
|
return flash_attn_calls
|
|
|
|
def test_linux_torch_install_calls_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = False, is_windows = False) == 1
|
|
|
|
def test_no_torch_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = True, is_macos = False, is_windows = False) == 0
|
|
|
|
def test_macos_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = True, is_windows = False) == 0
|
|
|
|
def test_windows_install_skips_flash_attn_step(self):
|
|
assert self._run_install(no_torch = False, is_macos = False, is_windows = True) == 0
|