tests: add no-torch / Intel Mac test suite (#4646)

* tests: add no-torch / Intel Mac test suite

Add comprehensive test coverage for the no-torch / --no-torch installer
and Studio backend changes introduced in #4624.

Shell tests (tests/sh/test_mac_intel_compat.sh):
- version_ge edge cases (9 tests)
- Architecture detection + Python version resolution (4 tests)
- get_torch_index_url on Darwin (2 tests)
- UNSLOTH_NO_TORCH propagation via SKIP_TORCH (5 tests)
- E2E uv venv creation at Python 3.12 (3 tests)
- E2E torch skip with mock uv shim (4 tests)
- UNSLOTH_NO_TORCH env propagation (4 tests)
- --python override flag parsing + resolution (11 tests)
- --no-torch flag parsing (4 tests)
- SKIP_TORCH unification (3 tests)
- CPU hint printing (2 tests)

Python tests (tests/python/test_no_torch_filtering.py):
- _filter_requirements unit tests with synthetic + real requirements files
- NO_TORCH / IS_MACOS constant parsing
- Subprocess mock of install_python_stack() across platform configs
- install.sh --no-torch flag structural + subprocess tests

Python tests (tests/python/test_studio_import_no_torch.py):
- AST checks for data_collators.py, chat_templates.py, format_conversion.py
- Parametrized venv tests (Python 3.12 + 3.13) for no-torch exec
- Dataclass instantiation without torch
- format_conversion convert functions without torch
- Negative controls (import torch fails, torchao fails)

Python tests (tests/python/test_e2e_no_torch_sandbox.py):
- Before/after import chain tests
- Edge cases (broken torch, fake torch, lazy import)
- Hardware detection without torch
- install.sh logic tests (flag parsing, version resolution)
- install_python_stack filtering tests
- Live server startup tests (opt-in via @server marker)

* fix: address review comments on test suite

- Fix always-true assertion in test_studio_import_no_torch.py (or True)
- Make IS_MACOS test platform-aware instead of hardcoding Linux
- Restore torchvision + torchaudio in server test cleanup (not just torch)
- Include server stderr in skip message for easier debugging

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-03-27 02:33:45 -07:00 committed by GitHub
parent e9ac785346
commit 2ffc8d2cea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 3166 additions and 0 deletions

7
tests/python/conftest.py Normal file
View file

@ -0,0 +1,7 @@
"""Shared pytest configuration for tests/python/."""
def pytest_configure(config):
config.addinivalue_line(
"markers", "server: heavyweight tests requiring studio venv"
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,753 @@
"""Tests for install_python_stack NO_TORCH / IS_MACOS filtering logic.
Covers:
- _filter_requirements unit tests (synthetic + REAL requirements files)
- NO_TORCH / IS_MACOS / IS_WINDOWS env var parsing
- Subprocess-mock of install_python_stack() to verify overrides/triton/filtering
actually happen (or get skipped) under each platform/config combination
- VCS URL and environment marker edge cases in filtering
"""
from __future__ import annotations
import importlib
import os
import re
import subprocess
import sys
import textwrap
from pathlib import Path
from unittest import mock
import pytest
# Add the studio directory so we can import install_python_stack
STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio"
sys.path.insert(0, str(STUDIO_DIR))
import install_python_stack as ips
# Paths to the REAL requirements files
REQ_ROOT = Path(__file__).resolve().parents[2] / "studio" / "backend" / "requirements"
EXTRAS_TXT = REQ_ROOT / "extras.txt"
EXTRAS_NO_DEPS_TXT = REQ_ROOT / "extras-no-deps.txt"
OVERRIDES_TXT = REQ_ROOT / "overrides.txt"
TRITON_KERNELS_TXT = REQ_ROOT / "triton-kernels.txt"
# ── _filter_requirements unit tests (synthetic) ───────────────────────
class TestFilterRequirements:
"""Verify _filter_requirements correctly removes packages by prefix."""
def _write_req(self, tmp_path: Path, content: str) -> Path:
req = tmp_path / "requirements.txt"
req.write_text(textwrap.dedent(content), encoding = "utf-8")
return req
def test_filters_no_torch_packages(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
torch-stoi==0.1
timm>=1.0
numpy
torchcodec>=0.1
torch-c-dlpack-ext
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
# Only numpy should remain (non-blank lines)
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"], f"Expected only numpy, got: {non_blank}"
def test_empty_file(self, tmp_path):
req = self._write_req(tmp_path, "")
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
content = Path(result).read_text(encoding = "utf-8")
assert content.strip() == ""
def test_comments_preserved(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
# torch-stoi is needed for audio
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
# Comment starts with "#", not "torch-stoi", so it's preserved
assert len(non_blank) == 2
assert non_blank[0].startswith("#")
assert non_blank[1] == "numpy"
def test_version_specifiers_filtered(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
torch-stoi>=0.1.0
timm==1.2.3
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [], f"Expected empty, got: {non_blank}"
def test_prefix_match_catches_extensions(self, tmp_path):
"""Prefix matching catches torch-stoi-extra (correct for pip names)."""
req = self._write_req(
tmp_path,
"""\
torch-stoi-extra
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"]
def test_mixed_case_filtered(self, tmp_path):
"""Package names are lowercased before matching."""
req = self._write_req(
tmp_path,
"""\
Timm>=1.0
TORCH-STOI
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == ["numpy"]
def test_whitespace_and_blank_lines_preserved(self, tmp_path):
req = self._write_req(
tmp_path,
"""\
numpy
pandas
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
content = Path(result).read_text(encoding = "utf-8")
# Blank lines should be preserved (not stripped)
assert "\n\n" in content or content.count("\n") >= 3
def test_stacked_windows_and_no_torch_filters(self, tmp_path):
"""Both WINDOWS_SKIP_PACKAGES and NO_TORCH_SKIP_PACKAGES applied."""
req = self._write_req(
tmp_path,
"""\
open_spiel
triton_kernels
torch-stoi
timm
numpy
""",
)
# First filter Windows packages, then NO_TORCH packages
intermediate = ips._filter_requirements(req, ips.WINDOWS_SKIP_PACKAGES)
result = ips._filter_requirements(
Path(intermediate), ips.NO_TORCH_SKIP_PACKAGES
)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"Expected only numpy after stacked filters, got: {non_blank}"
def test_vcs_url_with_skip_package_name(self, tmp_path):
"""VCS URLs like git+https://...torch-stoi should also be filtered (startswith matches)."""
req = self._write_req(
tmp_path,
"""\
numpy
torch-stoi @ git+https://github.com/example/torch-stoi.git
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"VCS URL line should be filtered, got: {non_blank}"
def test_env_marker_line_filtered(self, tmp_path):
"""Package lines with env markers are still filtered by prefix."""
req = self._write_req(
tmp_path,
"""\
timm>=1.0; python_version>="3.10"
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
assert non_blank == [
"numpy"
], f"Env marker line should be filtered, got: {non_blank}"
def test_git_plus_url_not_over_matched(self, tmp_path):
"""A git+ URL whose path contains a skip package name but does NOT start with it."""
req = self._write_req(
tmp_path,
"""\
git+https://github.com/meta-pytorch/OpenEnv.git
numpy
""",
)
result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES)
lines = Path(result).read_text(encoding = "utf-8").splitlines()
non_blank = [l.strip() for l in lines if l.strip()]
# The git+ URL doesn't start with any skip package, so it is preserved
assert len(non_blank) == 2, f"git+ URL should be preserved, got: {non_blank}"
# ── Real requirements file filtering ──────────────────────────────────
class TestRealRequirementsFiltering:
"""Filter the ACTUAL extras.txt and extras-no-deps.txt with NO_TORCH_SKIP_PACKAGES."""
@pytest.fixture(autouse = True)
def _check_req_files(self):
if not EXTRAS_TXT.is_file():
pytest.skip("extras.txt not found in repo")
if not EXTRAS_NO_DEPS_TXT.is_file():
pytest.skip("extras-no-deps.txt not found in repo")
def _non_blank_non_comment(self, path: Path) -> list[str]:
"""Return non-blank, non-comment lines from a requirements file."""
lines = path.read_text(encoding = "utf-8").splitlines()
return [l.strip() for l in lines if l.strip() and not l.strip().startswith("#")]
def test_extras_txt_torch_packages_removed(self):
"""extras.txt: all NO_TORCH_SKIP_PACKAGES must be removed, everything else preserved."""
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
filtered = self._non_blank_non_comment(Path(result))
original = self._non_blank_non_comment(EXTRAS_TXT)
# These must be gone
for pkg in ["torch-stoi", "timm", "openai-whisper", "transformers-cfg"]:
assert not any(
l.lower().startswith(pkg) for l in filtered
), f"{pkg} should be removed from extras.txt"
# Everything else must remain
expected = [
l
for l in original
if not any(
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
)
]
assert filtered == expected, (
f"Filtered extras.txt should match expected.\n"
f"Missing: {set(expected) - set(filtered)}\n"
f"Extra: {set(filtered) - set(expected)}"
)
def test_extras_no_deps_txt_torchcodec_and_dlpack_removed(self):
"""extras-no-deps.txt: torchcodec and torch-c-dlpack-ext must be removed."""
result = ips._filter_requirements(
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
)
filtered = self._non_blank_non_comment(Path(result))
original = self._non_blank_non_comment(EXTRAS_NO_DEPS_TXT)
for pkg in ["torchcodec", "torch-c-dlpack-ext"]:
assert not any(
l.lower().startswith(pkg) for l in filtered
), f"{pkg} should be removed from extras-no-deps.txt"
expected = [
l
for l in original
if not any(
l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES
)
]
assert filtered == expected
def test_extras_txt_most_packages_preserved(self):
"""Ensure a representative set of non-torch packages survive filtering."""
result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES)
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
must_survive = ["scikit-learn", "loguru", "tiktoken", "einops", "tabulate"]
for pkg in must_survive:
if pkg in EXTRAS_TXT.read_text(encoding = "utf-8").lower():
assert pkg in filtered_text, f"{pkg} should survive NO_TORCH filtering"
def test_extras_no_deps_txt_trl_preserved(self):
"""trl should survive NO_TORCH filtering in extras-no-deps.txt."""
result = ips._filter_requirements(
EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES
)
filtered_text = Path(result).read_text(encoding = "utf-8").lower()
assert "trl" in filtered_text, "trl should survive NO_TORCH filtering"
# ── NO_TORCH constant tests ──────────────────────────────────────────
class TestNoTorchConstant:
"""Verify NO_TORCH is derived correctly from UNSLOTH_NO_TORCH env var."""
def _reimport_no_torch(self) -> bool:
return os.environ.get("UNSLOTH_NO_TORCH", "false").lower() in ("1", "true")
def test_true_lowercase(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "true"}):
assert self._reimport_no_torch() is True
def test_true_one(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "1"}):
assert self._reimport_no_torch() is True
def test_true_uppercase(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "TRUE"}):
assert self._reimport_no_torch() is True
def test_false_string(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}):
assert self._reimport_no_torch() is False
def test_false_zero(self):
with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "0"}):
assert self._reimport_no_torch() is False
def test_not_set(self):
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with mock.patch.dict(os.environ, env, clear = True):
assert self._reimport_no_torch() is False
def test_infer_no_torch_on_intel_mac(self):
"""_infer_no_torch falls back to platform detection when env var is unset."""
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with (
mock.patch.dict(os.environ, env, clear = True),
mock.patch.object(ips, "IS_MAC_INTEL", True),
):
assert ips._infer_no_torch() is True
def test_infer_no_torch_respects_explicit_false_on_intel_mac(self):
"""Explicit UNSLOTH_NO_TORCH=false overrides platform detection."""
with (
mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}),
mock.patch.object(ips, "IS_MAC_INTEL", True),
):
assert ips._infer_no_torch() is False
def test_infer_no_torch_linux_unset(self):
"""On Linux with env var unset, _infer_no_torch returns False."""
env = os.environ.copy()
env.pop("UNSLOTH_NO_TORCH", None)
with (
mock.patch.dict(os.environ, env, clear = True),
mock.patch.object(ips, "IS_MAC_INTEL", False),
):
assert ips._infer_no_torch() is False
# ── IS_MACOS constant tests ──────────────────────────────────────────
class TestIsMacosConstant:
"""Verify IS_MACOS detection logic."""
def test_is_macos_matches_platform(self):
import sys
expected = sys.platform == "darwin"
assert ips.IS_MACOS is expected
# ── Subprocess mock of install_python_stack() ─────────────────────────
class TestInstallPythonStackSubprocessMock:
"""Monkeypatch subprocess.run to capture all pip/uv commands,
then verify which requirements files are used/skipped under
different NO_TORCH / IS_MACOS / IS_WINDOWS configurations."""
@pytest.fixture(autouse = True)
def _check_req_files(self):
"""Skip if requirements files are missing."""
for f in [EXTRAS_TXT, EXTRAS_NO_DEPS_TXT, OVERRIDES_TXT]:
if not f.is_file():
pytest.skip(f"{f.name} not found in repo")
def _capture_install(
self,
no_torch: bool,
is_macos: bool,
is_windows: bool,
*,
skip_base: bool = True,
):
"""Run install_python_stack() with mocked subprocess, capturing all commands.
Returns a list of string-joined commands (each element is ' '.join(cmd)).
"""
captured_cmds: list[list[str]] = []
def mock_run(cmd, **kw):
captured_cmds.append(
list(cmd) if isinstance(cmd, (list, tuple)) else [str(cmd)]
)
return subprocess.CompletedProcess(cmd, 0, b"", b"")
env = {"SKIP_STUDIO_BASE": "1"} if skip_base else {}
with (
mock.patch.object(ips, "NO_TORCH", no_torch),
mock.patch.object(ips, "IS_MACOS", is_macos),
mock.patch.object(ips, "IS_WINDOWS", is_windows),
mock.patch.object(ips, "USE_UV", True),
mock.patch.object(ips, "UV_NEEDS_SYSTEM", False),
mock.patch.object(ips, "VERBOSE", False),
mock.patch("subprocess.run", side_effect = mock_run),
mock.patch.object(ips, "_bootstrap_uv", return_value = True),
mock.patch.object(
ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin")
),
mock.patch("pathlib.Path.is_dir", return_value = True),
mock.patch("pathlib.Path.is_file", return_value = True),
):
with mock.patch.dict(os.environ, env, clear = False):
ips.install_python_stack()
return [" ".join(str(c) for c in cmd) for cmd in captured_cmds]
def _cmds_contain_file(self, cmds: list[str], filename: str) -> bool:
"""Check if any captured command references the given filename."""
return any(filename in cmd for cmd in cmds)
# -- NO_TORCH=True, IS_MACOS=True (Intel Mac scenario) --
def test_no_torch_macos_skips_overrides(self):
"""With NO_TORCH=True, overrides.txt pip_install must NOT be called."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped when NO_TORCH=True"
def test_no_torch_macos_skips_triton(self):
"""With IS_MACOS=True, triton-kernels.txt must NOT be called."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on macOS"
def test_no_torch_macos_extras_called(self):
"""With NO_TORCH=True, extras.txt is still called (but filtered)."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
has_extras = self._cmds_contain_file(cmds, "extras.txt") or any(
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
)
assert has_extras, "extras.txt (or its filtered temp) should be called"
def test_no_torch_macos_extras_no_deps_called(self):
"""With NO_TORCH=True, extras-no-deps.txt is still called (but filtered)."""
cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False)
has_extras_nd = self._cmds_contain_file(cmds, "extras-no-deps.txt") or any(
"-r" in cmd and "tmp" in cmd.lower() for cmd in cmds
)
assert (
has_extras_nd
), "extras-no-deps.txt (or its filtered temp) should be called"
# -- IS_WINDOWS=True + NO_TORCH=True (stacked) --
def test_windows_no_torch_skips_overrides(self):
"""Windows+NO_TORCH: overrides.txt must be skipped."""
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped with NO_TORCH=True on Windows"
def test_windows_no_torch_skips_triton(self):
"""Windows: triton-kernels.txt must be skipped (IS_WINDOWS guard)."""
cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on Windows"
# -- Normal Linux path (NO_TORCH=False, IS_MACOS=False, IS_WINDOWS=False) --
def test_normal_linux_includes_overrides(self):
"""Normal Linux: overrides.txt IS called."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be called on normal Linux"
def test_normal_linux_includes_triton(self):
"""Normal Linux: triton-kernels.txt IS called."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be called on normal Linux"
def test_normal_linux_includes_extras(self):
"""Normal Linux: extras.txt IS called (no filtering)."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "extras.txt"
), "extras.txt should be called on normal Linux"
def test_normal_linux_includes_extras_no_deps(self):
"""Normal Linux: extras-no-deps.txt IS called (no filtering)."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False)
assert self._cmds_contain_file(
cmds, "extras-no-deps.txt"
), "extras-no-deps.txt should be called on normal Linux"
# -- Windows-only (NO_TORCH=False) to verify triton is still skipped --
def test_windows_only_skips_triton(self):
"""Windows (without NO_TORCH): triton still skipped."""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on Windows even without NO_TORCH"
def test_windows_only_includes_overrides(self):
"""Windows (without NO_TORCH): overrides IS called (via filtered temp file).
On Windows, all req files go through _filter_requirements(WINDOWS_SKIP_PACKAGES),
so the command uses a temp file, not overrides.txt directly. We check for
--reinstall (uv translation of --force-reinstall) which is unique to overrides.
"""
cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True)
assert any(
"--reinstall" in cmd for cmd in cmds
), "overrides step (--reinstall) should be called on Windows when NO_TORCH=False"
# -- Update path (skip_base=False) to verify no-torch mode is durable --
def test_update_path_intel_macos_still_skips_overrides(self):
"""Update path (no SKIP_STUDIO_BASE): overrides still skipped on Intel Mac."""
cmds = self._capture_install(
no_torch = True, is_macos = True, is_windows = False, skip_base = False
)
assert not self._cmds_contain_file(
cmds, "overrides.txt"
), "overrides.txt should be skipped on Intel Mac even via studio update"
def test_update_path_intel_macos_still_skips_triton(self):
"""Update path (no SKIP_STUDIO_BASE): triton still skipped on macOS."""
cmds = self._capture_install(
no_torch = True, is_macos = True, is_windows = False, skip_base = False
)
assert not self._cmds_contain_file(
cmds, "triton-kernels.txt"
), "triton-kernels.txt should be skipped on macOS even via studio update"
# ── Overrides skip structural checks ─────────────────────────────────
class TestOverridesSkip:
"""Verify overrides.txt is skipped when NO_TORCH is True (source-level check)."""
def test_no_torch_guard_exists_in_source(self):
"""The install_python_stack source must contain a NO_TORCH guard around overrides."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
assert (
"if NO_TORCH:" in source
), "NO_TORCH guard not found in install_python_stack.py"
def test_overrides_skipped_when_no_torch(self):
"""With NO_TORCH=True on the module, pip_install should NOT be called for overrides."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
overrides_match = re.search(r"if NO_TORCH:.*?overrides", source, re.DOTALL)
assert (
overrides_match is not None
), "Expected NO_TORCH conditional before overrides install"
# ── install.sh --no-torch flag tests ──────────────────────────────────
class TestInstallShNoTorchFlag:
"""Verify install.sh has the --no-torch flag and SKIP_TORCH variable."""
@pytest.fixture(autouse = True)
def _check_install_sh(self):
install_sh = Path(__file__).resolve().parents[2] / "install.sh"
if not install_sh.is_file():
pytest.skip("install.sh not found")
self.install_sh = install_sh
self.source = install_sh.read_text(encoding = "utf-8")
def test_no_torch_flag_in_case_statement(self):
"""--no-torch must appear in the flag parser case statement."""
assert (
"--no-torch)" in self.source
), "--no-torch not found in install.sh flag parser"
def test_no_torch_flag_variable_initialized(self):
"""_NO_TORCH_FLAG must be initialized to false."""
assert (
"_NO_TORCH_FLAG=false" in self.source
), "_NO_TORCH_FLAG=false not found in install.sh"
def test_skip_torch_variable_exists(self):
"""SKIP_TORCH variable must be defined."""
assert (
"SKIP_TORCH=false" in self.source
), "SKIP_TORCH=false not found in install.sh"
assert (
"SKIP_TORCH=true" in self.source
), "SKIP_TORCH=true not found in install.sh"
def test_skip_torch_driven_by_flag_and_mac_intel(self):
"""SKIP_TORCH must check both _NO_TORCH_FLAG and MAC_INTEL."""
assert (
"_NO_TORCH_FLAG" in self.source
), "_NO_TORCH_FLAG not referenced in SKIP_TORCH logic"
assert (
"MAC_INTEL" in self.source
), "MAC_INTEL not referenced in SKIP_TORCH logic"
def test_unsloth_no_torch_uses_skip_torch(self):
"""UNSLOTH_NO_TORCH must reference $SKIP_TORCH, not $MAC_INTEL."""
import re
matches = re.findall(r'UNSLOTH_NO_TORCH="\$(\w+)"', self.source)
for var in matches:
assert (
var == "SKIP_TORCH"
), f"UNSLOTH_NO_TORCH references ${var} instead of $SKIP_TORCH"
def test_cpu_hint_message_exists(self):
"""CPU hint message must exist in install.sh."""
assert (
"No NVIDIA GPU detected" in self.source
), "CPU hint message not found in install.sh"
assert (
"--no-torch" in self.source
), "--no-torch suggestion not found in CPU hint"
def test_no_torch_flag_parsing_subprocess(self):
"""--no-torch flag sets _NO_TORCH_FLAG=true (subprocess test)."""
script = textwrap.dedent("""\
_NO_TORCH_FLAG=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then
PACKAGE_NAME="$arg"
_next_is_package=false
continue
fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--no-torch) _NO_TORCH_FLAG=true ;;
esac
done
echo "$_NO_TORCH_FLAG"
""")
result = subprocess.run(
["bash", "-c", script, "_", "--no-torch"],
capture_output = True,
text = True,
)
assert (
result.stdout.strip() == "true"
), f"Expected _NO_TORCH_FLAG=true, got: {result.stdout.strip()}"
def test_no_torch_with_local_flag(self):
"""--no-torch and --local can be used together."""
script = textwrap.dedent("""\
_NO_TORCH_FLAG=false
_next_is_package=false
STUDIO_LOCAL_INSTALL=false
PACKAGE_NAME="unsloth"
for arg in "$@"; do
if [ "$_next_is_package" = true ]; then
PACKAGE_NAME="$arg"
_next_is_package=false
continue
fi
case "$arg" in
--local) STUDIO_LOCAL_INSTALL=true ;;
--package) _next_is_package=true ;;
--no-torch) _NO_TORCH_FLAG=true ;;
esac
done
echo "$_NO_TORCH_FLAG $STUDIO_LOCAL_INSTALL"
""")
result = subprocess.run(
["bash", "-c", script, "_", "--local", "--no-torch"],
capture_output = True,
text = True,
)
assert (
result.stdout.strip() == "true true"
), f"Expected 'true true', got: {result.stdout.strip()}"
def test_cpu_hint_only_when_not_skip_torch(self):
"""CPU hint should only print when SKIP_TORCH=false and OS!=macos."""
script = textwrap.dedent("""\
TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu"
SKIP_TORCH=false
OS="linux"
case "$TORCH_INDEX_URL" in
*/cpu)
if [ "$SKIP_TORCH" = false ] && [ "$OS" != "macos" ]; then
echo "HINT_PRINTED"
fi
;;
esac
""")
result = subprocess.run(
["bash", "-c", script],
capture_output = True,
text = True,
)
assert "HINT_PRINTED" in result.stdout, "CPU hint should print"
# With SKIP_TORCH=true, hint should NOT print
script2 = script.replace("SKIP_TORCH=false", "SKIP_TORCH=true")
result2 = subprocess.run(
["bash", "-c", script2],
capture_output = True,
text = True,
)
assert (
"HINT_PRINTED" not in result2.stdout
), "CPU hint should NOT print when SKIP_TORCH=true"
# ── Triton macOS skip structural checks ──────────────────────────────
class TestTritonMacosSkip:
"""Verify triton is skipped on macOS (source-level check)."""
def test_triton_guard_in_source(self):
"""Source must skip triton on both Windows and macOS."""
source = Path(ips.__file__).read_text(encoding = "utf-8")
assert (
"not IS_MACOS" in source
), "IS_MACOS guard for triton not found in install_python_stack.py"
assert (
"not IS_WINDOWS and not IS_MACOS" in source
), "Expected 'not IS_WINDOWS and not IS_MACOS' guard for triton"

View file

@ -0,0 +1,582 @@
"""End-to-end sandbox tests: Studio modules in isolated no-torch venvs.
Covers:
- Python 3.12 and 3.13 venv creation (Intel Mac uses 3.12, Apple Silicon/Linux 3.13)
- data_collators.py loads and dataclasses instantiate without torch
- chat_templates.py top-level exec works with stubs for relative imports
- Negative control: prepending 'import torch' fails in no-torch venv
- Negative control: installing torchao (from overrides.txt) fails in no-torch venv
- AST structural checks for top-level torch imports
"""
from __future__ import annotations
import ast
import os
import shutil
import subprocess
import sys
import tempfile
import textwrap
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
DATA_COLLATORS = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "data_collators.py"
)
CHAT_TEMPLATES = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "chat_templates.py"
)
FORMAT_CONVERSION = (
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "format_conversion.py"
)
def _has_uv() -> bool:
return shutil.which("uv") is not None
def _create_venv(venv_dir: Path, python_version: str) -> Path | None:
"""Create a uv venv at the given Python version. Returns python path or None."""
result = subprocess.run(
["uv", "venv", str(venv_dir), "--python", python_version],
capture_output = True,
)
if result.returncode != 0:
return None
venv_python = venv_dir / "bin" / "python"
if not venv_python.exists():
venv_python = venv_dir / "Scripts" / "python.exe"
return venv_python if venv_python.exists() else None
@pytest.fixture(params = ["3.12", "3.13"], scope = "module")
def no_torch_venv(request, tmp_path_factory):
"""Create a temporary venv at the requested Python version with no torch.
Parametrized for 3.12 (Intel Mac) and 3.13 (Apple Silicon / Linux).
"""
if not _has_uv():
pytest.skip("uv not available")
py_version = request.param
venv_dir = tmp_path_factory.mktemp(f"no_torch_venv_{py_version}")
venv_python = _create_venv(venv_dir, py_version)
if venv_python is None:
pytest.skip(f"Could not create Python {py_version} venv")
# Verify torch is NOT importable
check = subprocess.run(
[str(venv_python), "-c", "import torch"],
capture_output = True,
)
assert (
check.returncode != 0
), f"torch should NOT be importable in fresh {py_version} venv"
return str(venv_python)
# ── AST structural checks ─────────────────────────────────────────────
class TestDataCollatorsAST:
"""Static analysis: data_collators.py has no top-level torch imports."""
def test_ast_parse(self):
"""data_collators.py must be valid Python syntax."""
source = DATA_COLLATORS.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(DATA_COLLATORS))
assert tree is not None
def test_no_top_level_torch_import(self):
"""No top-level 'import torch' or 'from torch' statements."""
source = DATA_COLLATORS.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Import):
for alias in node.names:
assert not alias.name.startswith(
"torch"
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
elif isinstance(node, ast.ImportFrom):
if node.module:
assert not node.module.startswith(
"torch"
), f"Top-level 'from {node.module}' found at line {node.lineno}"
class TestChatTemplatesAST:
"""Static analysis: chat_templates.py has no top-level torch imports."""
def test_ast_parse(self):
"""chat_templates.py must be valid Python syntax."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(CHAT_TEMPLATES))
assert tree is not None
def test_no_top_level_torch_import(self):
"""No top-level 'import torch' or 'from torch' at module level."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Import):
for alias in node.names:
assert not alias.name.startswith(
"torch"
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
elif isinstance(node, ast.ImportFrom):
if node.module:
assert not node.module.startswith(
"torch"
), f"Top-level 'from {node.module}' found at line {node.lineno}"
def test_torch_imports_only_inside_functions(self):
"""All 'from torch' imports must be inside function/method bodies."""
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
tree = ast.parse(source)
torch_imports = []
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = None
if isinstance(node, ast.ImportFrom):
module = node.module
elif isinstance(node, ast.Import):
module = node.names[0].name if node.names else None
if module and module.startswith("torch"):
torch_imports.append(node)
top_level = set(id(n) for n in ast.iter_child_nodes(tree))
for imp in torch_imports:
assert id(imp) not in top_level, (
f"torch import at line {imp.lineno} is at top level"
" (should be inside a function)"
)
# ── data_collators.py: exec + dataclass instantiation in no-torch venv ──
class TestDataCollatorsNoTorchVenv:
"""Run data_collators.py in an isolated no-torch venv, verify classes load."""
def test_exec_in_no_torch_venv(self, no_torch_venv):
"""data_collators.py executes in a venv without torch (with loggers stub)."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
print("OK: exec succeeded")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"data_collators.py failed in no-torch venv:\n{result.stderr.decode()}"
assert b"OK: exec succeeded" in result.stdout
def test_dataclass_speech_collator_instantiable(self, no_torch_venv):
"""DataCollatorSpeechSeq2SeqWithPadding can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None)
assert obj.processor is None, "processor should be None"
print("OK: DataCollatorSpeechSeq2SeqWithPadding instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DataCollatorSpeechSeq2SeqWithPadding failed:\n{result.stderr.decode()}"
assert b"OK: DataCollatorSpeechSeq2SeqWithPadding instantiated" in result.stdout
def test_dataclass_deepseek_collator_instantiable(self, no_torch_venv):
"""DeepSeekOCRDataCollator can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = DeepSeekOCRDataCollator(processor=None)
assert obj.processor is None, "processor should be None"
assert obj.max_length == 2048, "default max_length should be 2048"
assert obj.ignore_index == -100, "default ignore_index should be -100"
print("OK: DeepSeekOCRDataCollator instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DeepSeekOCRDataCollator failed:\n{result.stderr.decode()}"
assert b"OK: DeepSeekOCRDataCollator instantiated" in result.stdout
def test_dataclass_vlm_collator_instantiable(self, no_torch_venv):
"""VLMDataCollator can be instantiated with processor=None."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({str(DATA_COLLATORS)!r}).read())
obj = VLMDataCollator(processor=None)
assert obj.processor is None
assert obj.mask_input_tokens is True, "default mask_input_tokens should be True"
print("OK: VLMDataCollator instantiated")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"VLMDataCollator failed:\n{result.stderr.decode()}"
assert b"OK: VLMDataCollator instantiated" in result.stdout
# ── chat_templates.py: exec in no-torch venv ─────────────────────────
class TestChatTemplatesNoTorchVenv:
"""Run chat_templates.py in an isolated no-torch venv with stubs."""
def test_exec_with_stubs(self, no_torch_venv):
"""chat_templates.py top-level exec works with stubs for relative imports."""
code = textwrap.dedent(f"""\
import sys, types
# Stub loggers
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
sys.modules['loggers'] = loggers
# Stub relative imports (.format_detection, .model_mappings)
format_detection = types.ModuleType('format_detection')
format_detection.detect_dataset_format = lambda *a, **k: None
format_detection.detect_multimodal_dataset = lambda *a, **k: None
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
sys.modules['format_detection'] = format_detection
model_mappings = types.ModuleType('model_mappings')
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
sys.modules['model_mappings'] = model_mappings
# Read and transform the source: replace relative imports with absolute
source = open({str(CHAT_TEMPLATES)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
source = source.replace('from .model_mappings import', 'from model_mappings import')
exec(source)
# Verify module-level constants are defined
ns = dict(locals())
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined after exec"
print("OK: chat_templates.py exec succeeded")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"chat_templates.py failed in no-torch venv:\n{result.stderr.decode()}"
assert b"OK: chat_templates.py exec succeeded" in result.stdout
def test_default_alpaca_template_defined(self, no_torch_venv):
"""DEFAULT_ALPACA_TEMPLATE constant is accessible after exec."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
sys.modules['loggers'] = loggers
format_detection = types.ModuleType('format_detection')
format_detection.detect_dataset_format = lambda *a, **k: None
format_detection.detect_multimodal_dataset = lambda *a, **k: None
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
sys.modules['format_detection'] = format_detection
model_mappings = types.ModuleType('model_mappings')
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
sys.modules['model_mappings'] = model_mappings
ns = {{}}
source = open({str(CHAT_TEMPLATES)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
source = source.replace('from .model_mappings import', 'from model_mappings import')
exec(source, ns)
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined"
assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'], "Template content unexpected"
print("OK: DEFAULT_ALPACA_TEMPLATE defined and valid")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"DEFAULT_ALPACA_TEMPLATE check failed:\n{result.stderr.decode()}"
assert b"OK: DEFAULT_ALPACA_TEMPLATE defined and valid" in result.stdout
# ── format_conversion.py: AST + runtime tests ────────────────────────
class TestFormatConversionAST:
"""Static analysis: format_conversion.py torch imports are guarded."""
def test_ast_parse(self):
"""format_conversion.py must be valid Python syntax."""
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
tree = ast.parse(source, filename = str(FORMAT_CONVERSION))
assert tree is not None
def test_no_bare_torch_import_in_functions(self):
"""All 'from torch' imports in function bodies must be inside try/except."""
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
for child in ast.walk(node):
if (
isinstance(child, ast.ImportFrom)
and child.module
and child.module.startswith("torch")
):
# This torch import must be inside a Try node
found_in_try = False
for try_node in ast.walk(node):
if isinstance(try_node, ast.Try):
for try_child in ast.walk(try_node):
if try_child is child:
found_in_try = True
break
if found_in_try:
break
assert found_in_try, (
f"torch import at line {child.lineno} in {node.name}() "
"is not inside a try/except block"
)
class TestFormatConversionNoTorchVenv:
"""Run format_conversion.py functions in a no-torch venv."""
def test_convert_chatml_to_alpaca_no_torch(self, no_torch_venv):
"""convert_chatml_to_alpaca works without torch (via try/except ImportError)."""
code = textwrap.dedent(f"""\
import sys, types
# Stub loggers
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{
'info': lambda s, m: None,
'warning': lambda s, m: None,
'debug': lambda s, m: None,
}})()
sys.modules['loggers'] = loggers
# Stub datasets.IterableDataset (HF datasets, not torch)
datasets_mod = types.ModuleType('datasets')
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
sys.modules['datasets'] = datasets_mod
# Stub utils.hardware
utils_mod = types.ModuleType('utils')
hardware_mod = types.ModuleType('utils.hardware')
hardware_mod.dataset_map_num_proc = lambda n=None: 1
utils_mod.hardware = hardware_mod
sys.modules['utils'] = utils_mod
sys.modules['utils.hardware'] = hardware_mod
# Read and exec format_conversion.py
source = open({str(FORMAT_CONVERSION)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
ns = {{'__name__': '__test__'}}
exec(source, ns)
# Test convert_chatml_to_alpaca with a simple dataset
class FakeDataset:
def map(self, fn, **kw):
result = fn({{
'messages': [[
{{'role': 'user', 'content': 'Hello'}},
{{'role': 'assistant', 'content': 'Hi there'}},
]]
}})
return result
result = ns['convert_chatml_to_alpaca'](FakeDataset())
assert 'instruction' in result, f"Expected 'instruction' in result, got {{result.keys()}}"
assert result['instruction'] == ['Hello']
assert result['output'] == ['Hi there']
print("OK: convert_chatml_to_alpaca works without torch")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"convert_chatml_to_alpaca failed without torch:\n{result.stderr.decode()}"
assert b"OK: convert_chatml_to_alpaca works without torch" in result.stdout
def test_convert_alpaca_to_chatml_no_torch(self, no_torch_venv):
"""convert_alpaca_to_chatml works without torch (via try/except ImportError)."""
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: type('L', (), {{
'info': lambda s, m: None,
'warning': lambda s, m: None,
'debug': lambda s, m: None,
}})()
sys.modules['loggers'] = loggers
datasets_mod = types.ModuleType('datasets')
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
sys.modules['datasets'] = datasets_mod
utils_mod = types.ModuleType('utils')
hardware_mod = types.ModuleType('utils.hardware')
hardware_mod.dataset_map_num_proc = lambda n=None: 1
utils_mod.hardware = hardware_mod
sys.modules['utils'] = utils_mod
sys.modules['utils.hardware'] = hardware_mod
source = open({str(FORMAT_CONVERSION)!r}).read()
source = source.replace('from .format_detection import', 'from format_detection import')
ns = {{'__name__': '__test__'}}
exec(source, ns)
class FakeDataset:
def map(self, fn, **kw):
result = fn({{
'instruction': ['Write a poem'],
'input': [''],
'output': ['Roses are red'],
}})
return result
result = ns['convert_alpaca_to_chatml'](FakeDataset())
assert 'conversations' in result
convo = result['conversations'][0]
assert convo[0]['role'] == 'user'
assert convo[1]['role'] == 'assistant'
print("OK: convert_alpaca_to_chatml works without torch")
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode == 0
), f"convert_alpaca_to_chatml failed without torch:\n{result.stderr.decode()}"
assert b"OK: convert_alpaca_to_chatml works without torch" in result.stdout
# ── Negative controls ─────────────────────────────────────────────────
class TestNegativeControls:
"""Prove the fix is necessary by showing what fails WITHOUT it."""
def test_import_torch_prepended_fails(self, no_torch_venv):
"""Prepending 'import torch' to data_collators.py causes ModuleNotFoundError."""
with tempfile.NamedTemporaryFile(
mode = "w", suffix = ".py", delete = False, encoding = "utf-8"
) as f:
f.write("import torch\n")
f.write(DATA_COLLATORS.read_text(encoding = "utf-8"))
temp_file = f.name
try:
code = textwrap.dedent(f"""\
import sys, types
loggers = types.ModuleType('loggers')
loggers.get_logger = lambda n: None
sys.modules['loggers'] = loggers
exec(open({temp_file!r}).read())
""")
result = subprocess.run(
[no_torch_venv, "-c", code],
capture_output = True,
timeout = 30,
)
assert (
result.returncode != 0
), "Expected failure when 'import torch' is prepended"
assert (
b"ModuleNotFoundError" in result.stderr
or b"ImportError" in result.stderr
), f"Expected ImportError, got:\n{result.stderr.decode()}"
finally:
os.unlink(temp_file)
def test_torchao_install_fails_no_torch_venv(self, no_torch_venv):
"""Installing torchao (from overrides.txt) fails in a no-torch venv.
This proves the overrides.txt skip is necessary for Intel Mac.
"""
result = subprocess.run(
[
no_torch_venv,
"-m",
"pip",
"install",
"torchao==0.14.0",
"--dry-run",
],
capture_output = True,
timeout = 60,
)
if result.returncode != 0:
# torchao install/resolution failed as expected
pass
else:
# pip dry-run may not catch dependency issues; verify torch is missing
check = subprocess.run(
[no_torch_venv, "-c", "import torch"],
capture_output = True,
)
assert (
check.returncode != 0
), "torch should not be importable -- torchao would fail at runtime"
def test_direct_torch_import_fails(self, no_torch_venv):
"""Direct 'import torch' fails in the no-torch venv."""
result = subprocess.run(
[no_torch_venv, "-c", "import torch; print('torch loaded')"],
capture_output = True,
timeout = 30,
)
assert result.returncode != 0, "import torch should fail in no-torch venv"
assert (
b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr
)