diff --git a/tests/python/conftest.py b/tests/python/conftest.py new file mode 100644 index 000000000..66542d245 --- /dev/null +++ b/tests/python/conftest.py @@ -0,0 +1,7 @@ +"""Shared pytest configuration for tests/python/.""" + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "server: heavyweight tests requiring studio venv" + ) diff --git a/tests/python/test_e2e_no_torch_sandbox.py b/tests/python/test_e2e_no_torch_sandbox.py new file mode 100644 index 000000000..f36f69201 --- /dev/null +++ b/tests/python/test_e2e_no_torch_sandbox.py @@ -0,0 +1,1239 @@ +"""Comprehensive E2E sandbox tests for PR #4624 (fix/install-mac-intel-no-torch). + +Proves that: +- The BEFORE state (top-level torch imports) crashes without torch +- The AFTER state (lazy/removed imports) works without torch +- Edge cases (broken torch, partial torch) are handled gracefully +- Hardware detection falls back to CPU without torch +- install.sh flag parsing and platform detection work correctly +- install_python_stack.py NO_TORCH filtering is correct +- Live server starts and responds without torch (optional, requires studio venv) + +Run: + # Lightweight tests (Groups 1-6, ~26 tests): + python -m pytest tests/python/test_e2e_no_torch_sandbox.py -v -k "not server" + + # Server tests (Group 7, 4 tests, requires studio venv): + python -m pytest tests/python/test_e2e_no_torch_sandbox.py -v -m server +""" + +from __future__ import annotations + +import os +import shutil +import signal +import subprocess +import sys +import textwrap +import time +from pathlib import Path +from unittest import mock + +import pytest + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +REPO_ROOT = Path(__file__).resolve().parents[2] +STUDIO_DIR = REPO_ROOT / "studio" +BACKEND_DIR = STUDIO_DIR / "backend" +DATASETS_DIR = BACKEND_DIR / "utils" / "datasets" +HARDWARE_DIR = BACKEND_DIR / "utils" / "hardware" +INSTALL_SH = REPO_ROOT / "install.sh" +INSTALL_PY = STUDIO_DIR / "install_python_stack.py" + +DATA_COLLATORS = DATASETS_DIR / "data_collators.py" +CHAT_TEMPLATES = DATASETS_DIR / "chat_templates.py" +FORMAT_DETECTION = DATASETS_DIR / "format_detection.py" +MODEL_MAPPINGS = DATASETS_DIR / "model_mappings.py" +VLM_PROCESSING = DATASETS_DIR / "vlm_processing.py" +HARDWARE_PY = HARDWARE_DIR / "hardware.py" + +# Studio venv for server tests +STUDIO_VENV = Path.home() / ".unsloth" / "studio" / "unsloth_studio" + +# Add studio to path for install_python_stack imports +sys.path.insert(0, str(STUDIO_DIR)) + + +# --------------------------------------------------------------------------- +# Cross-platform helpers +# --------------------------------------------------------------------------- + + +def _venv_python(venv_dir: Path) -> Path: + """Return the Python executable path for a venv, cross-platform.""" + if sys.platform == "win32": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +def _has_uv() -> bool: + return shutil.which("uv") is not None + + +def _create_no_torch_venv(venv_dir: Path, python_version: str = "3.12") -> Path | None: + """Create a uv venv with no torch. Returns python path or None.""" + result = subprocess.run( + ["uv", "venv", str(venv_dir), "--python", python_version], + capture_output = True, + ) + if result.returncode != 0: + return None + py = _venv_python(venv_dir) + if not py.exists(): + return None + # Verify torch is NOT importable + check = subprocess.run([str(py), "-c", "import torch"], capture_output = True) + if check.returncode == 0: + return None + return py + + +def _run_in_sandbox( + py: str | Path, + code: str, + timeout: int = 60, + env: dict | None = None, +) -> subprocess.CompletedProcess: + """Run Python code in a sandboxed interpreter.""" + return subprocess.run( + [str(py), "-c", code], + capture_output = True, + timeout = timeout, + env = env, + ) + + +def _run_sh(script: str, timeout: int = 30) -> subprocess.CompletedProcess: + """Run a bash snippet and return the result.""" + return subprocess.run( + ["bash", "-c", script], + capture_output = True, + timeout = timeout, + ) + + +# --------------------------------------------------------------------------- +# Stub generators +# --------------------------------------------------------------------------- + + +def _write_loggers_stub(sandbox: Path) -> None: + """Create a minimal loggers package stub (replaces structlog-backed real one).""" + loggers_dir = sandbox / "loggers" + loggers_dir.mkdir(exist_ok = True) + (loggers_dir / "__init__.py").write_text( + "from .handlers import get_logger\n__all__ = ['get_logger']\n", + encoding = "utf-8", + ) + (loggers_dir / "handlers.py").write_text( + textwrap.dedent("""\ + class _Logger: + def info(self, msg, *a, **k): pass + def warning(self, msg, *a, **k): pass + def debug(self, msg, *a, **k): pass + def error(self, msg, *a, **k): pass + def msg(self, msg, *a, **k): pass + def get_logger(name=None): + return _Logger() + """), + encoding = "utf-8", + ) + + +def _write_structlog_stub(sandbox: Path) -> None: + """Create a minimal structlog stub.""" + structlog_dir = sandbox / "structlog" + structlog_dir.mkdir(exist_ok = True) + (structlog_dir / "__init__.py").write_text( + textwrap.dedent("""\ + class _Logger: + def info(self, msg, *a, **k): pass + def warning(self, msg, *a, **k): pass + def debug(self, msg, *a, **k): pass + def error(self, msg, *a, **k): pass + def msg(self, msg, *a, **k): pass + def get_logger(name=None): + return _Logger() + """), + encoding = "utf-8", + ) + + +def _write_hardware_stub(sandbox: Path) -> None: + """Create utils/hardware stub with dataset_map_num_proc.""" + hw_dir = sandbox / "utils" / "hardware" + hw_dir.mkdir(parents = True, exist_ok = True) + (sandbox / "utils" / "__init__.py").write_text("", encoding = "utf-8") + (hw_dir / "__init__.py").write_text( + "def dataset_map_num_proc(n=None): return n\n", + encoding = "utf-8", + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope = "session") +def repo_root(): + return REPO_ROOT + + +@pytest.fixture +def sandbox_dir(tmp_path): + """Per-test temporary sandbox directory.""" + return tmp_path + + +@pytest.fixture(params = ["3.12", "3.13"], scope = "module") +def no_torch_venv(request, tmp_path_factory): + """Create a temporary uv venv with no torch. + + Parametrized for 3.12 (Intel Mac default) and 3.13 (Apple Silicon/Linux). + """ + if not _has_uv(): + pytest.skip("uv not available") + + py_version = request.param + venv_dir = tmp_path_factory.mktemp(f"e2e_no_torch_{py_version}") + py = _create_no_torch_venv(venv_dir, py_version) + if py is None: + pytest.skip(f"Could not create Python {py_version} no-torch venv") + return str(py) + + +# =========================================================================== +# Group 1: BEFORE vs AFTER -- Import Chain (6 tests) +# =========================================================================== + + +class TestBeforeAfterImportChain: + """Prove the bug exists in BEFORE state and is fixed in AFTER state. + + BEFORE = PR branch files with top-level torch import synthetically prepended + (simulates the main branch). + AFTER = PR branch files as-is (lazy imports / torch import removed). + """ + + # -- BEFORE: crashes -- + + def test_before_chat_templates_crashes(self, no_torch_venv, sandbox_dir): + """BEFORE: chat_templates.py with top-level 'from torch.utils.data import + IterableDataset' crashes without torch.""" + source = CHAT_TEMPLATES.read_text(encoding = "utf-8") + before_source = "from torch.utils.data import IterableDataset\n" + source + + before_file = sandbox_dir / "chat_templates_before.py" + before_file.write_text(before_source, encoding = "utf-8") + + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None}})() + sys.modules['loggers'] = loggers + fd = types.ModuleType('format_detection') + fd.detect_dataset_format = fd.detect_multimodal_dataset = fd.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = fd + mm = types.ModuleType('model_mappings') + mm.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = mm + source = open({str(before_file)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + exec(source) + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode != 0 + ), "BEFORE chat_templates.py should crash without torch" + assert ( + b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr + ) + + def test_before_data_collators_crashes(self, no_torch_venv, sandbox_dir): + """BEFORE: data_collators.py with top-level 'import torch' crashes.""" + source = DATA_COLLATORS.read_text(encoding = "utf-8") + before_source = "import torch\n" + source + + before_file = sandbox_dir / "data_collators_before.py" + before_file.write_text(before_source, encoding = "utf-8") + + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(before_file)!r}).read()) + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode != 0 + ), "BEFORE data_collators.py should crash without torch" + assert ( + b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr + ) + + def test_before_full_import_chain_crashes(self, no_torch_venv, sandbox_dir): + """BEFORE: full utils/datasets/ package with top-level torch imports crashes.""" + _write_loggers_stub(sandbox_dir) + _write_hardware_stub(sandbox_dir) + + pkg_dir = sandbox_dir / "utils" / "datasets" + pkg_dir.mkdir(parents = True, exist_ok = True) + + # Copy torch-free modules as-is + shutil.copy2(FORMAT_DETECTION, pkg_dir / "format_detection.py") + shutil.copy2(MODEL_MAPPINGS, pkg_dir / "model_mappings.py") + shutil.copy2(VLM_PROCESSING, pkg_dir / "vlm_processing.py") + + # BEFORE data_collators: prepend top-level 'import torch' + dc_source = DATA_COLLATORS.read_text(encoding = "utf-8") + (pkg_dir / "data_collators.py").write_text( + "import torch\n" + dc_source, + encoding = "utf-8", + ) + + # BEFORE chat_templates: prepend top-level IterableDataset import + ct_source = CHAT_TEMPLATES.read_text(encoding = "utf-8") + (pkg_dir / "chat_templates.py").write_text( + "from torch.utils.data import IterableDataset\n" + ct_source, + encoding = "utf-8", + ) + + # Minimal __init__.py that triggers the chain + (pkg_dir / "__init__.py").write_text( + textwrap.dedent("""\ + from .format_detection import detect_dataset_format + from .data_collators import DataCollatorSpeechSeq2SeqWithPadding + from .chat_templates import DEFAULT_ALPACA_TEMPLATE + """), + encoding = "utf-8", + ) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + from utils.datasets import detect_dataset_format + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode != 0 + ), "BEFORE full import chain should crash without torch" + assert ( + b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr + ) + + # -- AFTER: succeeds -- + + def test_after_chat_templates_imports(self, no_torch_venv): + """AFTER: PR branch chat_templates.py imports fine without torch.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None}})() + sys.modules['loggers'] = loggers + fd = types.ModuleType('format_detection') + fd.detect_dataset_format = fd.detect_multimodal_dataset = fd.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = fd + mm = types.ModuleType('model_mappings') + mm.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = mm + source = open({str(CHAT_TEMPLATES)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + exec(source) + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"AFTER chat_templates.py should work without torch:\n{result.stderr.decode()}" + assert b"OK" in result.stdout + + def test_after_data_collators_imports(self, no_torch_venv): + """AFTER: PR branch data_collators.py imports fine without torch.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"AFTER data_collators.py should work without torch:\n{result.stderr.decode()}" + assert b"OK" in result.stdout + + def test_after_full_import_chain_imports(self, no_torch_venv, sandbox_dir): + """AFTER: full utils/datasets/ package imports fine without torch.""" + _write_loggers_stub(sandbox_dir) + _write_hardware_stub(sandbox_dir) + + pkg_dir = sandbox_dir / "utils" / "datasets" + pkg_dir.mkdir(parents = True, exist_ok = True) + + # Copy AFTER versions (PR branch -- no top-level torch) + for src in [ + FORMAT_DETECTION, + MODEL_MAPPINGS, + VLM_PROCESSING, + DATA_COLLATORS, + CHAT_TEMPLATES, + ]: + if src.exists(): + shutil.copy2(src, pkg_dir / src.name) + + # Minimal __init__.py + (pkg_dir / "__init__.py").write_text( + textwrap.dedent("""\ + from .format_detection import detect_dataset_format, detect_custom_format_heuristic + from .model_mappings import MODEL_TO_TEMPLATE_MAPPER + from .chat_templates import DEFAULT_ALPACA_TEMPLATE, get_dataset_info_summary + from .data_collators import ( + DataCollatorSpeechSeq2SeqWithPadding, + DeepSeekOCRDataCollator, + VLMDataCollator, + ) + from .vlm_processing import generate_smart_vlm_instruction + """), + encoding = "utf-8", + ) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + from utils.datasets import ( + detect_dataset_format, + DEFAULT_ALPACA_TEMPLATE, + DataCollatorSpeechSeq2SeqWithPadding, + DeepSeekOCRDataCollator, + VLMDataCollator, + generate_smart_vlm_instruction, + ) + assert 'Instruction' in DEFAULT_ALPACA_TEMPLATE + print("OK: full import chain succeeded") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"AFTER full import chain should work:\n{result.stderr.decode()}" + assert b"OK: full import chain succeeded" in result.stdout + + +# =========================================================================== +# Group 2: Dataclass Instantiation (4 tests) +# =========================================================================== + + +class TestDataclassInstantiation: + """Verify dataclass collators can be instantiated and constants accessed + without torch in an isolated venv.""" + + def test_speech_collator_instantiate(self, no_torch_venv): + """DataCollatorSpeechSeq2SeqWithPadding(processor=None) succeeds.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None) + assert obj.processor is None + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + + def test_deepseek_ocr_collator_instantiate(self, no_torch_venv): + """DeepSeekOCRDataCollator has correct default field values.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = DeepSeekOCRDataCollator(processor=None) + assert obj.processor is None + assert obj.max_length == 2048 + assert obj.ignore_index == -100 + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + + def test_vlm_collator_instantiate(self, no_torch_venv): + """VLMDataCollator has correct default field values.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = VLMDataCollator(processor=None) + assert obj.processor is None + assert obj.max_length == 2048 + assert obj.mask_input_tokens is True + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + + def test_alpaca_template_accessible(self, no_torch_venv): + """DEFAULT_ALPACA_TEMPLATE constant is accessible and contains 'Instruction'.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None}})() + sys.modules['loggers'] = loggers + fd = types.ModuleType('format_detection') + fd.detect_dataset_format = fd.detect_multimodal_dataset = fd.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = fd + mm = types.ModuleType('model_mappings') + mm.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = mm + ns = {{}} + source = open({str(CHAT_TEMPLATES)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + exec(source, ns) + assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'] + print("OK") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + + +# =========================================================================== +# Group 3: Edge Cases -- Partial/Broken Torch (4 tests) +# =========================================================================== + + +class TestEdgeCasesBrokenTorch: + """Test behavior with fake or broken torch modules on sys.path.""" + + def test_fake_broken_torch_module(self, no_torch_venv, sandbox_dir): + """A fake torch that raises RuntimeError('CUDA not found') on import. + + data_collators.py (no top-level torch import) should still load fine. + """ + torch_dir = sandbox_dir / "torch" + torch_dir.mkdir() + (torch_dir / "__init__.py").write_text( + 'raise RuntimeError("CUDA not found")\n', + encoding = "utf-8", + ) + _write_loggers_stub(sandbox_dir) + shutil.copy2(DATA_COLLATORS, sandbox_dir / "data_collators.py") + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + exec(open({str(sandbox_dir / 'data_collators.py')!r}).read()) + obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None) + print("OK: data_collators works despite broken torch on sys.path") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"Should work with broken torch:\n{result.stderr.decode()}" + assert b"OK:" in result.stdout + + def test_torch_import_error_hardware_fallback(self, no_torch_venv, sandbox_dir): + """A fake torch that raises ImportError. detect_hardware() falls back to CPU.""" + torch_dir = sandbox_dir / "torch" + torch_dir.mkdir() + (torch_dir / "__init__.py").write_text( + 'raise ImportError("No torch binary")\n', + encoding = "utf-8", + ) + _write_loggers_stub(sandbox_dir) + _write_structlog_stub(sandbox_dir) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + source = open({str(HARDWARE_PY)!r}).read() + ns = {{'__name__': '__test__'}} + exec(source, ns) + result = ns['detect_hardware']() + assert result == ns['DeviceType'].CPU, f"Expected CPU, got {{result}}" + print("OK: detect_hardware returned CPU") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"detect_hardware should fallback to CPU:\n{result.stderr.decode()}" + assert b"OK: detect_hardware returned CPU" in result.stdout + + def test_fake_torch_no_cuda(self, no_torch_venv, sandbox_dir): + """Fake torch that imports OK but torch.cuda.is_available() returns False. + + detect_hardware() should still fall back to CPU. + """ + torch_dir = sandbox_dir / "torch" + torch_dir.mkdir() + (torch_dir / "__init__.py").write_text( + textwrap.dedent("""\ + class _Cuda: + @staticmethod + def is_available(): + return False + cuda = _Cuda() + class version: + cuda = None + """), + encoding = "utf-8", + ) + _write_loggers_stub(sandbox_dir) + _write_structlog_stub(sandbox_dir) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + source = open({str(HARDWARE_PY)!r}).read() + ns = {{'__name__': '__test__'}} + exec(source, ns) + result = ns['detect_hardware']() + assert result == ns['DeviceType'].CPU, f"Expected CPU, got {{result}}" + print("OK: detect_hardware returned CPU with fake torch (no CUDA)") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"Should fall back to CPU:\n{result.stderr.decode()}" + assert b"OK:" in result.stdout + + def test_lazy_torch_fails_at_call_time_not_import_time( + self, no_torch_venv, sandbox_dir + ): + """apply_chat_template_to_dataset is importable without torch. + + Calling the alpaca branch triggers the lazy 'from torch.utils.data' inside + the try block. This should fail at call time, not import time -- proving the + lazy import pattern works correctly. + """ + _write_loggers_stub(sandbox_dir) + + code = textwrap.dedent(f"""\ + import sys, types + sys.path.insert(0, {str(sandbox_dir)!r}) + fd = types.ModuleType('format_detection') + fd.detect_dataset_format = fd.detect_multimodal_dataset = fd.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = fd + mm = types.ModuleType('model_mappings') + mm.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = mm + + ns = {{}} + source = open({str(CHAT_TEMPLATES)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + exec(source, ns) + + # Import succeeds -- this is the fix + assert 'apply_chat_template_to_dataset' in ns + print("OK: import succeeded") + + # Calling alpaca branch triggers lazy torch import inside the try block. + # The function catches the error and returns it in the errors list. + dataset_info = {{ + 'dataset': type('D', (), {{'map': lambda *a, **k: None}})(), + 'final_format': 'alpaca', + 'chat_column': None, + 'is_standardized': True, + 'warnings': [], + }} + result = ns['apply_chat_template_to_dataset'](dataset_info, None) + # The function has a try/except that catches the error gracefully + if not result['success']: + print("OK: call-time failure caught gracefully") + else: + print("OK: call succeeded (unexpected but not a crash)") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert ( + result.returncode == 0 + ), f"Should not crash at import time:\n{result.stderr.decode()}" + assert b"OK: import succeeded" in result.stdout + + +# =========================================================================== +# Group 4: Hardware Detection Without Torch (3 tests) +# =========================================================================== + + +class TestHardwareDetectionNoTorch: + """Hardware module works without torch, falling back to CPU.""" + + def test_detect_hardware_no_torch(self, no_torch_venv, sandbox_dir): + """detect_hardware() returns CPU device when torch is not installed.""" + _write_loggers_stub(sandbox_dir) + _write_structlog_stub(sandbox_dir) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + source = open({str(HARDWARE_PY)!r}).read() + ns = {{'__name__': '__test__'}} + exec(source, ns) + device = ns['detect_hardware']() + assert device == ns['DeviceType'].CPU + assert ns['CHAT_ONLY'] is True + print("OK: detect_hardware returned CPU, CHAT_ONLY=True") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + assert b"OK:" in result.stdout + + def test_get_package_versions_no_torch(self, no_torch_venv, sandbox_dir): + """get_package_versions() returns torch=None, cuda=None without torch.""" + _write_loggers_stub(sandbox_dir) + _write_structlog_stub(sandbox_dir) + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + source = open({str(HARDWARE_PY)!r}).read() + ns = {{'__name__': '__test__'}} + exec(source, ns) + versions = ns['get_package_versions']() + assert versions['torch'] is None, f"Expected torch=None, got {{versions['torch']}}" + assert versions['cuda'] is None, f"Expected cuda=None, got {{versions['cuda']}}" + print("OK: torch=None, cuda=None") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + assert b"OK:" in result.stdout + + def test_hardware_module_import_no_torch(self, no_torch_venv, sandbox_dir): + """The hardware module imports and detect_hardware is callable without torch.""" + _write_loggers_stub(sandbox_dir) + _write_structlog_stub(sandbox_dir) + _write_hardware_stub(sandbox_dir) + + # Copy the real hardware module into a sandbox package + hw_sandbox = sandbox_dir / "hw_pkg" + hw_sandbox.mkdir() + (hw_sandbox / "__init__.py").write_text("", encoding = "utf-8") + shutil.copy2(HARDWARE_PY, hw_sandbox / "hardware.py") + + code = textwrap.dedent(f"""\ + import sys + sys.path.insert(0, {str(sandbox_dir)!r}) + source = open({str(hw_sandbox / 'hardware.py')!r}).read() + ns = {{'__name__': '__test__'}} + exec(source, ns) + assert callable(ns['detect_hardware']) + assert callable(ns['get_package_versions']) + assert callable(ns['is_apple_silicon']) + print("OK: all hardware functions accessible") + """) + result = _run_in_sandbox(no_torch_venv, code) + assert result.returncode == 0, f"Failed:\n{result.stderr.decode()}" + assert b"OK:" in result.stdout + + +# =========================================================================== +# Group 5: install.sh Logic (5 tests via bash subprocess) +# =========================================================================== + + +class TestInstallShLogic: + """Test install.sh flag parsing, platform detection, and guard logic.""" + + @pytest.fixture(autouse = True) + def _check_install_sh(self): + if not INSTALL_SH.is_file(): + pytest.skip("install.sh not found") + + def test_python_flag_parsing(self): + """--python flag correctly sets _USER_PYTHON.""" + # Extract flag parser snippet from install.sh and test it + script = textwrap.dedent("""\ + _USER_PYTHON="" + _next_is_python=false + for arg in "$@"; do + if [ "$_next_is_python" = true ]; then + _USER_PYTHON="$arg" + _next_is_python=false + continue + fi + case "$arg" in + --python) _next_is_python=true ;; + esac + done + echo "$_USER_PYTHON" + """) + # Test: --python 3.12 + r = _run_sh(f"{script}" + "\n", timeout = 10) + # Need to pass args to the script + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "--python", "3.12"], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"3.12" + + # Test: --local --python 3.11 + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "--local", "--python", "3.11"], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"3.11" + + # Test: no --python flag + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "--local"], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"" + + def test_python_flag_missing_arg_errors(self): + """--python without a version argument triggers an error.""" + # Extract the flag parser + error guard from install.sh + script = textwrap.dedent("""\ + set -e + _USER_PYTHON="" + _next_is_python=false + for arg in "$@"; do + if [ "$_next_is_python" = true ]; then + _USER_PYTHON="$arg" + _next_is_python=false + continue + fi + case "$arg" in + --python) _next_is_python=true ;; + esac + done + if [ "$_next_is_python" = true ]; then + echo "ERROR: --python requires a version argument" >&2 + exit 1 + fi + echo "$_USER_PYTHON" + """) + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "--python"], + capture_output = True, + timeout = 10, + ) + assert r.returncode != 0 + assert b"ERROR" in r.stderr + + def test_python_version_resolution(self): + """Python version defaults to 3.12 on Intel Mac, 3.13 elsewhere. + --python overrides both.""" + script = textwrap.dedent("""\ + MAC_INTEL="$1" + _USER_PYTHON="$2" + + if [ -n "$_USER_PYTHON" ]; then + PYTHON_VERSION="$_USER_PYTHON" + elif [ "$MAC_INTEL" = true ]; then + PYTHON_VERSION="3.12" + else + PYTHON_VERSION="3.13" + fi + echo "$PYTHON_VERSION" + """) + # Intel Mac, no override + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "true", ""], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"3.12" + + # Non-Intel, no override + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "false", ""], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"3.13" + + # Intel Mac with --python override + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "true", "3.11"], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"3.11" + + def test_mac_intel_detection_snippet(self): + """Architecture detection sets MAC_INTEL correctly for different platforms.""" + script = textwrap.dedent("""\ + OS="$1" + _ARCH="$2" + MAC_INTEL=false + if [ "$OS" = "macos" ] && [ "$_ARCH" = "x86_64" ]; then + MAC_INTEL=true + fi + echo "$MAC_INTEL" + """) + cases = [ + (("macos", "x86_64"), b"true"), + (("macos", "arm64"), b"false"), + (("linux", "x86_64"), b"false"), + (("linux", "aarch64"), b"false"), + ] + for (os_val, arch), expected in cases: + r = subprocess.run( + ["bash", "-c", script + "\n", "_", os_val, arch], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == expected, ( + f"MAC_INTEL for ({os_val}, {arch}): " + f"expected {expected!r}, got {r.stdout.strip()!r}" + ) + + def test_stale_venv_guard_respects_override(self): + """When _USER_PYTHON is set, the stale venv recreation guard is skipped.""" + # The guard: if MAC_INTEL=true && -z _USER_PYTHON && venv exists ... + script = textwrap.dedent("""\ + MAC_INTEL=true + _USER_PYTHON="$1" + _VENV_EXISTS=true # simulate existing venv + + SHOULD_RECREATE=false + if [ "$MAC_INTEL" = true ] && [ -z "$_USER_PYTHON" ] && [ "$_VENV_EXISTS" = true ]; then + SHOULD_RECREATE=true + fi + echo "$SHOULD_RECREATE" + """) + # With override: should NOT recreate + r = subprocess.run( + ["bash", "-c", script + "\n", "_", "3.11"], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"false" + + # Without override: SHOULD recreate + r = subprocess.run( + ["bash", "-c", script + "\n", "_", ""], + capture_output = True, + timeout = 10, + ) + assert r.stdout.strip() == b"true" + + +# =========================================================================== +# Group 6: install_python_stack.py NO_TORCH Filtering (4 tests) +# =========================================================================== + + +class TestInstallPythonStackFiltering: + """Test the NO_TORCH filtering logic in install_python_stack.py.""" + + @pytest.fixture(autouse = True) + def _check_install_py(self): + if not INSTALL_PY.is_file(): + pytest.skip("install_python_stack.py not found") + + def test_filter_requirements_removes_torch_deps(self): + """_filter_requirements removes all NO_TORCH_SKIP_PACKAGES from a real extras file.""" + import install_python_stack as ips + + extras = STUDIO_DIR / "backend" / "requirements" / "extras.txt" + if not extras.is_file(): + pytest.skip("extras.txt not found") + + result_path = ips._filter_requirements(extras, ips.NO_TORCH_SKIP_PACKAGES) + filtered = Path(result_path).read_text(encoding = "utf-8").lower() + + for pkg in ["torch-stoi", "timm", "openai-whisper", "transformers-cfg"]: + lines = [ + l.strip() + for l in filtered.splitlines() + if l.strip() and not l.strip().startswith("#") + ] + assert not any( + l.startswith(pkg) for l in lines + ), f"{pkg} should be removed from extras.txt" + + def test_filter_requirements_preserves_non_torch(self): + """Non-torch packages survive NO_TORCH filtering.""" + import install_python_stack as ips + + extras = STUDIO_DIR / "backend" / "requirements" / "extras.txt" + if not extras.is_file(): + pytest.skip("extras.txt not found") + + result_path = ips._filter_requirements(extras, ips.NO_TORCH_SKIP_PACKAGES) + filtered_text = Path(result_path).read_text(encoding = "utf-8").lower() + + must_survive = ["scikit-learn", "loguru", "tiktoken", "einops"] + original_text = extras.read_text(encoding = "utf-8").lower() + for pkg in must_survive: + if pkg in original_text: + assert pkg in filtered_text, f"{pkg} should survive NO_TORCH filtering" + + def test_infer_no_torch_env_var_overrides_platform(self): + """UNSLOTH_NO_TORCH=true on Linux -> True; =false on Intel Mac -> False.""" + import install_python_stack as ips + + # Explicit true on Linux + with ( + mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "true"}), + mock.patch.object(ips, "IS_MAC_INTEL", False), + ): + assert ips._infer_no_torch() is True + + # Explicit false on Intel Mac + with ( + mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}), + mock.patch.object(ips, "IS_MAC_INTEL", True), + ): + assert ips._infer_no_torch() is False + + # Unset on Intel Mac -> True (platform fallback) + env = os.environ.copy() + env.pop("UNSLOTH_NO_TORCH", None) + with ( + mock.patch.dict(os.environ, env, clear = True), + mock.patch.object(ips, "IS_MAC_INTEL", True), + ): + assert ips._infer_no_torch() is True + + def test_no_torch_skips_overrides_and_triton(self): + """When NO_TORCH=True, overrides.txt and triton are skipped (source guard check).""" + import install_python_stack as ips + + source = Path(ips.__file__).read_text(encoding = "utf-8") + + # NO_TORCH guard before overrides + assert ( + "if NO_TORCH:" in source + ), "NO_TORCH guard not found in install_python_stack.py" + + # macOS guard for triton + assert ( + "not IS_WINDOWS and not IS_MACOS" in source + ), "'not IS_WINDOWS and not IS_MACOS' guard for triton not found" + + +# =========================================================================== +# Group 7: Live Server Startup (4 tests) -- Heavyweight +# =========================================================================== + + +def _studio_venv_python() -> Path | None: + """Return the studio venv Python path, or None if not found.""" + py = _venv_python(STUDIO_VENV) + if py.exists(): + return py + return None + + +def _server_port() -> int: + """Find an available port for the test server.""" + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +server = pytest.mark.server + + +@server +class TestLiveServerStartup: + """Live server startup tests. + + These use the existing Studio venv at ~/.unsloth/studio/unsloth_studio. + They temporarily ensure torch is not importable, test server startup, + then leave the venv unchanged. + + Run separately: pytest -m server + """ + + @pytest.fixture(autouse = True) + def _check_studio_venv(self): + py = _studio_venv_python() + if py is None: + pytest.skip("Studio venv not found at ~/.unsloth/studio/unsloth_studio") + + @pytest.fixture(scope = "class") + def server_process(self): + """Start the studio backend server without torch, yield (proc, port), then stop.""" + py = _studio_venv_python() + if py is None: + pytest.skip("Studio venv not found") + + port = _server_port() + backend_dir = BACKEND_DIR + + # Check if torch is installed in the studio venv + check = subprocess.run( + [str(py), "-c", "import torch; print(torch.__version__)"], + capture_output = True, + ) + torch_was_installed = check.returncode == 0 + torch_version = check.stdout.decode().strip() if torch_was_installed else None + + # Uninstall torch if present + if torch_was_installed: + subprocess.run( + [ + str(py), + "-m", + "pip", + "uninstall", + "-y", + "torch", + "torchvision", + "torchaudio", + ], + capture_output = True, + timeout = 120, + ) + + # Start server + env = os.environ.copy() + env["PYTHONPATH"] = str(backend_dir) + proc = subprocess.Popen( + [str(py), str(backend_dir / "run.py"), "--port", str(port)], + env = env, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + cwd = str(backend_dir), + ) + + # Wait for server to be ready (poll /api/health) + import urllib.request + import urllib.error + + ready = False + for _ in range(30): + time.sleep(1) + try: + resp = urllib.request.urlopen( + f"http://127.0.0.1:{port}/api/health", timeout = 2 + ) + if resp.status == 200: + ready = True + break + except (urllib.error.URLError, ConnectionRefusedError, OSError): + continue + + if not ready: + stdout, stderr = proc.communicate(timeout = 5) + # Reinstall torch + torchvision + torchaudio + if torch_was_installed and torch_version: + subprocess.run( + [ + str(py), + "-m", + "pip", + "install", + f"torch=={torch_version}", + "torchvision", + "torchaudio", + ], + capture_output = True, + timeout = 300, + ) + server_output = stdout.decode(errors = "replace") + stderr.decode( + errors = "replace" + ) + pytest.skip( + f"Server failed to start within 30 seconds. Output:\n{server_output}" + ) + + yield proc, port + + # Cleanup: stop server, reinstall torch + proc.terminate() + try: + proc.wait(timeout = 10) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout = 5) + + if torch_was_installed and torch_version: + subprocess.run( + [ + str(py), + "-m", + "pip", + "install", + f"torch=={torch_version}", + "torchvision", + "torchaudio", + ], + capture_output = True, + timeout = 300, + ) + + def test_server_starts_without_torch(self, server_process): + """Server responds to /api/health with chat_only: true.""" + import json + import urllib.request + + _, port = server_process + resp = urllib.request.urlopen(f"http://127.0.0.1:{port}/api/health", timeout = 5) + data = json.loads(resp.read()) + assert data["status"] == "healthy" + assert data["chat_only"] is True + + def test_all_routes_registered(self, server_process): + """OpenAPI spec shows >= 20 paths (server started fully).""" + import json + import urllib.request + + _, port = server_process + resp = urllib.request.urlopen( + f"http://127.0.0.1:{port}/openapi.json", timeout = 5 + ) + spec = json.loads(resp.read()) + assert ( + len(spec.get("paths", {})) >= 20 + ), f"Expected >= 20 routes, got {len(spec.get('paths', {}))}" + + def test_hardware_endpoint_no_torch(self, server_process): + """GET /api/system/hardware returns torch=null, gpu_name=null.""" + import json + import urllib.request + + _, port = server_process + resp = urllib.request.urlopen( + f"http://127.0.0.1:{port}/api/system/hardware", + timeout = 5, + ) + data = json.loads(resp.read()) + versions = data.get("versions", {}) + assert versions.get("torch") is None + assert versions.get("cuda") is None + + def test_server_survives_multiple_requests(self, server_process): + """Hit 5 different endpoints. Server PID should still be alive after.""" + import urllib.request + import urllib.error + + proc, port = server_process + endpoints = [ + "/api/health", + "/openapi.json", + "/api/system/hardware", + "/api/health", + "/docs", + ] + for ep in endpoints: + try: + urllib.request.urlopen(f"http://127.0.0.1:{port}{ep}", timeout = 5) + except urllib.error.HTTPError: + pass # 4xx/5xx is fine -- server didn't crash + except urllib.error.URLError: + pytest.fail(f"Server stopped responding at {ep}") + + assert proc.poll() is None, "Server process should still be running" diff --git a/tests/python/test_no_torch_filtering.py b/tests/python/test_no_torch_filtering.py new file mode 100644 index 000000000..5c2926a1f --- /dev/null +++ b/tests/python/test_no_torch_filtering.py @@ -0,0 +1,753 @@ +"""Tests for install_python_stack NO_TORCH / IS_MACOS filtering logic. + +Covers: +- _filter_requirements unit tests (synthetic + REAL requirements files) +- NO_TORCH / IS_MACOS / IS_WINDOWS env var parsing +- Subprocess-mock of install_python_stack() to verify overrides/triton/filtering + actually happen (or get skipped) under each platform/config combination +- VCS URL and environment marker edge cases in filtering +""" + +from __future__ import annotations + +import importlib +import os +import re +import subprocess +import sys +import textwrap +from pathlib import Path +from unittest import mock + +import pytest + +# Add the studio directory so we can import install_python_stack +STUDIO_DIR = Path(__file__).resolve().parents[2] / "studio" +sys.path.insert(0, str(STUDIO_DIR)) + +import install_python_stack as ips + +# Paths to the REAL requirements files +REQ_ROOT = Path(__file__).resolve().parents[2] / "studio" / "backend" / "requirements" +EXTRAS_TXT = REQ_ROOT / "extras.txt" +EXTRAS_NO_DEPS_TXT = REQ_ROOT / "extras-no-deps.txt" +OVERRIDES_TXT = REQ_ROOT / "overrides.txt" +TRITON_KERNELS_TXT = REQ_ROOT / "triton-kernels.txt" + + +# ── _filter_requirements unit tests (synthetic) ─────────────────────── + + +class TestFilterRequirements: + """Verify _filter_requirements correctly removes packages by prefix.""" + + def _write_req(self, tmp_path: Path, content: str) -> Path: + req = tmp_path / "requirements.txt" + req.write_text(textwrap.dedent(content), encoding = "utf-8") + return req + + def test_filters_no_torch_packages(self, tmp_path): + req = self._write_req( + tmp_path, + """\ + torch-stoi==0.1 + timm>=1.0 + numpy + torchcodec>=0.1 + torch-c-dlpack-ext + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + # Only numpy should remain (non-blank lines) + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == ["numpy"], f"Expected only numpy, got: {non_blank}" + + def test_empty_file(self, tmp_path): + req = self._write_req(tmp_path, "") + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + content = Path(result).read_text(encoding = "utf-8") + assert content.strip() == "" + + def test_comments_preserved(self, tmp_path): + req = self._write_req( + tmp_path, + """\ + # torch-stoi is needed for audio + numpy + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + # Comment starts with "#", not "torch-stoi", so it's preserved + assert len(non_blank) == 2 + assert non_blank[0].startswith("#") + assert non_blank[1] == "numpy" + + def test_version_specifiers_filtered(self, tmp_path): + req = self._write_req( + tmp_path, + """\ + torch-stoi>=0.1.0 + timm==1.2.3 + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == [], f"Expected empty, got: {non_blank}" + + def test_prefix_match_catches_extensions(self, tmp_path): + """Prefix matching catches torch-stoi-extra (correct for pip names).""" + req = self._write_req( + tmp_path, + """\ + torch-stoi-extra + numpy + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == ["numpy"] + + def test_mixed_case_filtered(self, tmp_path): + """Package names are lowercased before matching.""" + req = self._write_req( + tmp_path, + """\ + Timm>=1.0 + TORCH-STOI + numpy + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == ["numpy"] + + def test_whitespace_and_blank_lines_preserved(self, tmp_path): + req = self._write_req( + tmp_path, + """\ + numpy + + pandas + + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + content = Path(result).read_text(encoding = "utf-8") + # Blank lines should be preserved (not stripped) + assert "\n\n" in content or content.count("\n") >= 3 + + def test_stacked_windows_and_no_torch_filters(self, tmp_path): + """Both WINDOWS_SKIP_PACKAGES and NO_TORCH_SKIP_PACKAGES applied.""" + req = self._write_req( + tmp_path, + """\ + open_spiel + triton_kernels + torch-stoi + timm + numpy + """, + ) + # First filter Windows packages, then NO_TORCH packages + intermediate = ips._filter_requirements(req, ips.WINDOWS_SKIP_PACKAGES) + result = ips._filter_requirements( + Path(intermediate), ips.NO_TORCH_SKIP_PACKAGES + ) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == [ + "numpy" + ], f"Expected only numpy after stacked filters, got: {non_blank}" + + def test_vcs_url_with_skip_package_name(self, tmp_path): + """VCS URLs like git+https://...torch-stoi should also be filtered (startswith matches).""" + req = self._write_req( + tmp_path, + """\ + numpy + torch-stoi @ git+https://github.com/example/torch-stoi.git + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == [ + "numpy" + ], f"VCS URL line should be filtered, got: {non_blank}" + + def test_env_marker_line_filtered(self, tmp_path): + """Package lines with env markers are still filtered by prefix.""" + req = self._write_req( + tmp_path, + """\ + timm>=1.0; python_version>="3.10" + numpy + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + assert non_blank == [ + "numpy" + ], f"Env marker line should be filtered, got: {non_blank}" + + def test_git_plus_url_not_over_matched(self, tmp_path): + """A git+ URL whose path contains a skip package name but does NOT start with it.""" + req = self._write_req( + tmp_path, + """\ + git+https://github.com/meta-pytorch/OpenEnv.git + numpy + """, + ) + result = ips._filter_requirements(req, ips.NO_TORCH_SKIP_PACKAGES) + lines = Path(result).read_text(encoding = "utf-8").splitlines() + non_blank = [l.strip() for l in lines if l.strip()] + # The git+ URL doesn't start with any skip package, so it is preserved + assert len(non_blank) == 2, f"git+ URL should be preserved, got: {non_blank}" + + +# ── Real requirements file filtering ────────────────────────────────── + + +class TestRealRequirementsFiltering: + """Filter the ACTUAL extras.txt and extras-no-deps.txt with NO_TORCH_SKIP_PACKAGES.""" + + @pytest.fixture(autouse = True) + def _check_req_files(self): + if not EXTRAS_TXT.is_file(): + pytest.skip("extras.txt not found in repo") + if not EXTRAS_NO_DEPS_TXT.is_file(): + pytest.skip("extras-no-deps.txt not found in repo") + + def _non_blank_non_comment(self, path: Path) -> list[str]: + """Return non-blank, non-comment lines from a requirements file.""" + lines = path.read_text(encoding = "utf-8").splitlines() + return [l.strip() for l in lines if l.strip() and not l.strip().startswith("#")] + + def test_extras_txt_torch_packages_removed(self): + """extras.txt: all NO_TORCH_SKIP_PACKAGES must be removed, everything else preserved.""" + result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES) + filtered = self._non_blank_non_comment(Path(result)) + original = self._non_blank_non_comment(EXTRAS_TXT) + + # These must be gone + for pkg in ["torch-stoi", "timm", "openai-whisper", "transformers-cfg"]: + assert not any( + l.lower().startswith(pkg) for l in filtered + ), f"{pkg} should be removed from extras.txt" + + # Everything else must remain + expected = [ + l + for l in original + if not any( + l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES + ) + ] + assert filtered == expected, ( + f"Filtered extras.txt should match expected.\n" + f"Missing: {set(expected) - set(filtered)}\n" + f"Extra: {set(filtered) - set(expected)}" + ) + + def test_extras_no_deps_txt_torchcodec_and_dlpack_removed(self): + """extras-no-deps.txt: torchcodec and torch-c-dlpack-ext must be removed.""" + result = ips._filter_requirements( + EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES + ) + filtered = self._non_blank_non_comment(Path(result)) + original = self._non_blank_non_comment(EXTRAS_NO_DEPS_TXT) + + for pkg in ["torchcodec", "torch-c-dlpack-ext"]: + assert not any( + l.lower().startswith(pkg) for l in filtered + ), f"{pkg} should be removed from extras-no-deps.txt" + + expected = [ + l + for l in original + if not any( + l.strip().lower().startswith(p) for p in ips.NO_TORCH_SKIP_PACKAGES + ) + ] + assert filtered == expected + + def test_extras_txt_most_packages_preserved(self): + """Ensure a representative set of non-torch packages survive filtering.""" + result = ips._filter_requirements(EXTRAS_TXT, ips.NO_TORCH_SKIP_PACKAGES) + filtered_text = Path(result).read_text(encoding = "utf-8").lower() + + must_survive = ["scikit-learn", "loguru", "tiktoken", "einops", "tabulate"] + for pkg in must_survive: + if pkg in EXTRAS_TXT.read_text(encoding = "utf-8").lower(): + assert pkg in filtered_text, f"{pkg} should survive NO_TORCH filtering" + + def test_extras_no_deps_txt_trl_preserved(self): + """trl should survive NO_TORCH filtering in extras-no-deps.txt.""" + result = ips._filter_requirements( + EXTRAS_NO_DEPS_TXT, ips.NO_TORCH_SKIP_PACKAGES + ) + filtered_text = Path(result).read_text(encoding = "utf-8").lower() + assert "trl" in filtered_text, "trl should survive NO_TORCH filtering" + + +# ── NO_TORCH constant tests ────────────────────────────────────────── + + +class TestNoTorchConstant: + """Verify NO_TORCH is derived correctly from UNSLOTH_NO_TORCH env var.""" + + def _reimport_no_torch(self) -> bool: + return os.environ.get("UNSLOTH_NO_TORCH", "false").lower() in ("1", "true") + + def test_true_lowercase(self): + with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "true"}): + assert self._reimport_no_torch() is True + + def test_true_one(self): + with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "1"}): + assert self._reimport_no_torch() is True + + def test_true_uppercase(self): + with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "TRUE"}): + assert self._reimport_no_torch() is True + + def test_false_string(self): + with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}): + assert self._reimport_no_torch() is False + + def test_false_zero(self): + with mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "0"}): + assert self._reimport_no_torch() is False + + def test_not_set(self): + env = os.environ.copy() + env.pop("UNSLOTH_NO_TORCH", None) + with mock.patch.dict(os.environ, env, clear = True): + assert self._reimport_no_torch() is False + + def test_infer_no_torch_on_intel_mac(self): + """_infer_no_torch falls back to platform detection when env var is unset.""" + env = os.environ.copy() + env.pop("UNSLOTH_NO_TORCH", None) + with ( + mock.patch.dict(os.environ, env, clear = True), + mock.patch.object(ips, "IS_MAC_INTEL", True), + ): + assert ips._infer_no_torch() is True + + def test_infer_no_torch_respects_explicit_false_on_intel_mac(self): + """Explicit UNSLOTH_NO_TORCH=false overrides platform detection.""" + with ( + mock.patch.dict(os.environ, {"UNSLOTH_NO_TORCH": "false"}), + mock.patch.object(ips, "IS_MAC_INTEL", True), + ): + assert ips._infer_no_torch() is False + + def test_infer_no_torch_linux_unset(self): + """On Linux with env var unset, _infer_no_torch returns False.""" + env = os.environ.copy() + env.pop("UNSLOTH_NO_TORCH", None) + with ( + mock.patch.dict(os.environ, env, clear = True), + mock.patch.object(ips, "IS_MAC_INTEL", False), + ): + assert ips._infer_no_torch() is False + + +# ── IS_MACOS constant tests ────────────────────────────────────────── + + +class TestIsMacosConstant: + """Verify IS_MACOS detection logic.""" + + def test_is_macos_matches_platform(self): + import sys + + expected = sys.platform == "darwin" + assert ips.IS_MACOS is expected + + +# ── Subprocess mock of install_python_stack() ───────────────────────── + + +class TestInstallPythonStackSubprocessMock: + """Monkeypatch subprocess.run to capture all pip/uv commands, + then verify which requirements files are used/skipped under + different NO_TORCH / IS_MACOS / IS_WINDOWS configurations.""" + + @pytest.fixture(autouse = True) + def _check_req_files(self): + """Skip if requirements files are missing.""" + for f in [EXTRAS_TXT, EXTRAS_NO_DEPS_TXT, OVERRIDES_TXT]: + if not f.is_file(): + pytest.skip(f"{f.name} not found in repo") + + def _capture_install( + self, + no_torch: bool, + is_macos: bool, + is_windows: bool, + *, + skip_base: bool = True, + ): + """Run install_python_stack() with mocked subprocess, capturing all commands. + + Returns a list of string-joined commands (each element is ' '.join(cmd)). + """ + captured_cmds: list[list[str]] = [] + + def mock_run(cmd, **kw): + captured_cmds.append( + list(cmd) if isinstance(cmd, (list, tuple)) else [str(cmd)] + ) + return subprocess.CompletedProcess(cmd, 0, b"", b"") + + env = {"SKIP_STUDIO_BASE": "1"} if skip_base else {} + + with ( + mock.patch.object(ips, "NO_TORCH", no_torch), + mock.patch.object(ips, "IS_MACOS", is_macos), + mock.patch.object(ips, "IS_WINDOWS", is_windows), + mock.patch.object(ips, "USE_UV", True), + mock.patch.object(ips, "UV_NEEDS_SYSTEM", False), + mock.patch.object(ips, "VERBOSE", False), + mock.patch("subprocess.run", side_effect = mock_run), + mock.patch.object(ips, "_bootstrap_uv", return_value = True), + mock.patch.object( + ips, "LOCAL_DD_UNSTRUCTURED_PLUGIN", Path("/fake/plugin") + ), + mock.patch("pathlib.Path.is_dir", return_value = True), + mock.patch("pathlib.Path.is_file", return_value = True), + ): + with mock.patch.dict(os.environ, env, clear = False): + ips.install_python_stack() + + return [" ".join(str(c) for c in cmd) for cmd in captured_cmds] + + def _cmds_contain_file(self, cmds: list[str], filename: str) -> bool: + """Check if any captured command references the given filename.""" + return any(filename in cmd for cmd in cmds) + + # -- NO_TORCH=True, IS_MACOS=True (Intel Mac scenario) -- + + def test_no_torch_macos_skips_overrides(self): + """With NO_TORCH=True, overrides.txt pip_install must NOT be called.""" + cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False) + assert not self._cmds_contain_file( + cmds, "overrides.txt" + ), "overrides.txt should be skipped when NO_TORCH=True" + + def test_no_torch_macos_skips_triton(self): + """With IS_MACOS=True, triton-kernels.txt must NOT be called.""" + cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False) + assert not self._cmds_contain_file( + cmds, "triton-kernels.txt" + ), "triton-kernels.txt should be skipped on macOS" + + def test_no_torch_macos_extras_called(self): + """With NO_TORCH=True, extras.txt is still called (but filtered).""" + cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False) + has_extras = self._cmds_contain_file(cmds, "extras.txt") or any( + "-r" in cmd and "tmp" in cmd.lower() for cmd in cmds + ) + assert has_extras, "extras.txt (or its filtered temp) should be called" + + def test_no_torch_macos_extras_no_deps_called(self): + """With NO_TORCH=True, extras-no-deps.txt is still called (but filtered).""" + cmds = self._capture_install(no_torch = True, is_macos = True, is_windows = False) + has_extras_nd = self._cmds_contain_file(cmds, "extras-no-deps.txt") or any( + "-r" in cmd and "tmp" in cmd.lower() for cmd in cmds + ) + assert ( + has_extras_nd + ), "extras-no-deps.txt (or its filtered temp) should be called" + + # -- IS_WINDOWS=True + NO_TORCH=True (stacked) -- + + def test_windows_no_torch_skips_overrides(self): + """Windows+NO_TORCH: overrides.txt must be skipped.""" + cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True) + assert not self._cmds_contain_file( + cmds, "overrides.txt" + ), "overrides.txt should be skipped with NO_TORCH=True on Windows" + + def test_windows_no_torch_skips_triton(self): + """Windows: triton-kernels.txt must be skipped (IS_WINDOWS guard).""" + cmds = self._capture_install(no_torch = True, is_macos = False, is_windows = True) + assert not self._cmds_contain_file( + cmds, "triton-kernels.txt" + ), "triton-kernels.txt should be skipped on Windows" + + # -- Normal Linux path (NO_TORCH=False, IS_MACOS=False, IS_WINDOWS=False) -- + + def test_normal_linux_includes_overrides(self): + """Normal Linux: overrides.txt IS called.""" + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False) + assert self._cmds_contain_file( + cmds, "overrides.txt" + ), "overrides.txt should be called on normal Linux" + + def test_normal_linux_includes_triton(self): + """Normal Linux: triton-kernels.txt IS called.""" + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False) + assert self._cmds_contain_file( + cmds, "triton-kernels.txt" + ), "triton-kernels.txt should be called on normal Linux" + + def test_normal_linux_includes_extras(self): + """Normal Linux: extras.txt IS called (no filtering).""" + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False) + assert self._cmds_contain_file( + cmds, "extras.txt" + ), "extras.txt should be called on normal Linux" + + def test_normal_linux_includes_extras_no_deps(self): + """Normal Linux: extras-no-deps.txt IS called (no filtering).""" + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = False) + assert self._cmds_contain_file( + cmds, "extras-no-deps.txt" + ), "extras-no-deps.txt should be called on normal Linux" + + # -- Windows-only (NO_TORCH=False) to verify triton is still skipped -- + + def test_windows_only_skips_triton(self): + """Windows (without NO_TORCH): triton still skipped.""" + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True) + assert not self._cmds_contain_file( + cmds, "triton-kernels.txt" + ), "triton-kernels.txt should be skipped on Windows even without NO_TORCH" + + def test_windows_only_includes_overrides(self): + """Windows (without NO_TORCH): overrides IS called (via filtered temp file). + + On Windows, all req files go through _filter_requirements(WINDOWS_SKIP_PACKAGES), + so the command uses a temp file, not overrides.txt directly. We check for + --reinstall (uv translation of --force-reinstall) which is unique to overrides. + """ + cmds = self._capture_install(no_torch = False, is_macos = False, is_windows = True) + assert any( + "--reinstall" in cmd for cmd in cmds + ), "overrides step (--reinstall) should be called on Windows when NO_TORCH=False" + + # -- Update path (skip_base=False) to verify no-torch mode is durable -- + + def test_update_path_intel_macos_still_skips_overrides(self): + """Update path (no SKIP_STUDIO_BASE): overrides still skipped on Intel Mac.""" + cmds = self._capture_install( + no_torch = True, is_macos = True, is_windows = False, skip_base = False + ) + assert not self._cmds_contain_file( + cmds, "overrides.txt" + ), "overrides.txt should be skipped on Intel Mac even via studio update" + + def test_update_path_intel_macos_still_skips_triton(self): + """Update path (no SKIP_STUDIO_BASE): triton still skipped on macOS.""" + cmds = self._capture_install( + no_torch = True, is_macos = True, is_windows = False, skip_base = False + ) + assert not self._cmds_contain_file( + cmds, "triton-kernels.txt" + ), "triton-kernels.txt should be skipped on macOS even via studio update" + + +# ── Overrides skip structural checks ───────────────────────────────── + + +class TestOverridesSkip: + """Verify overrides.txt is skipped when NO_TORCH is True (source-level check).""" + + def test_no_torch_guard_exists_in_source(self): + """The install_python_stack source must contain a NO_TORCH guard around overrides.""" + source = Path(ips.__file__).read_text(encoding = "utf-8") + assert ( + "if NO_TORCH:" in source + ), "NO_TORCH guard not found in install_python_stack.py" + + def test_overrides_skipped_when_no_torch(self): + """With NO_TORCH=True on the module, pip_install should NOT be called for overrides.""" + source = Path(ips.__file__).read_text(encoding = "utf-8") + overrides_match = re.search(r"if NO_TORCH:.*?overrides", source, re.DOTALL) + assert ( + overrides_match is not None + ), "Expected NO_TORCH conditional before overrides install" + + +# ── install.sh --no-torch flag tests ────────────────────────────────── + + +class TestInstallShNoTorchFlag: + """Verify install.sh has the --no-torch flag and SKIP_TORCH variable.""" + + @pytest.fixture(autouse = True) + def _check_install_sh(self): + install_sh = Path(__file__).resolve().parents[2] / "install.sh" + if not install_sh.is_file(): + pytest.skip("install.sh not found") + self.install_sh = install_sh + self.source = install_sh.read_text(encoding = "utf-8") + + def test_no_torch_flag_in_case_statement(self): + """--no-torch must appear in the flag parser case statement.""" + assert ( + "--no-torch)" in self.source + ), "--no-torch not found in install.sh flag parser" + + def test_no_torch_flag_variable_initialized(self): + """_NO_TORCH_FLAG must be initialized to false.""" + assert ( + "_NO_TORCH_FLAG=false" in self.source + ), "_NO_TORCH_FLAG=false not found in install.sh" + + def test_skip_torch_variable_exists(self): + """SKIP_TORCH variable must be defined.""" + assert ( + "SKIP_TORCH=false" in self.source + ), "SKIP_TORCH=false not found in install.sh" + assert ( + "SKIP_TORCH=true" in self.source + ), "SKIP_TORCH=true not found in install.sh" + + def test_skip_torch_driven_by_flag_and_mac_intel(self): + """SKIP_TORCH must check both _NO_TORCH_FLAG and MAC_INTEL.""" + assert ( + "_NO_TORCH_FLAG" in self.source + ), "_NO_TORCH_FLAG not referenced in SKIP_TORCH logic" + assert ( + "MAC_INTEL" in self.source + ), "MAC_INTEL not referenced in SKIP_TORCH logic" + + def test_unsloth_no_torch_uses_skip_torch(self): + """UNSLOTH_NO_TORCH must reference $SKIP_TORCH, not $MAC_INTEL.""" + import re + + matches = re.findall(r'UNSLOTH_NO_TORCH="\$(\w+)"', self.source) + for var in matches: + assert ( + var == "SKIP_TORCH" + ), f"UNSLOTH_NO_TORCH references ${var} instead of $SKIP_TORCH" + + def test_cpu_hint_message_exists(self): + """CPU hint message must exist in install.sh.""" + assert ( + "No NVIDIA GPU detected" in self.source + ), "CPU hint message not found in install.sh" + assert ( + "--no-torch" in self.source + ), "--no-torch suggestion not found in CPU hint" + + def test_no_torch_flag_parsing_subprocess(self): + """--no-torch flag sets _NO_TORCH_FLAG=true (subprocess test).""" + script = textwrap.dedent("""\ + _NO_TORCH_FLAG=false + _next_is_package=false + STUDIO_LOCAL_INSTALL=false + PACKAGE_NAME="unsloth" + for arg in "$@"; do + if [ "$_next_is_package" = true ]; then + PACKAGE_NAME="$arg" + _next_is_package=false + continue + fi + case "$arg" in + --local) STUDIO_LOCAL_INSTALL=true ;; + --package) _next_is_package=true ;; + --no-torch) _NO_TORCH_FLAG=true ;; + esac + done + echo "$_NO_TORCH_FLAG" + """) + result = subprocess.run( + ["bash", "-c", script, "_", "--no-torch"], + capture_output = True, + text = True, + ) + assert ( + result.stdout.strip() == "true" + ), f"Expected _NO_TORCH_FLAG=true, got: {result.stdout.strip()}" + + def test_no_torch_with_local_flag(self): + """--no-torch and --local can be used together.""" + script = textwrap.dedent("""\ + _NO_TORCH_FLAG=false + _next_is_package=false + STUDIO_LOCAL_INSTALL=false + PACKAGE_NAME="unsloth" + for arg in "$@"; do + if [ "$_next_is_package" = true ]; then + PACKAGE_NAME="$arg" + _next_is_package=false + continue + fi + case "$arg" in + --local) STUDIO_LOCAL_INSTALL=true ;; + --package) _next_is_package=true ;; + --no-torch) _NO_TORCH_FLAG=true ;; + esac + done + echo "$_NO_TORCH_FLAG $STUDIO_LOCAL_INSTALL" + """) + result = subprocess.run( + ["bash", "-c", script, "_", "--local", "--no-torch"], + capture_output = True, + text = True, + ) + assert ( + result.stdout.strip() == "true true" + ), f"Expected 'true true', got: {result.stdout.strip()}" + + def test_cpu_hint_only_when_not_skip_torch(self): + """CPU hint should only print when SKIP_TORCH=false and OS!=macos.""" + script = textwrap.dedent("""\ + TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu" + SKIP_TORCH=false + OS="linux" + case "$TORCH_INDEX_URL" in + */cpu) + if [ "$SKIP_TORCH" = false ] && [ "$OS" != "macos" ]; then + echo "HINT_PRINTED" + fi + ;; + esac + """) + result = subprocess.run( + ["bash", "-c", script], + capture_output = True, + text = True, + ) + assert "HINT_PRINTED" in result.stdout, "CPU hint should print" + + # With SKIP_TORCH=true, hint should NOT print + script2 = script.replace("SKIP_TORCH=false", "SKIP_TORCH=true") + result2 = subprocess.run( + ["bash", "-c", script2], + capture_output = True, + text = True, + ) + assert ( + "HINT_PRINTED" not in result2.stdout + ), "CPU hint should NOT print when SKIP_TORCH=true" + + +# ── Triton macOS skip structural checks ────────────────────────────── + + +class TestTritonMacosSkip: + """Verify triton is skipped on macOS (source-level check).""" + + def test_triton_guard_in_source(self): + """Source must skip triton on both Windows and macOS.""" + source = Path(ips.__file__).read_text(encoding = "utf-8") + assert ( + "not IS_MACOS" in source + ), "IS_MACOS guard for triton not found in install_python_stack.py" + assert ( + "not IS_WINDOWS and not IS_MACOS" in source + ), "Expected 'not IS_WINDOWS and not IS_MACOS' guard for triton" diff --git a/tests/python/test_studio_import_no_torch.py b/tests/python/test_studio_import_no_torch.py new file mode 100644 index 000000000..5592a282f --- /dev/null +++ b/tests/python/test_studio_import_no_torch.py @@ -0,0 +1,582 @@ +"""End-to-end sandbox tests: Studio modules in isolated no-torch venvs. + +Covers: +- Python 3.12 and 3.13 venv creation (Intel Mac uses 3.12, Apple Silicon/Linux 3.13) +- data_collators.py loads and dataclasses instantiate without torch +- chat_templates.py top-level exec works with stubs for relative imports +- Negative control: prepending 'import torch' fails in no-torch venv +- Negative control: installing torchao (from overrides.txt) fails in no-torch venv +- AST structural checks for top-level torch imports +""" + +from __future__ import annotations + +import ast +import os +import shutil +import subprocess +import sys +import tempfile +import textwrap +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] +DATA_COLLATORS = ( + REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "data_collators.py" +) +CHAT_TEMPLATES = ( + REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "chat_templates.py" +) +FORMAT_CONVERSION = ( + REPO_ROOT / "studio" / "backend" / "utils" / "datasets" / "format_conversion.py" +) + + +def _has_uv() -> bool: + return shutil.which("uv") is not None + + +def _create_venv(venv_dir: Path, python_version: str) -> Path | None: + """Create a uv venv at the given Python version. Returns python path or None.""" + result = subprocess.run( + ["uv", "venv", str(venv_dir), "--python", python_version], + capture_output = True, + ) + if result.returncode != 0: + return None + venv_python = venv_dir / "bin" / "python" + if not venv_python.exists(): + venv_python = venv_dir / "Scripts" / "python.exe" + return venv_python if venv_python.exists() else None + + +@pytest.fixture(params = ["3.12", "3.13"], scope = "module") +def no_torch_venv(request, tmp_path_factory): + """Create a temporary venv at the requested Python version with no torch. + + Parametrized for 3.12 (Intel Mac) and 3.13 (Apple Silicon / Linux). + """ + if not _has_uv(): + pytest.skip("uv not available") + + py_version = request.param + venv_dir = tmp_path_factory.mktemp(f"no_torch_venv_{py_version}") + venv_python = _create_venv(venv_dir, py_version) + if venv_python is None: + pytest.skip(f"Could not create Python {py_version} venv") + + # Verify torch is NOT importable + check = subprocess.run( + [str(venv_python), "-c", "import torch"], + capture_output = True, + ) + assert ( + check.returncode != 0 + ), f"torch should NOT be importable in fresh {py_version} venv" + + return str(venv_python) + + +# ── AST structural checks ───────────────────────────────────────────── + + +class TestDataCollatorsAST: + """Static analysis: data_collators.py has no top-level torch imports.""" + + def test_ast_parse(self): + """data_collators.py must be valid Python syntax.""" + source = DATA_COLLATORS.read_text(encoding = "utf-8") + tree = ast.parse(source, filename = str(DATA_COLLATORS)) + assert tree is not None + + def test_no_top_level_torch_import(self): + """No top-level 'import torch' or 'from torch' statements.""" + source = DATA_COLLATORS.read_text(encoding = "utf-8") + tree = ast.parse(source) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import): + for alias in node.names: + assert not alias.name.startswith( + "torch" + ), f"Top-level 'import {alias.name}' found at line {node.lineno}" + elif isinstance(node, ast.ImportFrom): + if node.module: + assert not node.module.startswith( + "torch" + ), f"Top-level 'from {node.module}' found at line {node.lineno}" + + +class TestChatTemplatesAST: + """Static analysis: chat_templates.py has no top-level torch imports.""" + + def test_ast_parse(self): + """chat_templates.py must be valid Python syntax.""" + source = CHAT_TEMPLATES.read_text(encoding = "utf-8") + tree = ast.parse(source, filename = str(CHAT_TEMPLATES)) + assert tree is not None + + def test_no_top_level_torch_import(self): + """No top-level 'import torch' or 'from torch' at module level.""" + source = CHAT_TEMPLATES.read_text(encoding = "utf-8") + tree = ast.parse(source) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import): + for alias in node.names: + assert not alias.name.startswith( + "torch" + ), f"Top-level 'import {alias.name}' found at line {node.lineno}" + elif isinstance(node, ast.ImportFrom): + if node.module: + assert not node.module.startswith( + "torch" + ), f"Top-level 'from {node.module}' found at line {node.lineno}" + + def test_torch_imports_only_inside_functions(self): + """All 'from torch' imports must be inside function/method bodies.""" + source = CHAT_TEMPLATES.read_text(encoding = "utf-8") + tree = ast.parse(source) + torch_imports = [] + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + module = None + if isinstance(node, ast.ImportFrom): + module = node.module + elif isinstance(node, ast.Import): + module = node.names[0].name if node.names else None + if module and module.startswith("torch"): + torch_imports.append(node) + + top_level = set(id(n) for n in ast.iter_child_nodes(tree)) + for imp in torch_imports: + assert id(imp) not in top_level, ( + f"torch import at line {imp.lineno} is at top level" + " (should be inside a function)" + ) + + +# ── data_collators.py: exec + dataclass instantiation in no-torch venv ── + + +class TestDataCollatorsNoTorchVenv: + """Run data_collators.py in an isolated no-torch venv, verify classes load.""" + + def test_exec_in_no_torch_venv(self, no_torch_venv): + """data_collators.py executes in a venv without torch (with loggers stub).""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + print("OK: exec succeeded") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"data_collators.py failed in no-torch venv:\n{result.stderr.decode()}" + assert b"OK: exec succeeded" in result.stdout + + def test_dataclass_speech_collator_instantiable(self, no_torch_venv): + """DataCollatorSpeechSeq2SeqWithPadding can be instantiated with processor=None.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = DataCollatorSpeechSeq2SeqWithPadding(processor=None) + assert obj.processor is None, "processor should be None" + print("OK: DataCollatorSpeechSeq2SeqWithPadding instantiated") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"DataCollatorSpeechSeq2SeqWithPadding failed:\n{result.stderr.decode()}" + assert b"OK: DataCollatorSpeechSeq2SeqWithPadding instantiated" in result.stdout + + def test_dataclass_deepseek_collator_instantiable(self, no_torch_venv): + """DeepSeekOCRDataCollator can be instantiated with processor=None.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = DeepSeekOCRDataCollator(processor=None) + assert obj.processor is None, "processor should be None" + assert obj.max_length == 2048, "default max_length should be 2048" + assert obj.ignore_index == -100, "default ignore_index should be -100" + print("OK: DeepSeekOCRDataCollator instantiated") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"DeepSeekOCRDataCollator failed:\n{result.stderr.decode()}" + assert b"OK: DeepSeekOCRDataCollator instantiated" in result.stdout + + def test_dataclass_vlm_collator_instantiable(self, no_torch_venv): + """VLMDataCollator can be instantiated with processor=None.""" + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({str(DATA_COLLATORS)!r}).read()) + obj = VLMDataCollator(processor=None) + assert obj.processor is None + assert obj.mask_input_tokens is True, "default mask_input_tokens should be True" + print("OK: VLMDataCollator instantiated") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"VLMDataCollator failed:\n{result.stderr.decode()}" + assert b"OK: VLMDataCollator instantiated" in result.stdout + + +# ── chat_templates.py: exec in no-torch venv ───────────────────────── + + +class TestChatTemplatesNoTorchVenv: + """Run chat_templates.py in an isolated no-torch venv with stubs.""" + + def test_exec_with_stubs(self, no_torch_venv): + """chat_templates.py top-level exec works with stubs for relative imports.""" + code = textwrap.dedent(f"""\ + import sys, types + + # Stub loggers + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})() + sys.modules['loggers'] = loggers + + # Stub relative imports (.format_detection, .model_mappings) + format_detection = types.ModuleType('format_detection') + format_detection.detect_dataset_format = lambda *a, **k: None + format_detection.detect_multimodal_dataset = lambda *a, **k: None + format_detection.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = format_detection + + model_mappings = types.ModuleType('model_mappings') + model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = model_mappings + + # Read and transform the source: replace relative imports with absolute + source = open({str(CHAT_TEMPLATES)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + + exec(source) + + # Verify module-level constants are defined + ns = dict(locals()) + assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined after exec" + print("OK: chat_templates.py exec succeeded") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"chat_templates.py failed in no-torch venv:\n{result.stderr.decode()}" + assert b"OK: chat_templates.py exec succeeded" in result.stdout + + def test_default_alpaca_template_defined(self, no_torch_venv): + """DEFAULT_ALPACA_TEMPLATE constant is accessible after exec.""" + code = textwrap.dedent(f"""\ + import sys, types + + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{'info': lambda s, m: None, 'warning': lambda s, m: None, 'debug': lambda s, m: None}})() + sys.modules['loggers'] = loggers + + format_detection = types.ModuleType('format_detection') + format_detection.detect_dataset_format = lambda *a, **k: None + format_detection.detect_multimodal_dataset = lambda *a, **k: None + format_detection.detect_custom_format_heuristic = lambda *a, **k: None + sys.modules['format_detection'] = format_detection + + model_mappings = types.ModuleType('model_mappings') + model_mappings.MODEL_TO_TEMPLATE_MAPPER = {{}} + sys.modules['model_mappings'] = model_mappings + + ns = {{}} + source = open({str(CHAT_TEMPLATES)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + source = source.replace('from .model_mappings import', 'from model_mappings import') + exec(source, ns) + + assert 'DEFAULT_ALPACA_TEMPLATE' in ns, "DEFAULT_ALPACA_TEMPLATE not defined" + assert 'Instruction' in ns['DEFAULT_ALPACA_TEMPLATE'], "Template content unexpected" + print("OK: DEFAULT_ALPACA_TEMPLATE defined and valid") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"DEFAULT_ALPACA_TEMPLATE check failed:\n{result.stderr.decode()}" + assert b"OK: DEFAULT_ALPACA_TEMPLATE defined and valid" in result.stdout + + +# ── format_conversion.py: AST + runtime tests ──────────────────────── + + +class TestFormatConversionAST: + """Static analysis: format_conversion.py torch imports are guarded.""" + + def test_ast_parse(self): + """format_conversion.py must be valid Python syntax.""" + source = FORMAT_CONVERSION.read_text(encoding = "utf-8") + tree = ast.parse(source, filename = str(FORMAT_CONVERSION)) + assert tree is not None + + def test_no_bare_torch_import_in_functions(self): + """All 'from torch' imports in function bodies must be inside try/except.""" + source = FORMAT_CONVERSION.read_text(encoding = "utf-8") + tree = ast.parse(source) + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + for child in ast.walk(node): + if ( + isinstance(child, ast.ImportFrom) + and child.module + and child.module.startswith("torch") + ): + # This torch import must be inside a Try node + found_in_try = False + for try_node in ast.walk(node): + if isinstance(try_node, ast.Try): + for try_child in ast.walk(try_node): + if try_child is child: + found_in_try = True + break + if found_in_try: + break + assert found_in_try, ( + f"torch import at line {child.lineno} in {node.name}() " + "is not inside a try/except block" + ) + + +class TestFormatConversionNoTorchVenv: + """Run format_conversion.py functions in a no-torch venv.""" + + def test_convert_chatml_to_alpaca_no_torch(self, no_torch_venv): + """convert_chatml_to_alpaca works without torch (via try/except ImportError).""" + code = textwrap.dedent(f"""\ + import sys, types + + # Stub loggers + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{ + 'info': lambda s, m: None, + 'warning': lambda s, m: None, + 'debug': lambda s, m: None, + }})() + sys.modules['loggers'] = loggers + + # Stub datasets.IterableDataset (HF datasets, not torch) + datasets_mod = types.ModuleType('datasets') + datasets_mod.IterableDataset = type('IterableDataset', (), {{}}) + sys.modules['datasets'] = datasets_mod + + # Stub utils.hardware + utils_mod = types.ModuleType('utils') + hardware_mod = types.ModuleType('utils.hardware') + hardware_mod.dataset_map_num_proc = lambda n=None: 1 + utils_mod.hardware = hardware_mod + sys.modules['utils'] = utils_mod + sys.modules['utils.hardware'] = hardware_mod + + # Read and exec format_conversion.py + source = open({str(FORMAT_CONVERSION)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + ns = {{'__name__': '__test__'}} + exec(source, ns) + + # Test convert_chatml_to_alpaca with a simple dataset + class FakeDataset: + def map(self, fn, **kw): + result = fn({{ + 'messages': [[ + {{'role': 'user', 'content': 'Hello'}}, + {{'role': 'assistant', 'content': 'Hi there'}}, + ]] + }}) + return result + + result = ns['convert_chatml_to_alpaca'](FakeDataset()) + assert 'instruction' in result, f"Expected 'instruction' in result, got {{result.keys()}}" + assert result['instruction'] == ['Hello'] + assert result['output'] == ['Hi there'] + print("OK: convert_chatml_to_alpaca works without torch") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"convert_chatml_to_alpaca failed without torch:\n{result.stderr.decode()}" + assert b"OK: convert_chatml_to_alpaca works without torch" in result.stdout + + def test_convert_alpaca_to_chatml_no_torch(self, no_torch_venv): + """convert_alpaca_to_chatml works without torch (via try/except ImportError).""" + code = textwrap.dedent(f"""\ + import sys, types + + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: type('L', (), {{ + 'info': lambda s, m: None, + 'warning': lambda s, m: None, + 'debug': lambda s, m: None, + }})() + sys.modules['loggers'] = loggers + + datasets_mod = types.ModuleType('datasets') + datasets_mod.IterableDataset = type('IterableDataset', (), {{}}) + sys.modules['datasets'] = datasets_mod + + utils_mod = types.ModuleType('utils') + hardware_mod = types.ModuleType('utils.hardware') + hardware_mod.dataset_map_num_proc = lambda n=None: 1 + utils_mod.hardware = hardware_mod + sys.modules['utils'] = utils_mod + sys.modules['utils.hardware'] = hardware_mod + + source = open({str(FORMAT_CONVERSION)!r}).read() + source = source.replace('from .format_detection import', 'from format_detection import') + ns = {{'__name__': '__test__'}} + exec(source, ns) + + class FakeDataset: + def map(self, fn, **kw): + result = fn({{ + 'instruction': ['Write a poem'], + 'input': [''], + 'output': ['Roses are red'], + }}) + return result + + result = ns['convert_alpaca_to_chatml'](FakeDataset()) + assert 'conversations' in result + convo = result['conversations'][0] + assert convo[0]['role'] == 'user' + assert convo[1]['role'] == 'assistant' + print("OK: convert_alpaca_to_chatml works without torch") + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode == 0 + ), f"convert_alpaca_to_chatml failed without torch:\n{result.stderr.decode()}" + assert b"OK: convert_alpaca_to_chatml works without torch" in result.stdout + + +# ── Negative controls ───────────────────────────────────────────────── + + +class TestNegativeControls: + """Prove the fix is necessary by showing what fails WITHOUT it.""" + + def test_import_torch_prepended_fails(self, no_torch_venv): + """Prepending 'import torch' to data_collators.py causes ModuleNotFoundError.""" + with tempfile.NamedTemporaryFile( + mode = "w", suffix = ".py", delete = False, encoding = "utf-8" + ) as f: + f.write("import torch\n") + f.write(DATA_COLLATORS.read_text(encoding = "utf-8")) + temp_file = f.name + + try: + code = textwrap.dedent(f"""\ + import sys, types + loggers = types.ModuleType('loggers') + loggers.get_logger = lambda n: None + sys.modules['loggers'] = loggers + exec(open({temp_file!r}).read()) + """) + result = subprocess.run( + [no_torch_venv, "-c", code], + capture_output = True, + timeout = 30, + ) + assert ( + result.returncode != 0 + ), "Expected failure when 'import torch' is prepended" + assert ( + b"ModuleNotFoundError" in result.stderr + or b"ImportError" in result.stderr + ), f"Expected ImportError, got:\n{result.stderr.decode()}" + finally: + os.unlink(temp_file) + + def test_torchao_install_fails_no_torch_venv(self, no_torch_venv): + """Installing torchao (from overrides.txt) fails in a no-torch venv. + + This proves the overrides.txt skip is necessary for Intel Mac. + """ + result = subprocess.run( + [ + no_torch_venv, + "-m", + "pip", + "install", + "torchao==0.14.0", + "--dry-run", + ], + capture_output = True, + timeout = 60, + ) + if result.returncode != 0: + # torchao install/resolution failed as expected + pass + else: + # pip dry-run may not catch dependency issues; verify torch is missing + check = subprocess.run( + [no_torch_venv, "-c", "import torch"], + capture_output = True, + ) + assert ( + check.returncode != 0 + ), "torch should not be importable -- torchao would fail at runtime" + + def test_direct_torch_import_fails(self, no_torch_venv): + """Direct 'import torch' fails in the no-torch venv.""" + result = subprocess.run( + [no_torch_venv, "-c", "import torch; print('torch loaded')"], + capture_output = True, + timeout = 30, + ) + assert result.returncode != 0, "import torch should fail in no-torch venv" + assert ( + b"ModuleNotFoundError" in result.stderr or b"ImportError" in result.stderr + ) diff --git a/tests/run_all.sh b/tests/run_all.sh index d7fdb38e7..a1516aa6c 100755 --- a/tests/run_all.sh +++ b/tests/run_all.sh @@ -6,11 +6,14 @@ TESTS_DIR="$(cd "$(dirname "$0")" && pwd)" echo "=== Bash tests ===" sh "$TESTS_DIR/sh/test_get_torch_index_url.sh" +sh "$TESTS_DIR/sh/test_mac_intel_compat.sh" echo "" echo "=== Python tests ===" python -m pytest "$TESTS_DIR/python/test_install_python_stack.py" -v python -m pytest "$TESTS_DIR/python/test_cross_platform_parity.py" -v +python -m pytest "$TESTS_DIR/python/test_no_torch_filtering.py" -v +python -m pytest "$TESTS_DIR/python/test_studio_import_no_torch.py" -v echo "" echo "All tests passed." diff --git a/tests/sh/test_mac_intel_compat.sh b/tests/sh/test_mac_intel_compat.sh new file mode 100644 index 000000000..d8848fd01 --- /dev/null +++ b/tests/sh/test_mac_intel_compat.sh @@ -0,0 +1,582 @@ +#!/bin/bash +# End-to-end sandbox tests for Mac Intel compatibility and UNSLOTH_NO_TORCH propagation. +# Tests version_ge, arch detection (existing), plus E2E venv creation, torch skip +# via a mock uv shim, and UNSLOTH_NO_TORCH env propagation in install.sh. +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +INSTALL_SH="$SCRIPT_DIR/../../install.sh" +PASS=0 +FAIL=0 + +assert_eq() { + _label="$1"; _expected="$2"; _actual="$3" + if [ "$_actual" = "$_expected" ]; then + echo " PASS: $_label" + PASS=$((PASS + 1)) + else + echo " FAIL: $_label (expected '$_expected', got '$_actual')" + FAIL=$((FAIL + 1)) + fi +} + +assert_contains() { + _label="$1"; _haystack="$2"; _needle="$3" + if echo "$_haystack" | grep -qF "$_needle"; then + echo " PASS: $_label" + PASS=$((PASS + 1)) + else + echo " FAIL: $_label (expected to find '$_needle')" + FAIL=$((FAIL + 1)) + fi +} + +assert_not_contains() { + _label="$1"; _haystack="$2"; _needle="$3" + if echo "$_haystack" | grep -qF "$_needle"; then + echo " FAIL: $_label (found '$_needle' but should not)" + FAIL=$((FAIL + 1)) + else + echo " PASS: $_label" + PASS=$((PASS + 1)) + fi +} + +# ── Extract version_ge function from install.sh ── +_VGE_FILE=$(mktemp) +sed -n '/^version_ge()/,/^}/p' "$INSTALL_SH" > "$_VGE_FILE" + +echo "=== version_ge ===" + +# Basic comparisons +_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.12' && echo pass || echo fail") +assert_eq "3.13 >= 3.12" "pass" "$_result" + +_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12' '3.13' && echo pass || echo fail") +assert_eq "3.12 >= 3.13" "fail" "$_result" + +_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13' '3.13' && echo pass || echo fail") +assert_eq "3.13 >= 3.13 (equal)" "pass" "$_result" + +# Patch versions +_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.13.8' '3.13' && echo pass || echo fail") +assert_eq "3.13.8 >= 3.13 (patch > implicit 0)" "pass" "$_result" + +_result=$(bash -c ". '$_VGE_FILE'; version_ge '3.12.0' '3.13.0' && echo pass || echo fail") +assert_eq "3.12.0 >= 3.13.0 (minor less)" "fail" "$_result" + +# UV_MIN_VERSION edge cases +_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.14' '0.7.14' && echo pass || echo fail") +assert_eq "0.7.14 >= 0.7.14 (exact UV_MIN_VERSION)" "pass" "$_result" + +_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.7.13' '0.7.14' && echo pass || echo fail") +assert_eq "0.7.13 >= 0.7.14 (below minimum)" "fail" "$_result" + +_result=$(bash -c ". '$_VGE_FILE'; version_ge '0.11.1' '0.7.14' && echo pass || echo fail") +assert_eq "0.11.1 >= 0.7.14 (well above)" "pass" "$_result" + +# Major jump +_result=$(bash -c ". '$_VGE_FILE'; version_ge '1.0' '0.99.99' && echo pass || echo fail") +assert_eq "1.0 >= 0.99.99 (major jump)" "pass" "$_result" + +rm -f "$_VGE_FILE" + +echo "" +echo "=== Architecture detection + PYTHON_VERSION ===" + +# Self-contained arch detection snippet matching install.sh logic +_ARCH_SNIPPET=$(mktemp) +cat > "$_ARCH_SNIPPET" << 'SNIPPET' +OS="linux" +if [ "$(uname)" = "Darwin" ]; then + OS="macos" +fi +_ARCH=$(uname -m) +MAC_INTEL=false +if [ "$OS" = "macos" ] && [ "$_ARCH" = "x86_64" ]; then + MAC_INTEL=true +fi +_USER_PYTHON="" +if [ -n "$_USER_PYTHON" ]; then + PYTHON_VERSION="$_USER_PYTHON" +elif [ "$MAC_INTEL" = true ]; then + PYTHON_VERSION="3.12" +else + PYTHON_VERSION="3.13" +fi +echo "$OS $MAC_INTEL $PYTHON_VERSION" +SNIPPET + +# Test: Darwin x86_64 -> macos true 3.12 +_result=$(bash -c ' +uname() { + case "$1" in + -m) echo "x86_64" ;; + *) echo "Darwin" ;; + esac +} +export -f uname +'"source '$_ARCH_SNIPPET'") +assert_eq "Darwin x86_64 -> macos true 3.12" "macos true 3.12" "$_result" + +# Test: Darwin arm64 -> macos false 3.13 +_result=$(bash -c ' +uname() { + case "$1" in + -m) echo "arm64" ;; + *) echo "Darwin" ;; + esac +} +export -f uname +'"source '$_ARCH_SNIPPET'") +assert_eq "Darwin arm64 -> macos false 3.13" "macos false 3.13" "$_result" + +# Test: Linux x86_64 -> linux false 3.13 +_result=$(bash -c ' +uname() { + case "$1" in + -m) echo "x86_64" ;; + *) echo "Linux" ;; + esac +} +export -f uname +'"source '$_ARCH_SNIPPET'") +assert_eq "Linux x86_64 -> linux false 3.13" "linux false 3.13" "$_result" + +# Test: Linux aarch64 -> linux false 3.13 +_result=$(bash -c ' +uname() { + case "$1" in + -m) echo "aarch64" ;; + *) echo "Linux" ;; + esac +} +export -f uname +'"source '$_ARCH_SNIPPET'") +assert_eq "Linux aarch64 -> linux false 3.13" "linux false 3.13" "$_result" + +rm -f "$_ARCH_SNIPPET" + +echo "" +echo "=== get_torch_index_url on Darwin ===" + +# Extract get_torch_index_url and replace hardcoded nvidia-smi path +_FUNC_FILE=$(mktemp) +_FAKE_SMI_DIR=$(mktemp -d) +sed -n '/^get_torch_index_url()/,/^}/p' "$INSTALL_SH" \ + | sed "s|/usr/bin/nvidia-smi|$_FAKE_SMI_DIR/nvidia-smi-absent|g" \ + > "$_FUNC_FILE" + +# Build a minimal tools directory +_TOOLS_DIR=$(mktemp -d) +for _cmd in grep sed head sh bash cat; do + _real=$(command -v "$_cmd" 2>/dev/null || true) + [ -n "$_real" ] && ln -sf "$_real" "$_TOOLS_DIR/$_cmd" +done + +# Create a mock uname that returns Darwin +_MOCK_UNAME_DIR=$(mktemp -d) +cat > "$_MOCK_UNAME_DIR/uname" << 'MOCK_UNAME' +#!/bin/sh +case "$1" in + -s) echo "Darwin" ;; + -m) echo "arm64" ;; + *) echo "Darwin" ;; +esac +MOCK_UNAME +chmod +x "$_MOCK_UNAME_DIR/uname" + +# Mock nvidia-smi that returns CUDA version (to prove macOS ignores it) +_GPU_DIR=$(mktemp -d) +cat > "$_GPU_DIR/nvidia-smi" << 'MOCK_SMI' +#!/bin/sh +cat <<'SMI_OUT' ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.6 | ++-----------------------------------------------------------------------------------------+ +SMI_OUT +MOCK_SMI +chmod +x "$_GPU_DIR/nvidia-smi" + +# Test: Darwin always returns cpu (even with nvidia-smi present) +_result=$(PATH="$_GPU_DIR:$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null) +assert_eq "Darwin -> cpu (even with nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result" + +# Test: Darwin without nvidia-smi also returns cpu +_result=$(PATH="$_MOCK_UNAME_DIR:$_TOOLS_DIR" bash -c ". '$_FUNC_FILE'; get_torch_index_url" 2>/dev/null) +assert_eq "Darwin -> cpu (no nvidia-smi)" "https://download.pytorch.org/whl/cpu" "$_result" + +rm -f "$_FUNC_FILE" +rm -rf "$_FAKE_SMI_DIR" "$_TOOLS_DIR" "$_MOCK_UNAME_DIR" "$_GPU_DIR" + +echo "" +echo "=== UNSLOTH_NO_TORCH propagation ===" + +# Verify UNSLOTH_NO_TORCH is passed to setup.sh in BOTH the --local and non-local branches. +_local_count=$(grep -c 'UNSLOTH_NO_TORCH=' "$INSTALL_SH" | head -1) +if [ "$_local_count" -ge 2 ]; then + echo " PASS: UNSLOTH_NO_TORCH appears in >= 2 setup.sh invocations ($_local_count found)" + PASS=$((PASS + 1)) +else + echo " FAIL: UNSLOTH_NO_TORCH should appear in >= 2 setup.sh invocations (found $_local_count)" + FAIL=$((FAIL + 1)) +fi + +# Verify the value passed is "$SKIP_TORCH" (the unified variable, not MAC_INTEL) +_skip_torch_count=$(grep 'UNSLOTH_NO_TORCH="\$SKIP_TORCH"' "$INSTALL_SH" | wc -l) +if [ "$_skip_torch_count" -ge 2 ]; then + echo " PASS: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" in both branches ($_skip_torch_count found)" + PASS=$((PASS + 1)) +else + echo " FAIL: UNSLOTH_NO_TORCH=\"\$SKIP_TORCH\" should appear in >= 2 branches (found $_skip_torch_count)" + FAIL=$((FAIL + 1)) +fi + +# Verify MAC_INTEL is set to true when Intel Mac is detected +_mac_intel_set=$(grep -c 'MAC_INTEL=true' "$INSTALL_SH") +if [ "$_mac_intel_set" -ge 1 ]; then + echo " PASS: MAC_INTEL=true is set in install.sh" + PASS=$((PASS + 1)) +else + echo " FAIL: MAC_INTEL=true not found in install.sh" + FAIL=$((FAIL + 1)) +fi + +# Verify the PyTorch skip message exists (now covers both --no-torch and Intel Mac) +if grep -q 'Skipping PyTorch' "$INSTALL_SH"; then + echo " PASS: PyTorch skip message found" + PASS=$((PASS + 1)) +else + echo " FAIL: PyTorch skip message not found" + FAIL=$((FAIL + 1)) +fi + +# Verify SKIP_TORCH unified variable exists +if grep -q 'SKIP_TORCH=true' "$INSTALL_SH"; then + echo " PASS: SKIP_TORCH=true assignment found" + PASS=$((PASS + 1)) +else + echo " FAIL: SKIP_TORCH=true not found in install.sh" + FAIL=$((FAIL + 1)) +fi + +echo "" +echo "=== E2E: venv creation at Python 3.12 (simulated Intel Mac) ===" + +# Actually create a uv venv at Python 3.12 to verify the path works +if command -v uv >/dev/null 2>&1; then + _VENV_DIR=$(mktemp -d) + _uv_result=$(uv venv "$_VENV_DIR/test_venv" --python 3.12 2>&1) && _uv_rc=0 || _uv_rc=$? + if [ "$_uv_rc" -eq 0 ]; then + echo " PASS: uv venv created at Python 3.12" + PASS=$((PASS + 1)) + + # Verify Python version inside the venv + _py_ver=$("$_VENV_DIR/test_venv/bin/python" -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + assert_eq "venv Python is 3.12" "3.12" "$_py_ver" + + # Verify torch is NOT available (fresh venv has no torch) + if "$_VENV_DIR/test_venv/bin/python" -c "import torch" 2>/dev/null; then + echo " FAIL: torch should NOT be importable in fresh 3.12 venv" + FAIL=$((FAIL + 1)) + else + echo " PASS: torch not importable in fresh 3.12 venv (expected for Intel Mac)" + PASS=$((PASS + 1)) + fi + else + echo " SKIP: Could not create Python 3.12 venv (python 3.12 not available)" + fi + rm -rf "$_VENV_DIR" +else + echo " SKIP: uv not available, cannot test venv creation" +fi + +echo "" +echo "=== E2E: torch install skipped when SKIP_TORCH=true (mock uv shim) ===" + +# Create a mock uv that logs all calls instead of running them +_MOCK_UV_DIR=$(mktemp -d) +_UV_LOG="$_MOCK_UV_DIR/uv_calls.log" +touch "$_UV_LOG" +cat > "$_MOCK_UV_DIR/uv" << MOCK_UV_EOF +#!/bin/sh +echo "UV_CALL: \$*" >> "$_UV_LOG" +MOCK_UV_EOF +chmod +x "$_MOCK_UV_DIR/uv" + +# Simulates the torch install decision from install.sh using SKIP_TORCH +_TORCH_BLOCK=$(mktemp) +cat > "$_TORCH_BLOCK" << 'TORCH_EOF' +# Simulates the torch install decision from install.sh +TORCH_INDEX_URL="https://download.pytorch.org/whl/cpu" +_VENV_PY="/fake/python" +if [ "$SKIP_TORCH" = true ]; then + echo "==> Skipping PyTorch (--no-torch or Intel Mac x86_64)." +else + echo "==> Installing PyTorch ($TORCH_INDEX_URL)..." + uv pip install --python "$_VENV_PY" "torch>=2.4,<2.11.0" torchvision torchaudio \ + --index-url "$TORCH_INDEX_URL" +fi +TORCH_EOF + +# Test: SKIP_TORCH=true -> torch install should be SKIPPED (no uv calls) +> "$_UV_LOG" # clear log +_torch_output=$(SKIP_TORCH=true PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1) +assert_contains "SKIP_TORCH=true prints skip message" "$_torch_output" "Skipping PyTorch" +if [ -s "$_UV_LOG" ]; then + echo " FAIL: uv was called when SKIP_TORCH=true (should be skipped)" + echo " Log: $(cat "$_UV_LOG")" + FAIL=$((FAIL + 1)) +else + echo " PASS: no uv pip install torch when SKIP_TORCH=true" + PASS=$((PASS + 1)) +fi + +# Test: SKIP_TORCH=false -> torch install should EXECUTE (uv called with torch) +> "$_UV_LOG" # clear log +_torch_output=$(SKIP_TORCH=false PATH="$_MOCK_UV_DIR:$PATH" bash "$_TORCH_BLOCK" 2>&1) +assert_contains "SKIP_TORCH=false prints install message" "$_torch_output" "Installing PyTorch" +if grep -q "torch" "$_UV_LOG"; then + echo " PASS: uv pip install torch called when SKIP_TORCH=false" + PASS=$((PASS + 1)) +else + echo " FAIL: uv pip install torch NOT called when SKIP_TORCH=false" + FAIL=$((FAIL + 1)) +fi + +rm -f "$_TORCH_BLOCK" +rm -rf "$_MOCK_UV_DIR" + +echo "" +echo "=== E2E: UNSLOTH_NO_TORCH env propagation (dynamic test) ===" + +# Simulates the setup.sh invocation using SKIP_TORCH +_ENV_BLOCK=$(mktemp) +cat > "$_ENV_BLOCK" << 'ENV_EOF' +# Simulates the setup.sh invocation block from install.sh +PACKAGE_NAME="unsloth" +_REPO_ROOT="/fake/repo" +SETUP_SH="/fake/setup.sh" + +if [ "$STUDIO_LOCAL_INSTALL" = true ]; then + SKIP_STUDIO_BASE=1 \ + STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \ + STUDIO_LOCAL_INSTALL=1 \ + STUDIO_LOCAL_REPO="$_REPO_ROOT" \ + UNSLOTH_NO_TORCH="$SKIP_TORCH" \ + env | grep "^UNSLOTH_NO_TORCH=" +else + SKIP_STUDIO_BASE=1 \ + STUDIO_PACKAGE_NAME="$PACKAGE_NAME" \ + UNSLOTH_NO_TORCH="$SKIP_TORCH" \ + env | grep "^UNSLOTH_NO_TORCH=" +fi +ENV_EOF + +# Test: SKIP_TORCH=true -> UNSLOTH_NO_TORCH=true in env +_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1) +assert_eq "non-local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result" + +# Test: SKIP_TORCH=false -> UNSLOTH_NO_TORCH=false in env +_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=false bash "$_ENV_BLOCK" 2>&1) +assert_eq "non-local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result" + +# Test: local install path also propagates +_env_result=$(SKIP_TORCH=true STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1) +assert_eq "local: UNSLOTH_NO_TORCH=true when SKIP_TORCH=true" "UNSLOTH_NO_TORCH=true" "$_env_result" + +_env_result=$(SKIP_TORCH=false STUDIO_LOCAL_INSTALL=true bash "$_ENV_BLOCK" 2>&1) +assert_eq "local: UNSLOTH_NO_TORCH=false when SKIP_TORCH=false" "UNSLOTH_NO_TORCH=false" "$_env_result" + +rm -f "$_ENV_BLOCK" + +echo "" +echo "=== --python override flag ===" + +# Test: flag parsing extracts version correctly +_PARSE_BLOCK=$(mktemp) +cat > "$_PARSE_BLOCK" << 'PARSE_EOF' +_USER_PYTHON="" +_next_is_python=false +_next_is_package=false +STUDIO_LOCAL_INSTALL=false +PACKAGE_NAME="unsloth" +for arg in "$@"; do + if [ "$_next_is_package" = true ]; then PACKAGE_NAME="$arg"; _next_is_package=false; continue; fi + if [ "$_next_is_python" = true ]; then _USER_PYTHON="$arg"; _next_is_python=false; continue; fi + case "$arg" in + --local) STUDIO_LOCAL_INSTALL=true ;; + --package) _next_is_package=true ;; + --python) _next_is_python=true ;; + esac +done +if [ "$_next_is_python" = true ]; then echo "ERROR"; exit 1; fi +echo "$_USER_PYTHON" +PARSE_EOF + +_result=$(bash "$_PARSE_BLOCK" --python 3.12) +assert_eq "--python 3.12 parsed" "3.12" "$_result" + +_result=$(bash "$_PARSE_BLOCK" --local --python 3.11) +assert_eq "--local --python 3.11 parsed" "3.11" "$_result" + +_result=$(bash "$_PARSE_BLOCK" --python 3.12 --local --package foo) +assert_eq "--python with --local --package" "3.12" "$_result" + +_result=$(bash "$_PARSE_BLOCK" 2>&1) # no --python +assert_eq "no --python -> empty" "" "$_result" + +_rc=0 +bash "$_PARSE_BLOCK" --python >/dev/null 2>&1 || _rc=$? +assert_eq "--python without arg -> error" "1" "$_rc" + +rm -f "$_PARSE_BLOCK" + +# Test: --python overrides auto-detected version in PYTHON_VERSION resolution +_RESOLVE_BLOCK=$(mktemp) +cat > "$_RESOLVE_BLOCK" << 'RESOLVE_EOF' +_USER_PYTHON="$1" +MAC_INTEL="$2" +if [ -n "$_USER_PYTHON" ]; then + PYTHON_VERSION="$_USER_PYTHON" +elif [ "$MAC_INTEL" = true ]; then + PYTHON_VERSION="3.12" +else + PYTHON_VERSION="3.13" +fi +echo "$PYTHON_VERSION" +RESOLVE_EOF + +_result=$(bash "$_RESOLVE_BLOCK" "3.11" "true") +assert_eq "--python 3.11 overrides Intel Mac 3.12" "3.11" "$_result" + +_result=$(bash "$_RESOLVE_BLOCK" "3.12" "false") +assert_eq "--python 3.12 overrides default 3.13" "3.12" "$_result" + +_result=$(bash "$_RESOLVE_BLOCK" "" "true") +assert_eq "no override -> Intel Mac gets 3.12" "3.12" "$_result" + +_result=$(bash "$_RESOLVE_BLOCK" "" "false") +assert_eq "no override -> non-Intel gets 3.13" "3.13" "$_result" + +rm -f "$_RESOLVE_BLOCK" + +# Test: --python flag exists in install.sh +if grep -q '\-\-python)' "$INSTALL_SH"; then + echo " PASS: --python case exists in install.sh" + PASS=$((PASS + 1)) +else + echo " FAIL: --python case not found in install.sh" + FAIL=$((FAIL + 1)) +fi + +# Test: _USER_PYTHON guards exist for stale-venv and 3.13.8 checks +_user_py_guards=$(grep -c '_USER_PYTHON' "$INSTALL_SH") +if [ "$_user_py_guards" -ge 4 ]; then + echo " PASS: _USER_PYTHON referenced >= 4 times in install.sh (flag + resolution + guards)" + PASS=$((PASS + 1)) +else + echo " FAIL: _USER_PYTHON should appear >= 4 times (found $_user_py_guards)" + FAIL=$((FAIL + 1)) +fi + +echo "" +echo "=== --no-torch flag parsing ===" + +# Test: --no-torch sets _NO_TORCH_FLAG=true +_FLAG_SNIPPET=$(mktemp) +cat > "$_FLAG_SNIPPET" << 'SNIPPET' +_NO_TORCH_FLAG=false +_next_is_package=false +STUDIO_LOCAL_INSTALL=false +PACKAGE_NAME="unsloth" +for arg in "$@"; do + if [ "$_next_is_package" = true ]; then + PACKAGE_NAME="$arg" + _next_is_package=false + continue + fi + case "$arg" in + --local) STUDIO_LOCAL_INSTALL=true ;; + --package) _next_is_package=true ;; + --no-torch) _NO_TORCH_FLAG=true ;; + esac +done +echo "$_NO_TORCH_FLAG" +SNIPPET + +_result=$(bash "$_FLAG_SNIPPET" --no-torch) +assert_eq "--no-torch sets flag to true" "true" "$_result" + +_result=$(bash "$_FLAG_SNIPPET") +assert_eq "no flags -> flag is false" "false" "$_result" + +_result=$(bash "$_FLAG_SNIPPET" --local --no-torch) +assert_eq "--local --no-torch both work" "true" "$_result" + +_result=$(bash "$_FLAG_SNIPPET" --no-torch --package custom-pkg) +assert_eq "--no-torch with --package works" "true" "$_result" + +rm -f "$_FLAG_SNIPPET" + +echo "" +echo "=== SKIP_TORCH unification ===" + +# Test: SKIP_TORCH is set to true when --no-torch flag is set (even without MAC_INTEL) +_SKIP_SNIPPET=$(mktemp) +cat > "$_SKIP_SNIPPET" << 'SNIPPET' +MAC_INTEL=false +_NO_TORCH_FLAG=$1 +SKIP_TORCH=false +if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then + SKIP_TORCH=true +fi +echo "$SKIP_TORCH" +SNIPPET + +_result=$(bash "$_SKIP_SNIPPET" true) +assert_eq "--no-torch flag alone sets SKIP_TORCH=true" "true" "$_result" + +_result=$(bash "$_SKIP_SNIPPET" false) +assert_eq "no flag, no MAC_INTEL -> SKIP_TORCH=false" "false" "$_result" + +# Test: MAC_INTEL=true alone also sets SKIP_TORCH=true +_SKIP_SNIPPET2=$(mktemp) +cat > "$_SKIP_SNIPPET2" << 'SNIPPET' +MAC_INTEL=true +_NO_TORCH_FLAG=false +SKIP_TORCH=false +if [ "$_NO_TORCH_FLAG" = true ] || [ "$MAC_INTEL" = true ]; then + SKIP_TORCH=true +fi +echo "$SKIP_TORCH" +SNIPPET + +_result=$(bash "$_SKIP_SNIPPET2") +assert_eq "MAC_INTEL=true alone sets SKIP_TORCH=true" "true" "$_result" + +rm -f "$_SKIP_SNIPPET" "$_SKIP_SNIPPET2" + +echo "" +echo "=== CPU hint printing ===" + +# Verify the CPU hint is present in install.sh source +if grep -q 'No NVIDIA GPU detected' "$INSTALL_SH"; then + echo " PASS: CPU hint message found in install.sh" + PASS=$((PASS + 1)) +else + echo " FAIL: CPU hint message not found in install.sh" + FAIL=$((FAIL + 1)) +fi + +if grep -q '\-\-no-torch' "$INSTALL_SH"; then + echo " PASS: --no-torch appears in install.sh" + PASS=$((PASS + 1)) +else + echo " FAIL: --no-torch not found in install.sh" + FAIL=$((FAIL + 1)) +fi + +echo "" +echo "Results: $PASS passed, $FAIL failed" +[ "$FAIL" -eq 0 ] || exit 1