unsloth/tests/python/test_flash_attn_install_python_stack.py
Roland Tannous 79adfd9c71
studio: skip flash-attn install on Blackwell GPUs (sm_100+) (#5420)
* studio: skip flash-attn install on Blackwell GPUs (sm_100+)

Dao-AILab does not publish prebuilt flash-attn wheels for sm_100, sm_120,
or sm_121, and the older-arch wheels fail to load on Blackwell. Add a
shared has_blackwell_gpu() helper and gate both the install-time
(install_python_stack._ensure_flash_attn) and runtime
(worker._ensure_flash_attn_for_long_context) paths on it. Detection uses
nvidia-smi --query-gpu=compute_cap, which works on Linux and Windows.

* test: stub has_blackwell_gpu in pre-existing runtime flash-attn tests

prefers_prebuilt_wheel and falls_back_to_pypi exercise the install
paths that the Blackwell guard now short-circuits. Make them explicit
about non-Blackwell so they pass on real Blackwell hosts.

* studio: cache has_blackwell_gpu, skip Blackwell warning under NO_TORCH

- Wrap has_blackwell_gpu in functools.lru_cache so repeated calls in a
  single process avoid redundant nvidia-smi spawns. Tests clear the
  cache via setup_method/teardown_method.
- In _ensure_flash_attn, run the NO_TORCH short-circuit before the
  Blackwell check so GGUF-only users (who never install torch anyway)
  do not see a Blackwell warning. Blackwell check still runs above the
  IS_WINDOWS / IS_MACOS gates so Blackwell-on-Windows users still see
  the explicit reason rather than a silent OS skip.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test: add has_blackwell_gpu to mlx worker test wheel_utils stub

test_mlx_training_worker_config loads worker.py against a hand-rolled
utils.wheel_utils stub. Adding has_blackwell_gpu to the stub symbol
list so worker's import line resolves.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2026-05-14 18:13:50 +04:00

477 lines
18 KiB
Python

"""Tests for the optional FlashAttention installer."""
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from unittest import mock
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
sys.path.insert(0, str(STUDIO_DIR))
sys.path.insert(0, str(STUDIO_DIR / "backend"))
import install_python_stack as ips
from backend.utils import wheel_utils
def _smi_result(stdout: str, returncode: int = 0) -> subprocess.CompletedProcess:
return subprocess.CompletedProcess(["nvidia-smi"], returncode, stdout, "")
class TestHasBlackwellGpu:
def setup_method(self):
wheel_utils.has_blackwell_gpu.cache_clear()
def teardown_method(self):
wheel_utils.has_blackwell_gpu.cache_clear()
def test_returns_false_when_nvidia_smi_missing(self):
with mock.patch.object(wheel_utils.shutil, "which", return_value = None):
assert wheel_utils.has_blackwell_gpu() is False
def test_returns_true_for_sm_100(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess, "run", return_value = _smi_result("10.0\n")
),
):
assert wheel_utils.has_blackwell_gpu() is True
def test_returns_true_for_sm_120(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess, "run", return_value = _smi_result("12.0\n")
),
):
assert wheel_utils.has_blackwell_gpu() is True
def test_returns_true_for_sm_121(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess, "run", return_value = _smi_result("12.1\n")
),
):
assert wheel_utils.has_blackwell_gpu() is True
def test_returns_false_for_sm_90(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess, "run", return_value = _smi_result("9.0\n")
),
):
assert wheel_utils.has_blackwell_gpu() is False
def test_returns_false_for_sm_89(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess, "run", return_value = _smi_result("8.9\n")
),
):
assert wheel_utils.has_blackwell_gpu() is False
def test_mixed_gpus_with_one_blackwell_returns_true(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess,
"run",
return_value = _smi_result("8.0\n10.0\n"),
),
):
assert wheel_utils.has_blackwell_gpu() is True
def test_returns_false_when_nvidia_smi_fails(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess,
"run",
return_value = _smi_result("", returncode = 1),
),
):
assert wheel_utils.has_blackwell_gpu() is False
def test_returns_false_on_subprocess_timeout(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess,
"run",
side_effect = subprocess.TimeoutExpired(cmd = "nvidia-smi", timeout = 10),
),
):
assert wheel_utils.has_blackwell_gpu() is False
def test_returns_false_on_malformed_output(self):
with (
mock.patch.object(
wheel_utils.shutil, "which", return_value = "/usr/bin/nvidia-smi"
),
mock.patch.object(
wheel_utils.subprocess,
"run",
return_value = _smi_result("not-a-number\n\n"),
),
):
assert wheel_utils.has_blackwell_gpu() is False
class TestFlashAttnWheelSelection:
def test_torch_210_maps_to_v281(self):
assert ips._select_flash_attn_version("2.10") == "2.8.1"
def test_torch_29_maps_to_v283(self):
assert ips._select_flash_attn_version("2.9") == "2.8.3"
def test_unsupported_torch_has_no_wheel_mapping(self):
assert ips._select_flash_attn_version("2.11") is None
def test_exact_wheel_url_uses_full_env_tuple(self):
url = ips._build_flash_attn_wheel_url(
{
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "12",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
}
)
assert url is not None
assert "v2.8.1" in url
assert (
"flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"
in url
)
def test_missing_cuda_major_disables_wheel_lookup(self):
assert (
ips._build_flash_attn_wheel_url(
{
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
}
)
is None
)
class TestEnsureFlashAttn:
def _import_check(self, code: int = 1):
return subprocess.CompletedProcess(["python", "-c", "import flash_attn"], code)
def test_prefers_exact_match_wheel(self):
install_calls = []
def fake_install_wheel(*args, **kwargs):
install_calls.append((args, kwargs))
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
mock.patch.object(
ips,
"probe_torch_wheel_env",
return_value = {
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "12",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
},
),
mock.patch.object(ips, "url_exists", return_value = True),
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
assert len(install_calls) == 1
args, kwargs = install_calls[0]
assert args == (
"https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.10cxx11abiTRUE-cp313-cp313-linux_x86_64.whl",
)
assert kwargs["python_executable"] == sys.executable
assert kwargs["use_uv"] is True
assert kwargs["uv_needs_system"] is False
def test_uv_install_respects_system_flag(self):
install_calls = []
def fake_install_wheel(*args, **kwargs):
install_calls.append((args, kwargs))
return [("uv", subprocess.CompletedProcess(["uv"], 0, ""))]
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", True),
mock.patch.object(
ips,
"probe_torch_wheel_env",
return_value = {
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "12",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
},
),
mock.patch.object(ips, "url_exists", return_value = True),
mock.patch.object(ips, "install_wheel", side_effect = fake_install_wheel),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
assert len(install_calls) == 1
_, kwargs = install_calls[0]
assert kwargs["uv_needs_system"] is True
def test_wheel_failure_warns_and_continues(self):
step_messages: list[tuple[str, str]] = []
printed_failures: list[str] = []
def fake_step(label: str, value: str, color_fn = None):
step_messages.append((label, value))
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
mock.patch.object(
ips,
"probe_torch_wheel_env",
return_value = {
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "12",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
},
),
mock.patch.object(ips, "url_exists", return_value = True),
mock.patch.object(
ips,
"install_wheel",
return_value = [
("uv", subprocess.CompletedProcess(["uv"], 1, "uv wheel failed")),
(
"pip",
subprocess.CompletedProcess(["pip"], 1, "pip wheel failed"),
),
],
),
mock.patch.object(
ips,
"_print_optional_install_failure",
side_effect = lambda label, result: printed_failures.append(label),
),
mock.patch.object(ips, "_step", side_effect = fake_step),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
assert printed_failures == [
"Installing flash-attn prebuilt wheel with uv",
"Installing flash-attn prebuilt wheel with pip",
]
assert ("warning", "Continuing without flash-attn") in step_messages
def test_wheel_missing_skips_install_at_setup_time(self):
step_messages: list[tuple[str, str]] = []
def fake_step(label: str, value: str, color_fn = None):
step_messages.append((label, value))
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(
ips,
"probe_torch_wheel_env",
return_value = {
"python_tag": "cp313",
"torch_mm": "2.10",
"cuda_major": "13",
"cxx11abi": "TRUE",
"platform_tag": "linux_x86_64",
},
),
mock.patch.object(ips, "url_exists", return_value = False),
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
mock.patch.object(ips, "_step", side_effect = fake_step),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
mock_install_wheel.assert_not_called()
assert (
"warning",
"No published flash-attn prebuilt wheel found",
) in step_messages
def test_skip_env_disables_setup_install(self):
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.dict(os.environ, {"UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL": "1"}),
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
mock_probe.assert_not_called()
mock_install_wheel.assert_not_called()
def test_blackwell_gpu_skips_install_with_warning(self):
step_messages: list[tuple[str, str]] = []
def fake_step(label: str, value: str, color_fn = None):
step_messages.append((label, value))
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", False),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "has_blackwell_gpu", return_value = True),
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
mock.patch.object(ips, "_step", side_effect = fake_step),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
mock_probe.assert_not_called()
mock_install_wheel.assert_not_called()
assert any(
label == "warning" and "Blackwell" in msg for label, msg in step_messages
)
def test_blackwell_gpu_on_windows_emits_blackwell_warning(self):
step_messages: list[tuple[str, str]] = []
def fake_step(label: str, value: str, color_fn = None):
step_messages.append((label, value))
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", True),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "has_blackwell_gpu", return_value = True),
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
mock.patch.object(ips, "_step", side_effect = fake_step),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
mock_probe.assert_not_called()
mock_install_wheel.assert_not_called()
assert any(
label == "warning" and "Blackwell" in msg for label, msg in step_messages
)
def test_non_blackwell_windows_does_not_emit_blackwell_warning(self):
step_messages: list[tuple[str, str]] = []
def fake_step(label: str, value: str, color_fn = None):
step_messages.append((label, value))
with (
mock.patch.object(ips, "NO_TORCH", False),
mock.patch.object(ips, "IS_WINDOWS", True),
mock.patch.object(ips, "IS_MACOS", False),
mock.patch.object(ips, "has_blackwell_gpu", return_value = False),
mock.patch.object(ips, "probe_torch_wheel_env") as mock_probe,
mock.patch.object(ips, "install_wheel") as mock_install_wheel,
mock.patch.object(ips, "_step", side_effect = fake_step),
mock.patch("subprocess.run", return_value = self._import_check()),
):
ips._ensure_flash_attn()
mock_probe.assert_not_called()
mock_install_wheel.assert_not_called()
assert not any("Blackwell" in msg for _, msg in step_messages)
class TestInstallPythonStackFlashAttnIntegration:
def _run_install(self, *, no_torch: bool, is_macos: bool, is_windows: bool) -> int:
flash_attn_calls = 0
def fake_run(cmd, **kw):
return subprocess.CompletedProcess(cmd, 0, b"", b"")
def count_flash_attn():
nonlocal flash_attn_calls
flash_attn_calls += 1
with (
mock.patch.object(ips, "NO_TORCH", no_torch),
mock.patch.object(ips, "IS_MACOS", is_macos),
mock.patch.object(ips, "IS_WINDOWS", is_windows),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
mock.patch.object(ips, "VERBOSE", False),
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
mock.patch.object(ips, "_ensure_flash_attn", side_effect = count_flash_attn),
mock.patch("subprocess.run", side_effect = fake_run),
mock.patch.object(ips, "_has_usable_nvidia_gpu", return_value = False),
mock.patch.object(ips, "_has_rocm_gpu", return_value = False),
mock.patch.object(
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
),
mock.patch("pathlib.Path.is_dir", return_value = True),
mock.patch("pathlib.Path.is_file", return_value = True),
mock.patch.dict(os.environ, {"SKIP_STUDIO_BASE": "1"}, clear = False),
):
ips.install_python_stack()
return flash_attn_calls
def test_linux_torch_install_calls_flash_attn_step(self):
assert self._run_install(no_torch = False, is_macos = False, is_windows = False) == 1
def test_no_torch_install_skips_flash_attn_step(self):
assert self._run_install(no_torch = True, is_macos = False, is_windows = False) == 0
def test_macos_install_skips_flash_attn_step(self):
assert self._run_install(no_torch = False, is_macos = True, is_windows = False) == 0
def test_windows_install_skips_flash_attn_step(self):
assert self._run_install(no_torch = False, is_macos = False, is_windows = True) == 0