mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 11:29:57 +00:00
* tests: add no-torch / Intel Mac test suite Add comprehensive test coverage for the no-torch / --no-torch installer and Studio backend changes introduced in #4624. Shell tests (tests/sh/test_mac_intel_compat.sh): - version_ge edge cases (9 tests) - Architecture detection + Python version resolution (4 tests) - get_torch_index_url on Darwin (2 tests) - UNSLOTH_NO_TORCH propagation via SKIP_TORCH (5 tests) - E2E uv venv creation at Python 3.12 (3 tests) - E2E torch skip with mock uv shim (4 tests) - UNSLOTH_NO_TORCH env propagation (4 tests) - --python override flag parsing + resolution (11 tests) - --no-torch flag parsing (4 tests) - SKIP_TORCH unification (3 tests) - CPU hint printing (2 tests) Python tests (tests/python/test_no_torch_filtering.py): - _filter_requirements unit tests with synthetic + real requirements files - NO_TORCH / IS_MACOS constant parsing - Subprocess mock of install_python_stack() across platform configs - install.sh --no-torch flag structural + subprocess tests Python tests (tests/python/test_studio_import_no_torch.py): - AST checks for data_collators.py, chat_templates.py, format_conversion.py - Parametrized venv tests (Python 3.12 + 3.13) for no-torch exec - Dataclass instantiation without torch - format_conversion convert functions without torch - Negative controls (import torch fails, torchao fails) Python tests (tests/python/test_e2e_no_torch_sandbox.py): - Before/after import chain tests - Edge cases (broken torch, fake torch, lazy import) - Hardware detection without torch - install.sh logic tests (flag parsing, version resolution) - install_python_stack filtering tests - Live server startup tests (opt-in via @server marker) * fix: address review comments on test suite - Fix always-true assertion in test_studio_import_no_torch.py (or True) - Make IS_MACOS test platform-aware instead of hardcoding Linux - Restore torchvision + torchaudio in server test cleanup (not just torch) - Include server stderr in skip message for easier debugging * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
582 lines
24 KiB
Python
582 lines
24 KiB
Python
"""End-to-end sandbox tests: Studio modules in isolated no-torch venvs.
|
|
|
|
Covers:
|
|
- Python 3.12 and 3.13 venv creation (Intel Mac uses 3.12, Apple Silicon/Linux 3.13)
|
|
- data_collators.py loads and dataclasses instantiate without torch
|
|
- chat_templates.py top-level exec works with stubs for relative imports
|
|
- Negative control: prepending 'import torch' fails in no-torch venv
|
|
- Negative control: installing torchao (from overrides.txt) fails in no-torch venv
|
|
- AST structural checks for top-level torch imports
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import textwrap
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
DATA_COLLATORS = (
|
|
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "data_collators.py"
|
|
)
|
|
CHAT_TEMPLATES = (
|
|
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "chat_templates.py"
|
|
)
|
|
FORMAT_CONVERSION = (
|
|
REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "format_conversion.py"
|
|
)
|
|
|
|
|
|
def _has_uv() -> bool:
|
|
return shutil.which("uv") is not None
|
|
|
|
|
|
def _create_venv(venv_dir: Path, python_version: str) -> Path | None:
|
|
"""Create a uv venv at the given Python version. Returns python path or None."""
|
|
result = subprocess.run(
|
|
["uv", "venv", str(venv_dir), "--python", python_version],
|
|
capture_output = True,
|
|
)
|
|
if result.returncode != 0:
|
|
return None
|
|
venv_python = venv_dir / "bin" / "python"
|
|
if not venv_python.exists():
|
|
venv_python = venv_dir / "Scripts" / "python.exe"
|
|
return venv_python if venv_python.exists() else None
|
|
|
|
|
|
@pytest.fixture(params = ["3.12", "3.13"], scope = "module")
|
|
def no_torch_venv(request, tmp_path_factory):
|
|
"""Create a temporary venv at the requested Python version with no torch.
|
|
|
|
Parametrized for 3.12 (Intel Mac) and 3.13 (Apple Silicon / Linux).
|
|
"""
|
|
if not _has_uv():
|
|
pytest.skip("uv not available")
|
|
|
|
py_version = request.param
|
|
venv_dir = tmp_path_factory.mktemp(f"no_torch_venv_{py_version}")
|
|
venv_python = _create_venv(venv_dir, py_version)
|
|
if venv_python is None:
|
|
pytest.skip(f"Could not create Python {py_version} venv")
|
|
|
|
# Verify torch is NOT importable
|
|
check = subprocess.run(
|
|
[str(venv_python), "-c", "import torch"],
|
|
capture_output = True,
|
|
)
|
|
assert (
|
|
check.returncode != 0
|
|
), f"torch should NOT be importable in fresh {py_version} venv"
|
|
|
|
return str(venv_python)
|
|
|
|
|
|
# ── AST structural checks ─────────────────────────────────────────────
|
|
|
|
|
|
class TestDataCollatorsAST:
|
|
"""Static analysis: data_collators.py has no top-level torch imports."""
|
|
|
|
def test_ast_parse(self):
|
|
"""data_collators.py must be valid Python syntax."""
|
|
source = DATA_COLLATORS.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source, filename = str(DATA_COLLATORS))
|
|
assert tree is not None
|
|
|
|
def test_no_top_level_torch_import(self):
|
|
"""No top-level 'import torch' or 'from torch' statements."""
|
|
source = DATA_COLLATORS.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source)
|
|
for node in ast.iter_child_nodes(tree):
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
assert not alias.name.startswith(
|
|
"torch"
|
|
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
|
|
elif isinstance(node, ast.ImportFrom):
|
|
if node.module:
|
|
assert not node.module.startswith(
|
|
"torch"
|
|
), f"Top-level 'from {node.module}' found at line {node.lineno}"
|
|
|
|
|
|
class TestChatTemplatesAST:
|
|
"""Static analysis: chat_templates.py has no top-level torch imports."""
|
|
|
|
def test_ast_parse(self):
|
|
"""chat_templates.py must be valid Python syntax."""
|
|
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source, filename = str(CHAT_TEMPLATES))
|
|
assert tree is not None
|
|
|
|
def test_no_top_level_torch_import(self):
|
|
"""No top-level 'import torch' or 'from torch' at module level."""
|
|
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source)
|
|
for node in ast.iter_child_nodes(tree):
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
assert not alias.name.startswith(
|
|
"torch"
|
|
), f"Top-level 'import {alias.name}' found at line {node.lineno}"
|
|
elif isinstance(node, ast.ImportFrom):
|
|
if node.module:
|
|
assert not node.module.startswith(
|
|
"torch"
|
|
), f"Top-level 'from {node.module}' found at line {node.lineno}"
|
|
|
|
def test_torch_imports_only_inside_functions(self):
|
|
"""All 'from torch' imports must be inside function/method bodies."""
|
|
source = CHAT_TEMPLATES.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source)
|
|
torch_imports = []
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
|
module = None
|
|
if isinstance(node, ast.ImportFrom):
|
|
module = node.module
|
|
elif isinstance(node, ast.Import):
|
|
module = node.names[0].name if node.names else None
|
|
if module and module.startswith("torch"):
|
|
torch_imports.append(node)
|
|
|
|
top_level = set(id(n) for n in ast.iter_child_nodes(tree))
|
|
for imp in torch_imports:
|
|
assert id(imp) not in top_level, (
|
|
f"torch import at line {imp.lineno} is at top level"
|
|
" (should be inside a function)"
|
|
)
|
|
|
|
|
|
# ── data_collators.py: exec + dataclass instantiation in no-torch venv ──
|
|
|
|
|
|
class TestDataCollatorsNoTorchVenv:
|
|
"""Run data_collators.py in an isolated no-torch venv, verify classes load."""
|
|
|
|
def test_exec_in_no_torch_venv(self, no_torch_venv):
|
|
"""data_collators.py executes in a venv without torch (with loggers stub)."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: None
|
|
sys.modules['loggers'] = loggers
|
|
exec(open({str(DATA_COLLATORS)!r}).read())
|
|
print("OK: exec succeeded")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"data_collators.py failed in no-torch venv:\n{result.stderr.decode()}"
|
|
assert b"OK: exec succeeded" in result.stdout
|
|
|
|
def test_dataclass_speech_collator_instantiable(self, no_torch_venv):
|
|
"""DataCollatorSpeechSeq2SeqWithPadding can be instantiated with processor=None."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: None
|
|
sys.modules['loggers'] = loggers
|
|
exec(open({str(DATA_COLLATORS)!r}).read())
|
|
obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None)
|
|
assert obj.processor is None, "processor should be None"
|
|
print("OK: DataCollatorSpeechSeq2SeqWithPadding instantiated")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"DataCollatorSpeechSeq2SeqWithPadding failed:\n{result.stderr.decode()}"
|
|
assert b"OK: DataCollatorSpeechSeq2SeqWithPadding instantiated" in result.stdout
|
|
|
|
def test_dataclass_deepseek_collator_instantiable(self, no_torch_venv):
|
|
"""DeepSeekOCRDataCollator can be instantiated with processor=None."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: None
|
|
sys.modules['loggers'] = loggers
|
|
exec(open({str(DATA_COLLATORS)!r}).read())
|
|
obj = DeepSeekOCRDataCollator(processor=None)
|
|
assert obj.processor is None, "processor should be None"
|
|
assert obj.max_length == 2048, "default max_length should be 2048"
|
|
assert obj.ignore_index == -100, "default ignore_index should be -100"
|
|
print("OK: DeepSeekOCRDataCollator instantiated")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"DeepSeekOCRDataCollator failed:\n{result.stderr.decode()}"
|
|
assert b"OK: DeepSeekOCRDataCollator instantiated" in result.stdout
|
|
|
|
def test_dataclass_vlm_collator_instantiable(self, no_torch_venv):
|
|
"""VLMDataCollator can be instantiated with processor=None."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: None
|
|
sys.modules['loggers'] = loggers
|
|
exec(open({str(DATA_COLLATORS)!r}).read())
|
|
obj = VLMDataCollator(processor=None)
|
|
assert obj.processor is None
|
|
assert obj.mask_input_tokens is True, "default mask_input_tokens should be True"
|
|
print("OK: VLMDataCollator instantiated")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"VLMDataCollator failed:\n{result.stderr.decode()}"
|
|
assert b"OK: VLMDataCollator instantiated" in result.stdout
|
|
|
|
|
|
# ── chat_templates.py: exec in no-torch venv ─────────────────────────
|
|
|
|
|
|
class TestChatTemplatesNoTorchVenv:
|
|
"""Run chat_templates.py in an isolated no-torch venv with stubs."""
|
|
|
|
def test_exec_with_stubs(self, no_torch_venv):
|
|
"""chat_templates.py top-level exec works with stubs for relative imports."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
|
|
# Stub loggers
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
|
|
sys.modules['loggers'] = loggers
|
|
|
|
# Stub relative imports (.format_detection, .model_mappings)
|
|
format_detection = types.ModuleType('format_detection')
|
|
format_detection.detect_dataset_format = lambda *a, **k: None
|
|
format_detection.detect_multimodal_dataset = lambda *a, **k: None
|
|
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
|
|
sys.modules['format_detection'] = format_detection
|
|
|
|
model_mappings = types.ModuleType('model_mappings')
|
|
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
|
|
sys.modules['model_mappings'] = model_mappings
|
|
|
|
# Read and transform the source: replace relative imports with absolute
|
|
source = open({str(CHAT_TEMPLATES)!r}).read()
|
|
source = source.replace('from .format_detection import', 'from format_detection import')
|
|
source = source.replace('from .model_mappings import', 'from model_mappings import')
|
|
|
|
exec(source)
|
|
|
|
# Verify module-level constants are defined
|
|
ns = dict(locals())
|
|
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined after exec"
|
|
print("OK: chat_templates.py exec succeeded")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"chat_templates.py failed in no-torch venv:\n{result.stderr.decode()}"
|
|
assert b"OK: chat_templates.py exec succeeded" in result.stdout
|
|
|
|
def test_default_alpaca_template_defined(self, no_torch_venv):
|
|
"""DEFAULT_ALPACA_TEMPLATE constant is accessible after exec."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})()
|
|
sys.modules['loggers'] = loggers
|
|
|
|
format_detection = types.ModuleType('format_detection')
|
|
format_detection.detect_dataset_format = lambda *a, **k: None
|
|
format_detection.detect_multimodal_dataset = lambda *a, **k: None
|
|
format_detection.detect_custom_format_heuristic = lambda *a, **k: None
|
|
sys.modules['format_detection'] = format_detection
|
|
|
|
model_mappings = types.ModuleType('model_mappings')
|
|
model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}}
|
|
sys.modules['model_mappings'] = model_mappings
|
|
|
|
ns = {{}}
|
|
source = open({str(CHAT_TEMPLATES)!r}).read()
|
|
source = source.replace('from .format_detection import', 'from format_detection import')
|
|
source = source.replace('from .model_mappings import', 'from model_mappings import')
|
|
exec(source, ns)
|
|
|
|
assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined"
|
|
assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'], "Template content unexpected"
|
|
print("OK: DEFAULT_ALPACA_TEMPLATE defined and valid")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"DEFAULT_ALPACA_TEMPLATE check failed:\n{result.stderr.decode()}"
|
|
assert b"OK: DEFAULT_ALPACA_TEMPLATE defined and valid" in result.stdout
|
|
|
|
|
|
# ── format_conversion.py: AST + runtime tests ────────────────────────
|
|
|
|
|
|
class TestFormatConversionAST:
|
|
"""Static analysis: format_conversion.py torch imports are guarded."""
|
|
|
|
def test_ast_parse(self):
|
|
"""format_conversion.py must be valid Python syntax."""
|
|
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source, filename = str(FORMAT_CONVERSION))
|
|
assert tree is not None
|
|
|
|
def test_no_bare_torch_import_in_functions(self):
|
|
"""All 'from torch' imports in function bodies must be inside try/except."""
|
|
source = FORMAT_CONVERSION.read_text(encoding = "utf-8")
|
|
tree = ast.parse(source)
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
for child in ast.walk(node):
|
|
if (
|
|
isinstance(child, ast.ImportFrom)
|
|
and child.module
|
|
and child.module.startswith("torch")
|
|
):
|
|
# This torch import must be inside a Try node
|
|
found_in_try = False
|
|
for try_node in ast.walk(node):
|
|
if isinstance(try_node, ast.Try):
|
|
for try_child in ast.walk(try_node):
|
|
if try_child is child:
|
|
found_in_try = True
|
|
break
|
|
if found_in_try:
|
|
break
|
|
assert found_in_try, (
|
|
f"torch import at line {child.lineno} in {node.name}() "
|
|
"is not inside a try/except block"
|
|
)
|
|
|
|
|
|
class TestFormatConversionNoTorchVenv:
|
|
"""Run format_conversion.py functions in a no-torch venv."""
|
|
|
|
def test_convert_chatml_to_alpaca_no_torch(self, no_torch_venv):
|
|
"""convert_chatml_to_alpaca works without torch (via try/except ImportError)."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
|
|
# Stub loggers
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: type('L', (), {{
|
|
'info': lambda s, m: None,
|
|
'warning': lambda s, m: None,
|
|
'debug': lambda s, m: None,
|
|
}})()
|
|
sys.modules['loggers'] = loggers
|
|
|
|
# Stub datasets.IterableDataset (HF datasets, not torch)
|
|
datasets_mod = types.ModuleType('datasets')
|
|
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
|
|
sys.modules['datasets'] = datasets_mod
|
|
|
|
# Stub utils.hardware
|
|
utils_mod = types.ModuleType('utils')
|
|
hardware_mod = types.ModuleType('utils.hardware')
|
|
hardware_mod.dataset_map_num_proc = lambda n=None: 1
|
|
utils_mod.hardware = hardware_mod
|
|
sys.modules['utils'] = utils_mod
|
|
sys.modules['utils.hardware'] = hardware_mod
|
|
|
|
# Read and exec format_conversion.py
|
|
source = open({str(FORMAT_CONVERSION)!r}).read()
|
|
source = source.replace('from .format_detection import', 'from format_detection import')
|
|
ns = {{'__name__': '__test__'}}
|
|
exec(source, ns)
|
|
|
|
# Test convert_chatml_to_alpaca with a simple dataset
|
|
class FakeDataset:
|
|
def map(self, fn, **kw):
|
|
result = fn({{
|
|
'messages': [[
|
|
{{'role': 'user', 'content': 'Hello'}},
|
|
{{'role': 'assistant', 'content': 'Hi there'}},
|
|
]]
|
|
}})
|
|
return result
|
|
|
|
result = ns['convert_chatml_to_alpaca'](FakeDataset())
|
|
assert 'instruction' in result, f"Expected 'instruction' in result, got {{result.keys()}}"
|
|
assert result['instruction'] == ['Hello']
|
|
assert result['output'] == ['Hi there']
|
|
print("OK: convert_chatml_to_alpaca works without torch")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"convert_chatml_to_alpaca failed without torch:\n{result.stderr.decode()}"
|
|
assert b"OK: convert_chatml_to_alpaca works without torch" in result.stdout
|
|
|
|
def test_convert_alpaca_to_chatml_no_torch(self, no_torch_venv):
|
|
"""convert_alpaca_to_chatml works without torch (via try/except ImportError)."""
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: type('L', (), {{
|
|
'info': lambda s, m: None,
|
|
'warning': lambda s, m: None,
|
|
'debug': lambda s, m: None,
|
|
}})()
|
|
sys.modules['loggers'] = loggers
|
|
|
|
datasets_mod = types.ModuleType('datasets')
|
|
datasets_mod.IterableDataset = type('IterableDataset', (), {{}})
|
|
sys.modules['datasets'] = datasets_mod
|
|
|
|
utils_mod = types.ModuleType('utils')
|
|
hardware_mod = types.ModuleType('utils.hardware')
|
|
hardware_mod.dataset_map_num_proc = lambda n=None: 1
|
|
utils_mod.hardware = hardware_mod
|
|
sys.modules['utils'] = utils_mod
|
|
sys.modules['utils.hardware'] = hardware_mod
|
|
|
|
source = open({str(FORMAT_CONVERSION)!r}).read()
|
|
source = source.replace('from .format_detection import', 'from format_detection import')
|
|
ns = {{'__name__': '__test__'}}
|
|
exec(source, ns)
|
|
|
|
class FakeDataset:
|
|
def map(self, fn, **kw):
|
|
result = fn({{
|
|
'instruction': ['Write a poem'],
|
|
'input': [''],
|
|
'output': ['Roses are red'],
|
|
}})
|
|
return result
|
|
|
|
result = ns['convert_alpaca_to_chatml'](FakeDataset())
|
|
assert 'conversations' in result
|
|
convo = result['conversations'][0]
|
|
assert convo[0]['role'] == 'user'
|
|
assert convo[1]['role'] == 'assistant'
|
|
print("OK: convert_alpaca_to_chatml works without torch")
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode == 0
|
|
), f"convert_alpaca_to_chatml failed without torch:\n{result.stderr.decode()}"
|
|
assert b"OK: convert_alpaca_to_chatml works without torch" in result.stdout
|
|
|
|
|
|
# ── Negative controls ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestNegativeControls:
|
|
"""Prove the fix is necessary by showing what fails WITHOUT it."""
|
|
|
|
def test_import_torch_prepended_fails(self, no_torch_venv):
|
|
"""Prepending 'import torch' to data_collators.py causes ModuleNotFoundError."""
|
|
with tempfile.NamedTemporaryFile(
|
|
mode = "w", suffix = ".py", delete = False, encoding = "utf-8"
|
|
) as f:
|
|
f.write("import torch\n")
|
|
f.write(DATA_COLLATORS.read_text(encoding = "utf-8"))
|
|
temp_file = f.name
|
|
|
|
try:
|
|
code = textwrap.dedent(f"""\
|
|
import sys, types
|
|
loggers = types.ModuleType('loggers')
|
|
loggers.get_logger = lambda n: None
|
|
sys.modules['loggers'] = loggers
|
|
exec(open({temp_file!r}).read())
|
|
""")
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", code],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert (
|
|
result.returncode != 0
|
|
), "Expected failure when 'import torch' is prepended"
|
|
assert (
|
|
b"ModuleNotFoundError" in result.stderr
|
|
or b"ImportError" in result.stderr
|
|
), f"Expected ImportError, got:\n{result.stderr.decode()}"
|
|
finally:
|
|
os.unlink(temp_file)
|
|
|
|
def test_torchao_install_fails_no_torch_venv(self, no_torch_venv):
|
|
"""Installing torchao (from overrides.txt) fails in a no-torch venv.
|
|
|
|
This proves the overrides.txt skip is necessary for Intel Mac.
|
|
"""
|
|
result = subprocess.run(
|
|
[
|
|
no_torch_venv,
|
|
"-m",
|
|
"pip",
|
|
"install",
|
|
"torchao==0.14.0",
|
|
"--dry-run",
|
|
],
|
|
capture_output = True,
|
|
timeout = 60,
|
|
)
|
|
if result.returncode != 0:
|
|
# torchao install/resolution failed as expected
|
|
pass
|
|
else:
|
|
# pip dry-run may not catch dependency issues; verify torch is missing
|
|
check = subprocess.run(
|
|
[no_torch_venv, "-c", "import torch"],
|
|
capture_output = True,
|
|
)
|
|
assert (
|
|
check.returncode != 0
|
|
), "torch should not be importable -- torchao would fail at runtime"
|
|
|
|
def test_direct_torch_import_fails(self, no_torch_venv):
|
|
"""Direct 'import torch' fails in the no-torch venv."""
|
|
result = subprocess.run(
|
|
[no_torch_venv, "-c", "import torch; print('torch loaded')"],
|
|
capture_output = True,
|
|
timeout = 30,
|
|
)
|
|
assert result.returncode != 0, "import torch should fail in no-torch venv"
|
|
assert (
|
|
b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr
|
|
)
|