Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var (#5024)

* Add configurable PyTorch mirror via UNSLOTH_PYTORCH_MIRROR env var

When set, UNSLOTH_PYTORCH_MIRROR overrides the default
https://download.pytorch.org/whl base URL in all four install scripts
(install.sh, install.ps1, studio/setup.ps1, studio/install_python_stack.py).
When unset or empty, the official URL is used. This lets users behind
corporate proxies or in regions with poor connectivity to pytorch.org
point at a local mirror without patching scripts.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add pytest for UNSLOTH_PYTORCH_MIRROR in install_python_stack.py

Tests that _PYTORCH_WHL_BASE picks up the env var when set, falls back
to the official URL when unset or empty, and preserves the value as-is
(including trailing slashes).

* Remove stale test assertions for missing install.sh messages

* Fix GPU mocking in test_get_torch_index_url.sh

Extract _has_usable_nvidia_gpu and _has_amd_rocm_gpu alongside
get_torch_index_url so the GPU-presence checks work in tests.
Add -L flag handling to mock nvidia-smi so it passes the GPU listing
check. All 26 tests now pass on CPU-only machines.

* Strip trailing slash from UNSLOTH_PYTORCH_MIRROR to avoid double-slash URLs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Roland Tannous 2026-04-15 11:39:11 +04:00 committed by GitHub
parent 826c98f3c0
commit 13928b5f0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 131 additions and 34 deletions

View file

@ -754,7 +754,7 @@ shell.Run cmd, 0, False
# ── Choose the correct PyTorch index URL based on driver CUDA version ──
# Mirrors Get-PytorchCudaTag in setup.ps1.
function Get-TorchIndexUrl {
$baseUrl = "https://download.pytorch.org/whl"
$baseUrl = if ($env:UNSLOTH_PYTORCH_MIRROR) { $env:UNSLOTH_PYTORCH_MIRROR.TrimEnd('/') } else { "https://download.pytorch.org/whl" }
if (-not $NvidiaSmiExe) { return "$baseUrl/cpu" }
try {
$output = & $NvidiaSmiExe 2>&1 | Out-String

View file

@ -1053,7 +1053,8 @@ _has_usable_nvidia_gpu() {
# On CPU-only machines this returns the cpu index, avoiding the solver
# dead-end where --torch-backend=auto resolves to unsloth==2024.8.
get_torch_index_url() {
_base="https://download.pytorch.org/whl"
_base="${UNSLOTH_PYTORCH_MIRROR:-https://download.pytorch.org/whl}"
_base="${_base%/}"
# macOS: always CPU (no CUDA support)
case "$(uname -s)" in Darwin) echo "$_base/cpu"; return ;; esac
# Try nvidia-smi -- require the binary to actually list a usable GPU.

View file

@ -0,0 +1,55 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
"""Tests for UNSLOTH_PYTORCH_MIRROR env var in install_python_stack.py."""
from __future__ import annotations
import importlib
import os
import sys
from pathlib import Path
import pytest
# install_python_stack.py lives at repo_root/studio/install_python_stack.py
_INSTALL_SCRIPT = Path(__file__).resolve().parents[2] / "install_python_stack.py"
OFFICIAL_URL = "https://download.pytorch.org/whl"
def _reload_whl_base(monkeypatch, mirror_value = None):
"""(Re-)import install_python_stack with a controlled env and return _PYTORCH_WHL_BASE."""
# Remove cached module so the module-level assignment re-executes
sys.modules.pop("install_python_stack", None)
if mirror_value is None:
monkeypatch.delenv("UNSLOTH_PYTORCH_MIRROR", raising = False)
else:
monkeypatch.setenv("UNSLOTH_PYTORCH_MIRROR", mirror_value)
# Temporarily add the script's directory to sys.path for import
script_dir = str(_INSTALL_SCRIPT.parent)
monkeypatch.syspath_prepend(script_dir)
import install_python_stack
return install_python_stack._PYTORCH_WHL_BASE
class TestPyTorchMirrorEnvVar:
"""UNSLOTH_PYTORCH_MIRROR controls _PYTORCH_WHL_BASE in install_python_stack."""
def test_unset_uses_official_url(self, monkeypatch):
assert _reload_whl_base(monkeypatch) == OFFICIAL_URL
def test_empty_string_falls_back_to_official(self, monkeypatch):
assert _reload_whl_base(monkeypatch, "") == OFFICIAL_URL
def test_custom_mirror_is_used(self, monkeypatch):
mirror = "https://mirrors.nju.edu.cn/pytorch/whl"
assert _reload_whl_base(monkeypatch, mirror) == mirror
def test_trailing_slash_stripped(self, monkeypatch):
result = _reload_whl_base(monkeypatch, "https://example.com/whl/")
assert result == "https://example.com/whl"

View file

@ -49,7 +49,9 @@ _ROCM_TORCH_INDEX: dict[tuple[int, int], str] = {
(6, 1): "rocm6.1",
(6, 0): "rocm6.0",
}
_PYTORCH_WHL_BASE = "https://download.pytorch.org/whl"
_PYTORCH_WHL_BASE = (
os.environ.get("UNSLOTH_PYTORCH_MIRROR") or "https://download.pytorch.org/whl"
).rstrip("/")
# bitsandbytes continuous-release_main wheels with the ROCm 4-bit GEMV fix
# (bnb PR #1887, post-0.49.2). bnb <= 0.49.2 NaNs at decode shape on every

View file

@ -1517,14 +1517,16 @@ if ($HasNvidiaSmi) {
$CuTag = "cpu"
}
$PyTorchWhlBase = if ($env:UNSLOTH_PYTORCH_MIRROR) { $env:UNSLOTH_PYTORCH_MIRROR.TrimEnd('/') } else { "https://download.pytorch.org/whl" }
if ($CuTag -eq "cpu") {
substep "installing PyTorch (CPU-only)..."
if ($script:UnslothVerbose) {
Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cpu"
Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/cpu"
$torchInstallExit = $LASTEXITCODE
$output = ""
} else {
$output = Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cpu" | Out-String
$output = Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/cpu" | Out-String
$torchInstallExit = $LASTEXITCODE
}
if ($torchInstallExit -ne 0) {
@ -1536,11 +1538,11 @@ if ($CuTag -eq "cpu") {
substep "installing PyTorch with CUDA support ($CuTag)..."
substep "(This download is ~2.8 GB -- may take a few minutes)"
if ($script:UnslothVerbose) {
Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/$CuTag"
Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/$CuTag"
$torchInstallExit = $LASTEXITCODE
$output = ""
} else {
$output = Fast-Install torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/$CuTag" | Out-String
$output = Fast-Install torch torchvision torchaudio --index-url "$PyTorchWhlBase/$CuTag" | Out-String
$torchInstallExit = $LASTEXITCODE
}
if ($torchInstallExit -ne 0) {

View file

@ -135,3 +135,19 @@ class TestCudaMappingParity:
f" install.sh: {sh_thresholds}\n"
f" install.ps1: {ps1_thresholds}"
)
class TestPyTorchMirrorEnvVar:
"""Both install scripts must support the UNSLOTH_PYTORCH_MIRROR env var."""
def test_install_sh_has_mirror_var(self):
text = INSTALL_SH.read_text()
assert (
"UNSLOTH_PYTORCH_MIRROR" in text
), "install.sh should reference UNSLOTH_PYTORCH_MIRROR"
def test_install_ps1_has_mirror_var(self):
text = INSTALL_PS1.read_text()
assert (
"UNSLOTH_PYTORCH_MIRROR" in text
), "install.ps1 should reference UNSLOTH_PYTORCH_MIRROR"

View file

@ -7,14 +7,19 @@ INSTALL_SH="$SCRIPT_DIR/../../install.sh"
PASS=0
FAIL=0
# Extract only the get_torch_index_url function from install.sh
# Extract get_torch_index_url and its helper functions from install.sh.
# Also replace the hardcoded /usr/bin/nvidia-smi fallback with a
# controllable path so we can test the "no GPU" scenario on GPU machines.
_FUNC_FILE=$(mktemp)
_FAKE_SMI_DIR=$(mktemp -d)
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH" \
| sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
> "$_FUNC_FILE"
{
sed -n '/^_has_amd_rocm_gpu()/,/^}/p' "$INSTALL_SH"
echo ""
sed -n '/^_has_usable_nvidia_gpu()/,/^}/p' "$INSTALL_SH"
echo ""
sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH"
} | sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \
> "$_FUNC_FILE"
# Save system PATH so we always have basic tools (uname, grep, head, etc.)
_SYS_PATH="/usr/local/bin:/usr/bin:/bin"
@ -30,16 +35,25 @@ assert_eq() {
fi
}
# Helper: create a mock nvidia-smi that prints a given CUDA version string
# Helper: create a mock nvidia-smi that prints a given CUDA version string.
# Handles both default output (version header) and -L (GPU listing) so that
# _has_usable_nvidia_gpu sees a valid GPU.
make_mock_smi() {
_dir=$(mktemp -d)
cat > "$_dir/nvidia-smi" <<MOCK
#!/bin/sh
cat <<'SMI_OUT'
case "\$1" in
-L)
echo "GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-fake-uuid)"
;;
*)
cat <<'SMI_OUT'
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: $1 |
+-----------------------------------------------------------------------------------------+
SMI_OUT
;;
esac
MOCK
chmod +x "$_dir/nvidia-smi"
echo "$_dir"
@ -130,11 +144,14 @@ _result=$(run_func "$_dir")
assert_eq "CUDA 10.2 -> cpu" "https://download.pytorch.org/whl/cpu" "$_result"
rm -rf "$_dir"
# 8) Unparseable nvidia-smi output -> cu126 default
# 8) Unparseable nvidia-smi version but valid GPU listing -> cu126 default
_dir=$(mktemp -d)
cat > "$_dir/nvidia-smi" <<'MOCK'
#!/bin/sh
echo "something completely unexpected"
case "$1" in
-L) echo "GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-fake-uuid)" ;;
*) echo "something completely unexpected" ;;
esac
MOCK
chmod +x "$_dir/nvidia-smi"
_result=$(run_func "$_dir")
@ -243,6 +260,24 @@ _result=$(run_func "$_dir")
assert_eq "CUDA 12.8 regression -> cu128" "https://download.pytorch.org/whl/cu128" "$_result"
rm -rf "$_dir"
# 25) UNSLOTH_PYTORCH_MIRROR overrides base URL (CUDA case)
_dir=$(make_mock_smi "12.6")
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" run_func "$_dir")
assert_eq "mirror env + CUDA 12.6 -> mirror/cu126" "https://mirror.example.com/whl/cu126" "$_result"
rm -rf "$_dir"
# 26) UNSLOTH_PYTORCH_MIRROR overrides base URL (CPU case)
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" run_func "none")
assert_eq "mirror env + no GPU -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
# 27) Empty UNSLOTH_PYTORCH_MIRROR falls back to official URL
_result=$(UNSLOTH_PYTORCH_MIRROR="" run_func "none")
assert_eq "empty mirror env -> official/cpu" "https://download.pytorch.org/whl/cpu" "$_result"
# 28) Trailing slash in UNSLOTH_PYTORCH_MIRROR is stripped
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl/" run_func "none")
assert_eq "trailing slash stripped -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
rm -f "$_FUNC_FILE"
rm -rf "$_FAKE_SMI_DIR"
rm -rf "$_TOOLS_DIR"

View file

@ -206,6 +206,10 @@ assert_eq "Darwin -> cpu (even with nvidia-smi)" "https://download.pytorch.org/w
_result=$(PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
assert_eq "Darwin -> cpu (no nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result"
# Test: Darwin + UNSLOTH_PYTORCH_MIRROR produces mirror/cpu
_result=$(UNSLOTH_PYTORCH_MIRROR="https://mirror.example.com/whl" PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null)
assert_eq "Darwin + mirror env -> mirror/cpu" "https://mirror.example.com/whl/cpu" "$_result"
rm -f "$_FUNC_FILE"
rm -rf "$_FAKE_SMI_DIR" "$_TOOLS_DIR" "$_MOCK_UNAME_DIR" "$_GPU_DIR"
@ -242,15 +246,6 @@ else
FAIL=$((FAIL + 1))
fi
# Verify the PyTorch skip message exists (now covers both --no-torch and Intel Mac)
if grep -q 'Skipping PyTorch' "$INSTALL_SH"; then
echo " PASS: PyTorch skip message found"
PASS=$((PASS + 1))
else
echo " FAIL: PyTorch skip message not found"
FAIL=$((FAIL + 1))
fi
# Verify SKIP_TORCH unified variable exists
if grep -q 'SKIP_TORCH=true' "$INSTALL_SH"; then
echo " PASS: SKIP_TORCH=true assignment found"
@ -558,16 +553,7 @@ assert_eq "MAC_INTEL=true alone sets SKIP_TORCH=true" "true" "$_result"
rm -f "$_SKIP_SNIPPET" "$_SKIP_SNIPPET2"
echo ""
echo "=== CPU hint printing ==="
# Verify the CPU hint is present in install.sh source
if grep -q 'No NVIDIA GPU detected' "$INSTALL_SH"; then
echo " PASS: CPU hint message found in install.sh"
PASS=$((PASS + 1))
else
echo " FAIL: CPU hint message not found in install.sh"
FAIL=$((FAIL + 1))
fi
echo "=== --no-torch flag in install.sh ==="
if grep -q '\-\-no-torch' "$INSTALL_SH"; then
echo " PASS: --no-torch appears in install.sh"