mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-02 05:10:32 +00:00
[Studio] Install flash attn at setup time for linux (#4979)
* [Studio] Install flash attn at setup time for linux * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup changes Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Test cases * wheel_utils: narrow url_exists exceptions and log at debug level --------- Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Roland Tannous <rolandtannous@gravityq.ai>
This commit is contained in:
parent
dccc0ebada
commit
da78c6be71
6 changed files with 815 additions and 207 deletions
|
|
@ -21,6 +21,14 @@ import tempfile
|
|||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
from backend.utils.wheel_utils import (
|
||||
flash_attn_package_version,
|
||||
flash_attn_wheel_url,
|
||||
install_wheel,
|
||||
probe_torch_wheel_env,
|
||||
url_exists,
|
||||
)
|
||||
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
IS_MACOS = sys.platform == "darwin"
|
||||
IS_MAC_INTEL = IS_MACOS and platform.machine() == "x86_64"
|
||||
|
|
@ -368,7 +376,6 @@ NO_TORCH = _infer_no_torch()
|
|||
VERBOSE: bool = os.environ.get("UNSLOTH_VERBOSE", "0") == "1"
|
||||
|
||||
# Progress bar state -- updated by _progress() as each install step runs.
|
||||
# _TOTAL counts: pip-upgrade + 7 shared steps + triton (non-Windows) + local-plugin + finalize
|
||||
# Update _TOTAL here if you add or remove install steps in install_python_stack().
|
||||
_STEP: int = 0
|
||||
_TOTAL: int = 0 # set at runtime in install_python_stack() based on platform
|
||||
|
|
@ -535,6 +542,66 @@ NO_TORCH_SKIP_PACKAGES = {
|
|||
"transformers-cfg",
|
||||
}
|
||||
|
||||
|
||||
def _select_flash_attn_version(torch_mm: str) -> str | None:
|
||||
return flash_attn_package_version(torch_mm)
|
||||
|
||||
|
||||
def _build_flash_attn_wheel_url(env: dict[str, str]) -> str | None:
|
||||
return flash_attn_wheel_url(env)
|
||||
|
||||
|
||||
def _print_optional_install_failure(
|
||||
label: str, result: subprocess.CompletedProcess[str]
|
||||
) -> None:
|
||||
_step("warning", f"{label} failed (exit code {result.returncode})", _cyan)
|
||||
if result.stdout:
|
||||
print(result.stdout.strip())
|
||||
|
||||
|
||||
def _flash_attn_install_disabled() -> bool:
|
||||
return os.getenv("UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL") == "1"
|
||||
|
||||
|
||||
def _ensure_flash_attn() -> None:
|
||||
if NO_TORCH or IS_WINDOWS or IS_MACOS:
|
||||
return
|
||||
if _flash_attn_install_disabled():
|
||||
return
|
||||
if (
|
||||
subprocess.run(
|
||||
[sys.executable, "-c", "import flash_attn"],
|
||||
stdout = subprocess.DEVNULL,
|
||||
stderr = subprocess.DEVNULL,
|
||||
).returncode
|
||||
== 0
|
||||
):
|
||||
return
|
||||
|
||||
env = probe_torch_wheel_env()
|
||||
wheel_url = _build_flash_attn_wheel_url(env) if env else None
|
||||
if wheel_url and url_exists(wheel_url):
|
||||
for installer, wheel_result in install_wheel(
|
||||
wheel_url,
|
||||
python_executable = sys.executable,
|
||||
use_uv = USE_UV,
|
||||
uv_needs_system = UV_NEEDS_SYSTEM,
|
||||
):
|
||||
if wheel_result.returncode == 0:
|
||||
return
|
||||
_print_optional_install_failure(
|
||||
f"Installing flash-attn prebuilt wheel with {installer}",
|
||||
wheel_result,
|
||||
)
|
||||
_step("warning", "Continuing without flash-attn", _cyan)
|
||||
return
|
||||
|
||||
if wheel_url is None:
|
||||
_step("warning", "No compatible flash-attn prebuilt wheel found", _cyan)
|
||||
else:
|
||||
_step("warning", "No published flash-attn prebuilt wheel found", _cyan)
|
||||
|
||||
|
||||
# -- uv bootstrap ------------------------------------------------------
|
||||
|
||||
USE_UV = False # Set by _bootstrap_uv() at the start of install_python_stack()
|
||||
|
|
@ -762,10 +829,8 @@ def install_python_stack() -> int:
|
|||
base_total = 10 if IS_WINDOWS else 11
|
||||
if IS_MACOS:
|
||||
base_total -= 1 # triton step is skipped on macOS
|
||||
# ROCm torch check steps (Linux only, non-macOS, non-no-torch):
|
||||
# one early check (step 2b) and one final repair (step 13).
|
||||
if not IS_WINDOWS and not IS_MACOS and not NO_TORCH:
|
||||
base_total += 2
|
||||
base_total += 3
|
||||
_TOTAL = (base_total - 1) if skip_base else base_total
|
||||
|
||||
# 1. Try to use uv for faster installs (must happen before pip upgrade
|
||||
|
|
@ -979,6 +1044,10 @@ def install_python_stack() -> int:
|
|||
constrain = False,
|
||||
)
|
||||
|
||||
if not IS_WINDOWS and not IS_MACOS and not NO_TORCH:
|
||||
_progress("flash-attn")
|
||||
_ensure_flash_attn()
|
||||
|
||||
# # 6. Patch: override llama_cpp.py with fix from unsloth-zoo feature/llama-cpp-windows-support branch
|
||||
# patch_package_file(
|
||||
# "unsloth-zoo",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue