mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var (#5024)
* Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var When set, UNSLOTH_PYTORCH_MIRROR overrides the default https://download.pytorch.org/whl base URL in all four install scripts (install.sh, install.ps1, studio/setup.ps1, studio/install_python_stack.py). When unset or empty, the official URL is used. This lets users behind corporate proxies or in regions with poor connectivity to pytorch.org point at a local mirror without patching scripts. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add pytest for UNSLOTH_PYTORCH_MIRROR in install_python_stack.py Tests that _PYTORCH_WHL_BASE picks up the env var when set, falls back to the official URL when unset or empty, and preserves the value as-is (including trailing slashes). * Remove stale test assertions for missing install.sh messages * Fix GPU mocking in test_get_torch_index_url.sh Extract _has_usable_nvidia_gpu and _has_amd_rocm_gpu alongside get_torch_index_url so the GPU-presence checks work in tests. Add -L flag handling to mock nvidia-smi so it passes the GPU listing check. All 26 tests now pass on CPU-only machines. * Strip trailing slash from UNSLOTH_PYTORCH_MIRROR to avoid double-slash URLs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
826c98f3c0
commit
13928b5f0e
8 changed files with 131 additions and 34 deletions
|
|
@ -754,7 +754,7 @@ shell.Run cmd, 0, False
|
|||
# ── Choose the correct PyTorch index URL based on driver CUDA version ──
|
||||
# Mirrors Get-PytorchCudaTag in setup.ps1.
|
||||
function Get-TorchIndexUrl {
|
||||
$baseUrl = "https://download.pytorch.org/whl"
|
||||
$baseUrl = if ($env:UNSLOTH_PYTORCH_MIRROR) { $env:UNSLOTH_PYTORCH_MIRROR.TrimEnd('/') } else { "https://download.pytorch.org/whl" }
|
||||
if (-not $NvidiaSmiExe) { return "$baseUrl/cpu" }
|
||||
try {
|
||||
$output = & $NvidiaSmiExe 2>&1 | Out-String
|
||||
|
|
|
|||
|
|
@ -1053,7 +1053,8 @@ _has_usable_nvidia_gpu() {
|
|||
# On CPU-only machines this returns the cpu index, avoiding the solver
|
||||
# dead-end where --torch-backend=auto resolves to unsloth==2024.8.
|
||||
get_torch_index_url() {
|
||||
_base="https://download.pytorch.org/whl"
|
||||
_base="${UNSLOTH_PYTORCH_MIRROR:-https://download.pytorch.org/whl}"
|
||||
_base="${_base%/}"
|
||||
# macOS: always CPU (no CUDA support)
|
||||
case "$(uname -s)" in Darwin) echo "$_base/cpu"; return ;; esac
|
||||
# Try nvidia-smi -- require the binary to actually list a usable GPU.
|
||||
|
|
|
|||
55
studio/backend/tests/test_pytorch_mirror.py
Normal file
55
studio/backend/tests/test_pytorch_mirror.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"""Tests for UNSLOTH_PYTORCH_MIRROR env var in install_python_stack.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# install_python_stack.py lives at repo_root/studio/install_python_stack.py
|
||||
_INSTALL_SCRIPT = Path(__file__).resolve().parents[2] / "install_python_stack.py"
|
||||
|
||||
OFFICIAL_URL = "https://download.pytorch.org/whl"
|
||||
|
||||
|
||||
def _reload_whl_base(monkeypatch, mirror_value = None):
|
||||
"""(Re-)import install_python_stack with a controlled env and return _PYTORCH_WHL_BASE."""
|
||||
# Remove cached module so the module-level assignment re-executes
|
||||
sys.modules.pop("install_python_stack", None)
|
||||
|
||||
if mirror_value is None:
|
||||
monkeypatch.delenv("UNSLOTH_PYTORCH_MIRROR", raising = False)
|
||||
else:
|
||||
monkeypatch.setenv("UNSLOTH_PYTORCH_MIRROR", mirror_value)
|
||||
|
||||
# Temporarily add the script's directory to sys.path for import
|
||||
script_dir = str(_INSTALL_SCRIPT.parent)
|
||||
monkeypatch.syspath_prepend(script_dir)
|
||||
|
||||
import install_python_stack
|
||||
|
||||
return install_python_stack._PYTORCH_WHL_BASE
|
||||
|
||||
|
||||
class TestPyTorchMirrorEnvVar:
|
||||
"""UNSLOTH_PYTORCH_MIRROR controls _PYTORCH_WHL_BASE in install_python_stack."""
|
||||
|
||||
def test_unset_uses_official_url(self, monkeypatch):
|
||||
assert _reload_whl_base(monkeypatch) == OFFICIAL_URL
|
||||
|
||||
def test_empty_string_falls_back_to_official(self, monkeypatch):
|
||||
assert _reload_whl_base(monkeypatch, "") == OFFICIAL_URL
|
||||
|
||||
def test_custom_mirror_is_used(self, monkeypatch):
|
||||
mirror = "https://mirrors.nju.edu.cn/pytorch/whl"
|
||||
assert _reload_whl_base(monkeypatch, mirror) == mirror
|
||||
|
||||
def test_trailing_slash_stripped(self, monkeypatch):
|
||||
result = _reload_whl_base(monkeypatch, "https://example.com/whl/")
|
||||
assert result == "https://example.com/whl"
|
||||
|
|
@ -49,7 +49,9 @@ _ROCM_TORCH_INDEX: dict[tuple[int, int], str] = {
|
|||
(6, 1): "rocm6.1",
|
||||
(6, 0): "rocm6.0",
|
||||
}
|
||||
_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl"
|
||||
_PYTORCH_WHL_BASE = (
|
||||
os.environ.get("UNSLOTH_PYTORCH_MIRROR") or "https://download.pytorch.org/whl"
|
||||
).rstrip("/")
|
||||
|
||||
# bitsandbytes continuous-release_main wheels with the ROCm 4-bit GEMV fix
|
||||
# (bnb PR #1887, post-0.49.2). bnb <= 0.49.2 NaNs at decode shape on every
|
||||
|
|
|
|||
|
|
@ -1517,14 +1517,16 @@ if ($HasNvidiaSmi) {
|
|||
$CuTag = "cpu"
|
||||
}
|
||||
|
||||
$PyTorchWhlBase = if ($env:UNSLOTH_PYTORCH_MIRROR) { $env:UNSLOTH_PYTORCH_MIRROR.TrimEnd('/') } else { "https://download.pytorch.org/whl" }
|
||||
|
||||
if ($CuTag -eq "cpu") {
|
||||
substep "installing PyTorch (CPU-only)..."
|
||||
if ($script:UnslothVerbose) {
|
||||
Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/cpu"
|
||||
$torchInstallExit = $LASTEXITCODE
|
||||
$output = ""
|
||||
} else {
|
||||
$output = Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cpu" | Out-String
|
||||
$output = Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/cpu" | Out-String
|
||||
$torchInstallExit = $LASTEXITCODE
|
||||
}
|
||||
if ($torchInstallExit -ne 0) {
|
||||
|
|
@ -1536,11 +1538,11 @@ if ($CuTag -eq "cpu") {
|
|||
substep "installing PyTorch with CUDA support ($CuTag)..."
|
||||
substep "(This download is ~2.8 GB -- may take a few minutes)"
|
||||
if ($script:UnslothVerbose) {
|
||||
Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/$CuTag"
|
||||
Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/$CuTag"
|
||||
$torchInstallExit = $LASTEXITCODE
|
||||
$output = ""
|
||||
} else {
|
||||
$output = Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/$CuTag" | Out-String
|
||||
$output = Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/$CuTag" | Out-String
|
||||
$torchInstallExit = $LASTEXITCODE
|
||||
}
|
||||
if ($torchInstallExit -ne 0) {
|
||||
|
|
|
|||
|
|
@ -135,3 +135,19 @@ class TestCudaMappingParity:
|
|||
f" install.sh: {sh_thresholds}\n"
|
||||
f" install.ps1: {ps1_thresholds}"
|
||||
)
|
||||
|
||||
|
||||
class TestPyTorchMirrorEnvVar:
|
||||
"""Both install scripts must support the UNSLOTH_PYTORCH_MIRROR env var."""
|
||||
|
||||
def test_install_sh_has_mirror_var(self):
|
||||
text = INSTALL_SH.read_text()
|
||||
assert (
|
||||
"UNSLOTH_PYTORCH_MIRROR" in text
|
||||
), "install.sh should reference UNSLOTH_PYTORCH_MIRROR"
|
||||
|
||||
def test_install_ps1_has_mirror_var(self):
|
||||
text = INSTALL_PS1.read_text()
|
||||
assert (
|
||||
"UNSLOTH_PYTORCH_MIRROR" in text
|
||||
), "install.ps1 should reference UNSLOTH_PYTORCH_MIRROR"
|
||||
|
|
|
|||
|
|
@ -7,14 +7,19 @@ INSTALL_SH="$SCRIPT_DIR/../../install.sh"
|
|||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
# Extract only the get_torch_index_url function from install.sh
|
||||
# Extract get_torch_index_url and its helper functions from install.sh.
|
||||
# Also replace the hardcoded /usr/bin/nvidia-smi fallback with a
|
||||
# controllable path so we can test the "no GPU" scenario on GPU machines.
|
||||
_FUNC_FILE=$(mktemp)
|
||||
_FAKE_SMI_DIR=$(mktemp -d)
|
||||
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH" \
|
||||
| sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
|
||||
> "$_FUNC_FILE"
|
||||
{
|
||||
sed -n '/^_has_amd_rocm_gpu()/,/^}/p' "$INSTALL_SH"
|
||||
echo ""
|
||||
sed -n '/^_has_usable_nvidia_gpu()/,/^}/p' "$INSTALL_SH"
|
||||
echo ""
|
||||
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH"
|
||||
} | sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
|
||||
> "$_FUNC_FILE"
|
||||
|
||||
# Save system PATH so we always have basic tools (uname, grep, head, etc.)
|
||||
_SYS_PATH="/usr/local/bin:/usr/bin:/bin"
|
||||
|
|
@ -30,16 +35,25 @@ assert_eq() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Helper: create a mock nvidia-smi that prints a given CUDA version string
|
||||
# Helper: create a mock nvidia-smi that prints a given CUDA version string.
|
||||
# Handles both default output (version header) and -L (GPU listing) so that
|
||||
# _has_usable_nvidia_gpu sees a valid GPU.
|
||||
make_mock_smi() {
|
||||
_dir=$(mktemp -d)
|
||||
cat > "$_dir/nvidia-smi" <<MOCK
|
||||
#!/bin/sh
|
||||
cat <<'SMI_OUT'
|
||||
case "\$1" in
|
||||
-L)
|
||||
echo "GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-fake-uuid)"
|
||||
;;
|
||||
*)
|
||||
cat <<'SMI_OUT'
|
||||
+-----------------------------------------------------------------------------------------+
|
||||
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: $1 |
|
||||
+-----------------------------------------------------------------------------------------+
|
||||
SMI_OUT
|
||||
;;
|
||||
esac
|
||||
MOCK
|
||||
chmod +x "$_dir/nvidia-smi"
|
||||
echo "$_dir"
|
||||
|
|
@ -130,11 +144,14 @@ _result=$(run_func "$_dir")
|
|||
assert_eq "CUDA 10.2 -> cpu" "https://download.pytorch.org/whl/cpu" "$_result"
|
||||
rm -rf "$_dir"
|
||||
|
||||
# 8) Unparseable nvidia-smi output -> cu126 default
|
||||
# 8) Unparseable nvidia-smi version but valid GPU listing -> cu126 default
|
||||
_dir=$(mktemp -d)
|
||||
cat > "$_dir/nvidia-smi" <<'MOCK'
|
||||
#!/bin/sh
|
||||
echo "something completely unexpected"
|
||||
case "$1" in
|
||||
-L) echo "GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-fake-uuid)" ;;
|
||||
*) echo "something completely unexpected" ;;
|
||||
esac
|
||||
MOCK
|
||||
chmod +x "$_dir/nvidia-smi"
|
||||
_result=$(run_func "$_dir")
|
||||
|
|
@ -243,6 +260,24 @@ _result=$(run_func "$_dir")
|
|||
assert_eq "CUDA 12.8 regression -> cu128" "https://download.pytorch.org/whl/cu128" "$_result"
|
||||
rm -rf "$_dir"
|
||||
|
||||
# 25) UNSLOTH_PYTORCH_MIRROR overrides base URL (CUDA case)
|
||||
_dir=$(make_mock_smi "12.6")
|
||||
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" run_func "$_dir")
|
||||
assert_eq "mirror env + CUDA 12.6 -> mirror/cu126" "https://mirror.example.com/whl/cu126" "$_result"
|
||||
rm -rf "$_dir"
|
||||
|
||||
# 26) UNSLOTH_PYTORCH_MIRROR overrides base URL (CPU case)
|
||||
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" run_func "none")
|
||||
assert_eq "mirror env + no GPU -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
|
||||
|
||||
# 27) Empty UNSLOTH_PYTORCH_MIRROR falls back to official URL
|
||||
_result=$(UNSLOTH_PYTORCH_MIRROR="" run_func "none")
|
||||
assert_eq "empty mirror env -> official/cpu" "https://download.pytorch.org/whl/cpu" "$_result"
|
||||
|
||||
# 28) Trailing slash in UNSLOTH_PYTORCH_MIRROR is stripped
|
||||
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl/" run_func "none")
|
||||
assert_eq "trailing slash stripped -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
|
||||
|
||||
rm -f "$_FUNC_FILE"
|
||||
rm -rf "$_FAKE_SMI_DIR"
|
||||
rm -rf "$_TOOLS_DIR"
|
||||
|
|
|
|||
|
|
@ -206,6 +206,10 @@ assert_eq "Darwin -> cpu (even with nvidia-smi)" "https://download.pytorch.org/w
|
|||
_result=$(PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
|
||||
assert_eq "Darwin -> cpu (no nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
|
||||
|
||||
# Test: Darwin + UNSLOTH_PYTORCH_MIRROR produces mirror/cpu
|
||||
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
|
||||
assert_eq "Darwin + mirror env -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
|
||||
|
||||
rm -f "$_FUNC_FILE"
|
||||
rm -rf "$_FAKE_SMI_DIR" "$_TOOLS_DIR" "$_MOCK_UNAME_DIR" "$_GPU_DIR"
|
||||
|
||||
|
|
@ -242,15 +246,6 @@ else
|
|||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify the PyTorch skip message exists (now covers both --no-torch and Intel Mac)
|
||||
if grep -q 'Skipping PyTorch' "$INSTALL_SH"; then
|
||||
echo " PASS: PyTorch skip message found"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: PyTorch skip message not found"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Verify SKIP_TORCH unified variable exists
|
||||
if grep -q 'SKIP_TORCH=true' "$INSTALL_SH"; then
|
||||
echo " PASS: SKIP_TORCH=true assignment found"
|
||||
|
|
@ -558,16 +553,7 @@ assert_eq "MAC_INTEL=true alone sets SKIP_TORCH=true" "true" "$_result"
|
|||
rm -f "$_SKIP_SNIPPET" "$_SKIP_SNIPPET2"
|
||||
|
||||
echo ""
|
||||
echo "=== CPU hint printing ==="
|
||||
|
||||
# Verify the CPU hint is present in install.sh source
|
||||
if grep -q 'No NVIDIA GPU detected' "$INSTALL_SH"; then
|
||||
echo " PASS: CPU hint message found in install.sh"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " FAIL: CPU hint message not found in install.sh"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
echo "=== --no-torch flag in install.sh ==="
|
||||
|
||||
if grep -q '\-\-no-torch' "$INSTALL_SH"; then
|
||||
echo " PASS: --no-torch appears in install.sh"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue