[Studio] Install flash attn at setup time for linux (#4979)

* [Studio] Install flash attn at setup time for linux

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup changes

Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test cases

* wheel_utils: narrow url_exists exceptions and log at debug level

---------

Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
This commit is contained in:
Datta Nimmaturi 2026-04-14 18:10:17 +05:30 committed by GitHub
parent dccc0ebada
commit da78c6be71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 815 additions and 207 deletions

View file

@ -21,6 +21,14 @@ import tempfile
import urllib.request
from pathlib import Path
from backend.utils.wheel_utils import (
flash_attn_package_version,
flash_attn_wheel_url,
install_wheel,
probe_torch_wheel_env,
url_exists,
)
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
IS_MAC_INTEL = IS_MACOS and platform.machine() == "x86_64"
@ -368,7 +376,6 @@ NO_TORCH = _infer_no_torch()
VERBOSE: bool = os.environ.get("UNSLOTH_VERBOSE", "0") == "1"
# Progress bar state -- updated by _progress() as each install step runs.
# _TOTAL counts: pip-upgrade + 7 shared steps + triton (non-Windows) + local-plugin + finalize
# Update _TOTAL here if you add or remove install steps in install_python_stack().
_STEP: int = 0
_TOTAL: int = 0 # set at runtime in install_python_stack() based on platform
@ -535,6 +542,66 @@ NO_TORCH_SKIP_PACKAGES = {
"transformers-cfg",
}
def _select_flash_attn_version(torch_mm: str) -> str | None:
return flash_attn_package_version(torch_mm)
def _build_flash_attn_wheel_url(env: dict[str, str]) -> str | None:
return flash_attn_wheel_url(env)
def _print_optional_install_failure(
label: str, result: subprocess.CompletedProcess[str]
) -> None:
_step("warning", f"{label} failed (exit code {result.returncode})", _cyan)
if result.stdout:
print(result.stdout.strip())
def _flash_attn_install_disabled() -> bool:
return os.getenv("UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL") == "1"
def _ensure_flash_attn() -> None:
if NO_TORCH or IS_WINDOWS or IS_MACOS:
return
if _flash_attn_install_disabled():
return
if (
subprocess.run(
[sys.executable, "-c", "import flash_attn"],
stdout = subprocess.DEVNULL,
stderr = subprocess.DEVNULL,
).returncode
== 0
):
return
env = probe_torch_wheel_env()
wheel_url = _build_flash_attn_wheel_url(env) if env else None
if wheel_url and url_exists(wheel_url):
for installer, wheel_result in install_wheel(
wheel_url,
python_executable = sys.executable,
use_uv = USE_UV,
uv_needs_system = UV_NEEDS_SYSTEM,
):
if wheel_result.returncode == 0:
return
_print_optional_install_failure(
f"Installing flash-attn prebuilt wheel with {installer}",
wheel_result,
)
_step("warning", "Continuing without flash-attn", _cyan)
return
if wheel_url is None:
_step("warning", "No compatible flash-attn prebuilt wheel found", _cyan)
else:
_step("warning", "No published flash-attn prebuilt wheel found", _cyan)
# -- uv bootstrap ------------------------------------------------------
USE_UV = False # Set by _bootstrap_uv() at the start of install_python_stack()
@ -762,10 +829,8 @@ def install_python_stack() -> int:
base_total = 10 if IS_WINDOWS else 11
if IS_MACOS:
base_total -= 1 # triton step is skipped on macOS
# ROCm torch check steps (Linux only, non-macOS, non-no-torch):
# one early check (step 2b) and one final repair (step 13).
if not IS_WINDOWS and not IS_MACOS and not NO_TORCH:
base_total += 2
base_total += 3
_TOTAL = (base_total - 1) if skip_base else base_total
# 1. Try to use uv for faster installs (must happen before pip upgrade
@ -979,6 +1044,10 @@ def install_python_stack() -> int:
constrain = False,
)
if not IS_WINDOWS and not IS_MACOS and not NO_TORCH:
_progress("flash-attn")
_ensure_flash_attn()
# # 6. Patch: override llama_cpp.py with fix from unsloth-zoo feature/llama-cpp-windows-support branch
# patch_package_file(
# "unsloth-zoo",