unsloth/scripts/notebook_validator.py
Daniel Han bfb5c2872c CI(notebooks): cross-repo validator for unslothai/notebooks
New PR-time + scheduled workflow that walks every nb/, kaggle/, and
original_template/ notebook in unslothai/notebooks and statically
validates the install cells and user-facing code against:

  - googlecolab/backend-info pip-freeze.gpu.txt (Colab oracle, refreshed
    on every run; fallback snapshot committed under scripts/data/).
  - PyPI metadata for transitive constraint resolution.
  - Hardcoded torch/torchcodec ABI table.
  - Hardcoded peft/torchao floor table.
  - The live unsloth + trl API surface, introspected under
    tests/_zoo_aggressive_cuda_spoof.py so the api job runs on a
    GPU-less ubuntu-latest runner.

Catches the bug classes from notebooks#258 / #260 / #261 / #264 / #221
and commit 51b1462 mechanically:

  R-INST-001  forbid git+ HEAD installs (notebooks#221)
  R-INST-002  --no-deps + transitive constraint violation
  R-INST-003  peft 0.19+ requires torchao 0.16.0+ (notebooks#258)
  R-INST-004  torch <-> torchcodec ABI mismatch (notebooks#261a)
  R-INST-005  --no-deps transformers + Colab tokenizers drift
              (notebooks#261b / #264)
  R-INST-006  forbid !!pip
  R-API-003   adamw_torch_fused -> adamw_8bit hint (warning)
  R-API-004   notebook references symbols outside live unsloth surface
  R-EXC-001   DONT_UPDATE_EXCEPTIONS notebooks must satisfy the same
              policy clauses as generated notebooks (notebooks#260)
  R-DRIFT-001 update_all_notebooks.py emits no diff (commit 51b1462)
  R-CONV-001  notebook_to_python.py converts every .ipynb cleanly

Files:
  .github/workflows/notebooks-ci.yml          PR-time + cron + dispatch
  scripts/notebook_validator.py               1148 LOC, single-file
  scripts/notebook_to_python.py               battle-tested converter
  scripts/data/colab_pip_freeze.gpu.txt       fallback snapshot
  scripts/data/colab_to_cpu_pin.json          cu128 -> CPU wheel map
  tests/notebooks/test_validator_fixtures.py  21 golden tests, all green

CPU-only by design. The api-introspect job follows the existing
consolidated-tests-ci spoof pattern (lines 309/417/536/626/826/1081/
1586/1998 of consolidated-tests-ci.yml). The smoke-install job is
opt-in via workflow_dispatch and stubs torchcodec since no CPU wheel
exists.

Validated on the live unslothai/notebooks@7af0ac0f tree: every fixture
test passes, exceptions check is silent, lint surfaces 27 errors + 6
warnings on real notebooks (mix of #258-class regressions in 6 nb/
notebooks the previous template fixes did not reach, plus 14
git+-HEAD installs in hand-tuned exception notebooks).
2026-05-07 11:42:57 +00:00

1148 lines
40 KiB
Python

#!/usr/bin/env python3
# coding: utf-8
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team.
"""
Static + lightweight-dynamic validator for unslothai/notebooks.
Built to catch the bug classes that landed in (at minimum):
- unslothai/notebooks#258 (Colab torchao 0.10 vs peft 0.19 floor)
- unslothai/notebooks#260 (DONT_UPDATE_EXCEPTIONS coverage drift)
- unslothai/notebooks#261 (torch/torchcodec ABI; --no-deps tokenizers)
- unslothai/notebooks#264 (transformers/tokenizers window with --no-deps)
- unslothai/notebooks#221 (removed unsloth APIs in user cells, git+ install)
- unslothai/notebooks commit 51b1462 (template/notebook drift)
CPU-only by design: never imports torch / unsloth at module load. The
api subcommand introspects unsloth under the existing
tests/_zoo_aggressive_cuda_spoof.py harness (PR #5312) so it works on
ubuntu-latest without a GPU.
Usage:
python scripts/notebook_validator.py drift --notebooks-dir <dir>
python scripts/notebook_validator.py convert --notebooks-dir <dir> --out _converted
python scripts/notebook_validator.py lint --notebooks-dir <dir> [--colab-pin <file>]
python scripts/notebook_validator.py exceptions --notebooks-dir <dir>
python scripts/notebook_validator.py api --converted-dir _converted --surface _api_surface.json
python scripts/notebook_validator.py all --notebooks-dir <dir>
python scripts/notebook_validator.py refresh-colab --out scripts/data/colab_pip_freeze.gpu.txt
"""
from __future__ import annotations
import argparse
import ast
import dataclasses
import json
import os
import pathlib
import re
import shlex
import subprocess
import sys
import textwrap
import time
import urllib.error
import urllib.request
from typing import Any, Iterable, Iterator
HERE = pathlib.Path(__file__).resolve().parent
DATA_DIR = HERE / "data"
PYPI_CACHE_DIR = DATA_DIR / "pypi_cache"
COLAB_PIP_FREEZE_URL = (
"https://raw.githubusercontent.com/googlecolab/backend-info/main/pip-freeze.gpu.txt"
)
COLAB_FALLBACK_FILE = DATA_DIR / "colab_pip_freeze.gpu.txt"
# ----- Compat tables. PRs add rows as new releases land. ----- #
# torch.minor -> set of compatible torchcodec.minor strings.
# Source: pytorch/torchcodec compatibility matrix on its README.
TORCH_TORCHCODEC: dict[str, set[str]] = {
"2.10": {"0.10"},
"2.9": {"0.7", "0.8", "0.9"},
"2.8": {"0.6"},
"2.7": {"0.3", "0.4", "0.5"},
"2.6": {"0.2", "0.3"},
"2.5": {"0.1", "0.2"},
}
# When peft >= trigger is on the resolved set, torchao >= floor must also be.
PEFT_TORCHAO_FLOOR: list[dict[str, str]] = [
{"trigger_peft": "0.19", "torchao_floor": "0.16.0"},
]
# git+ allowlist: install lines that legitimately fetch from GitHub. Anything
# else flags R-INST-001.
GIT_PLUS_ALLOWLIST = (
"github.com/SparkAudio/Spark-TTS",
"github.com/state-spaces/mamba",
"github.com/Dao-AILab/causal-conv1d",
"github.com/unslothai/unsloth-zoo",
"github.com/unslothai/unsloth",
)
# ----- Findings ----- #
@dataclasses.dataclass
class Finding:
rule: str
file: str
cell: int | None = None
line: int | None = None
severity: str = "error" # error | warning
message: str = ""
hint: str = ""
def to_dict(self) -> dict[str, Any]:
return dataclasses.asdict(self)
# ----- Notebook walking ----- #
def iter_notebooks(
notebooks_dir: pathlib.Path, include_templates: bool = False
) -> Iterator[pathlib.Path]:
"""Yield user-facing .ipynb files under nb/ and kaggle/. Pass
include_templates=True to also walk original_template/ (used by the
convert subcommand which doesn't lint install cells)."""
subs = ("nb", "kaggle")
if include_templates:
subs = ("nb", "kaggle", "original_template")
candidates = []
for sub in subs:
d = notebooks_dir / sub
if d.is_dir():
for p in sorted(d.glob("*.ipynb")):
candidates.append(p)
seen = set()
for p in candidates:
if p.resolve() in seen:
continue
seen.add(p.resolve())
yield p
def load_notebook(path: pathlib.Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding = "utf-8"))
def cell_source(cell: dict[str, Any]) -> str:
src = cell.get("source", "")
if isinstance(src, list):
return "".join(src)
return src
def code_cells(nb: dict[str, Any]) -> list[tuple[int, str]]:
out = []
for i, c in enumerate(nb.get("cells", [])):
if c.get("cell_type") == "code":
out.append((i, cell_source(c)))
return out
def install_cells(nb: dict[str, Any]) -> list[tuple[int, str]]:
"""Heuristic: any code cell that contains a `pip install`, `pip uninstall`
or `uv pip install` shell command, or a top-line `%%capture` magic."""
out = []
for i, src in code_cells(nb):
first = src.lstrip().splitlines()[:1]
if first and first[0].strip().startswith("%%capture"):
out.append((i, src))
continue
if re.search(
r"^[ \t]*!\s*(uv\s+)?pip\s+(install|uninstall)\b", src, re.MULTILINE
):
out.append((i, src))
return out
# Notebook target environment. The Colab oracle (pip-freeze.gpu.txt) only
# applies to notebooks that actually run on Colab; AMD-Dev-Cloud,
# Kaggle, HuggingFace-Course, and DGX-Spark notebooks have their own
# preinstalled environments and the Colab-vs-cell rules are not
# applicable to them.
def target_environment(notebook_name: str) -> str:
parts = pathlib.PurePath(notebook_name).parts
base = parts[-1] if parts else notebook_name
parent = parts[-2] if len(parts) >= 2 else ""
if parent == "kaggle" or base.startswith("Kaggle-"):
return "kaggle"
if base.startswith("AMD-") or "_AMD_" in base:
return "amd"
if base.startswith("HuggingFace Course-") or base.startswith("HuggingFace_Course-"):
return "colab" # HF Course notebooks still run on Colab.
if "DGX_Spark" in base:
return "dgx_spark"
return "colab"
# ----- Pip-freeze parsing ----- #
PINNED_RE = re.compile(r"^\s*([A-Za-z0-9._-]+)\s*==\s*([^\s;#]+)")
def parse_pip_freeze(path: pathlib.Path) -> dict[str, str]:
"""Return {name_lower: version_str_with_local_version}."""
out: dict[str, str] = {}
if not path.is_file():
return out
for line in path.read_text(encoding = "utf-8").splitlines():
if not line.strip() or line.startswith("#"):
continue
m = PINNED_RE.match(line)
if m:
out[m.group(1).lower()] = m.group(2)
return out
def normalise_version(v: str) -> str:
"""Strip +cu128 / +cpu / -dev local-version metadata."""
return re.split(r"[+\-]", v, maxsplit = 1)[0]
def version_minor(v: str) -> str:
parts = normalise_version(v).split(".")
return ".".join(parts[:2]) if len(parts) >= 2 else parts[0]
def cmp_versions(a: str, b: str) -> int:
"""Return -1/0/+1. Compares dotted numeric components only."""
def to_tuple(v: str) -> tuple[int, ...]:
return tuple(int(x) for x in re.findall(r"\d+", normalise_version(v)))
ta, tb = to_tuple(a), to_tuple(b)
if ta < tb:
return -1
if ta > tb:
return 1
return 0
# ----- Install-cell parsing ----- #
@dataclasses.dataclass
class PipInvocation:
tool: str # "pip" | "uv-pip"
flags: set[str] # {'--no-deps', '--upgrade', '--force-reinstall', ...}
packages: list[str] # raw package specifiers (e.g. 'transformers==5.5.0')
raw: str
line_no: int = 0
PIP_LINE_RE = re.compile(
r"^\s*!\s*(?P<tool>(?:uv\s+)?pip)\s+(?:install|uninstall)\b(?P<rest>.*)$",
re.IGNORECASE,
)
NON_PKG_FLAG_TAKES_VAL = {
"-r",
"--requirement",
"-c",
"--constraint",
"-i",
"--index-url",
"--extra-index-url",
"--find-links",
"-e",
"--editable",
"--target",
"--prefix",
}
def parse_pip_line(line: str, line_no: int = 0) -> PipInvocation | None:
m = PIP_LINE_RE.match(line)
if not m:
return None
tool = "uv-pip" if "uv" in m.group("tool") else "pip"
rest = m.group("rest")
# Strip trailing comment.
rest = re.split(r"(?<!\S)#", rest, maxsplit = 1)[0]
try:
tokens = shlex.split(rest, posix = True)
except ValueError:
# f-string interpolation like {xformers}: replace braces with placeholders.
rest_safe = re.sub(r"\{[^}]+\}", "PLACEHOLDER", rest)
try:
tokens = shlex.split(rest_safe, posix = True)
except ValueError:
return None
flags: set[str] = set()
packages: list[str] = []
skip_next = False
for t in tokens:
if skip_next:
skip_next = False
continue
if t in NON_PKG_FLAG_TAKES_VAL:
flags.add(t)
skip_next = True
continue
if t.startswith("-"):
flags.add(t)
continue
if t in ("install", "uninstall"):
continue
packages.append(t)
return PipInvocation(
tool = tool, flags = flags, packages = packages, raw = line, line_no = line_no
)
def _glue_line_continuations(text: str) -> list[tuple[int, str]]:
"""Return (logical_line_no, joined_text) for each logical line, treating
a trailing backslash as a continuation. Logical line numbers point at the
first physical line of each logical line."""
out: list[tuple[int, str]] = []
buf = ""
start = 0
for i, raw in enumerate(text.splitlines(), start = 1):
if buf == "":
start = i
if raw.rstrip().endswith("\\"):
buf += raw.rstrip()[:-1] + " "
else:
buf += raw
out.append((start, buf))
buf = ""
if buf:
out.append((start, buf))
return out
def iter_pip_invocations(install_cell: str) -> Iterator[PipInvocation]:
for line_no, line in _glue_line_continuations(install_cell):
inv = parse_pip_line(line, line_no)
if inv is not None:
yield inv
# Spec parsing: only what we need (no full PEP 440).
SPEC_RE = re.compile(r"^(?P<name>[A-Za-z0-9._-]+)(?:\[[^\]]*\])?(?P<rest>.*)$")
OP_VERSION_RE = re.compile(r"(==|>=|<=|!=|~=|>|<)\s*([0-9][^,;\s]*)")
@dataclasses.dataclass
class SpecParts:
name: str
pins: list[tuple[str, str]] # list of (op, version)
raw: str
def parse_spec(spec: str) -> SpecParts | None:
spec = spec.strip().strip('"').strip("'")
if not spec or spec.startswith("-") or "://" in spec:
return None
m = SPEC_RE.match(spec)
if not m:
return None
name = m.group("name").lower()
rest = m.group("rest")
pins = OP_VERSION_RE.findall(rest)
return SpecParts(name = name, pins = pins, raw = spec)
def explicit_pin(spec: SpecParts) -> str | None:
for op, ver in spec.pins:
if op == "==":
return ver
return None
# ----- PyPI metadata cache ----- #
def pypi_metadata(name: str, version: str) -> dict[str, Any] | None:
PYPI_CACHE_DIR.mkdir(parents = True, exist_ok = True)
safe = re.sub(r"[^A-Za-z0-9._-]", "_", f"{name.lower()}__{version}")
path = PYPI_CACHE_DIR / f"{safe}.json"
if path.is_file():
try:
return json.loads(path.read_text())
except json.JSONDecodeError:
pass
url = f"https://pypi.org/pypi/{name}/{version}/json"
try:
with urllib.request.urlopen(url, timeout = 10) as r:
data = json.loads(r.read())
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError):
return None
path.write_text(json.dumps(data))
return data
def transitive_constraint(
name: str, version: str, target: str
) -> tuple[str | None, list[str]]:
"""Return (raw_specifier_string_or_None, list_of_(op,version) tuples)
for the constraint that `name==version` places on `target`.
"""
md = pypi_metadata(name, version)
if not md:
return None, []
info = md.get("info", {}) or {}
requires = info.get("requires_dist") or []
target_l = target.lower()
for req in requires:
# Examples: 'tokenizers (<=0.23.0,>=0.22.0)', 'tokenizers <=0.23.0,>=0.22.0',
# 'tokenizers (>=0.22.0,<=0.23.0); python_version >= "3.9"'
head = req.split(";", 1)[0].strip()
m = re.match(r"^([A-Za-z0-9._-]+)\s*\(?([^)]*)?\)?\s*$", head)
if not m:
continue
if m.group(1).lower() != target_l:
continue
spec = (m.group(2) or "").strip()
return spec, OP_VERSION_RE.findall(spec)
return None, []
def constraint_satisfied(version: str, ops: list[tuple[str, str]]) -> bool:
if not ops:
return True
for op, v in ops:
c = cmp_versions(version, v)
if op == "==":
if c != 0:
return False
elif op == ">=":
if c < 0:
return False
elif op == "<=":
if c > 0:
return False
elif op == ">":
if c <= 0:
return False
elif op == "<":
if c >= 0:
return False
elif op == "!=":
if c == 0:
return False
return True
# ----- Resolved set ----- #
def resolved_set(install_cell: str, colab: dict[str, str]) -> dict[str, str]:
"""Merge install-cell explicit constraints with Colab pip-freeze. Cell
wins.
Resolution order per package, when more than one form is present:
1. Exact `==V` pin in any install line (definitive).
2. Upper-bound `<=V` constraint (pip picks the highest
allowed; that's V).
3. Colab pip-freeze fallback.
The lower-bound `>=V` is intentionally NOT reflected here — a `>=V`
by itself doesn't change the resolved version when a higher
Colab-preinstalled version is already in scope. (R-INST-003 calls
`_install_cell_lower_bound` separately to model that case.)
"""
out = dict(colab)
pinned: set[str] = set()
upper_bounds: dict[str, str] = {}
for inv in iter_pip_invocations(install_cell):
for raw in inv.packages:
sp = parse_spec(raw)
if sp is None:
continue
for op, ver in sp.pins:
if op == "==":
out[sp.name] = ver
pinned.add(sp.name)
elif op == "<=" and sp.name not in pinned:
if (
sp.name not in upper_bounds
or cmp_versions(ver, upper_bounds[sp.name]) < 0
):
upper_bounds[sp.name] = ver
# Apply upper bounds where Colab's preinstall violates them.
for name, ub in upper_bounds.items():
if name in pinned:
continue
existing = out.get(name)
if existing is None or cmp_versions(existing, ub) > 0:
out[name] = ub
return out
# ----- Rules ----- #
def rule_inst_001_git_plus(
install_cell: str, file: str, cell_idx: int
) -> list[Finding]:
findings: list[Finding] = []
for inv in iter_pip_invocations(install_cell):
if any("git+" in p for p in inv.packages) or "git+" in inv.raw:
if any(allowed in inv.raw for allowed in GIT_PLUS_ALLOWLIST):
continue
findings.append(
Finding(
rule = "R-INST-001",
file = file,
cell = cell_idx,
line = inv.line_no,
severity = "error",
message = "install line uses `git+` (volatile, not pinned to a release)",
hint = f"replace with a `pip install foo==X.Y.Z` from PyPI; allow-list is {GIT_PLUS_ALLOWLIST}",
)
)
return findings
def rule_inst_002_no_deps_transitive(
install_cell: str, colab: dict[str, str], file: str, cell_idx: int
) -> list[Finding]:
findings: list[Finding] = []
res = resolved_set(install_cell, colab)
for inv in iter_pip_invocations(install_cell):
if "--no-deps" not in inv.flags:
continue
for raw in inv.packages:
sp = parse_spec(raw)
if sp is None:
continue
v = explicit_pin(sp)
if v is None:
continue
# Check transitive constraints on a curated short list of pkgs we
# care about (transformers/peft/trl/accelerate/torchao/torchcodec).
for target in (
"tokenizers",
"torchao",
"accelerate",
"datasets",
"huggingface-hub",
"huggingface_hub",
):
spec_str, ops = transitive_constraint(sp.name, v, target)
if not ops:
continue
resolved_target = res.get(target.replace("_", "-"), res.get(target))
if resolved_target is None:
continue
if not constraint_satisfied(resolved_target, ops):
findings.append(
Finding(
rule = "R-INST-002",
file = file,
cell = cell_idx,
line = inv.line_no,
severity = "error",
message = f"`--no-deps {sp.name}=={v}` leaves transitive `{target}` unpinned: resolved {resolved_target} violates {sp.name}'s requirement {spec_str!r}",
hint = f'add `"{target}>={ops[0][1]},<={ops[-1][1]}"` (or the exact window from the metadata) to the same install line',
)
)
return findings
def _install_cell_lower_bound(install_cell: str, target: str) -> str | None:
"""Return the highest LOWER bound that any install line places on `target`,
or None if no constraint is present. Treats `==V` as both lower and upper.
Used by R-INST-003: a `pip install torchao>=0.16.0` line is enough to
satisfy a `torchao>=0.16.0` floor even though it's not a `==` pin."""
best: str | None = None
for inv in iter_pip_invocations(install_cell):
for raw in inv.packages:
sp = parse_spec(raw)
if sp is None or sp.name != target:
continue
for op, ver in sp.pins:
if op in ("==", ">="):
if best is None or cmp_versions(ver, best) > 0:
best = ver
return best
def rule_inst_003_peft_torchao(
install_cell: str, colab: dict[str, str], file: str, cell_idx: int
) -> list[Finding]:
findings: list[Finding] = []
res = resolved_set(install_cell, colab)
peft_v = res.get("peft")
if not peft_v:
return findings
torchao_explicit = _install_cell_lower_bound(install_cell, "torchao")
torchao_resolved = torchao_explicit or res.get("torchao")
for floor in PEFT_TORCHAO_FLOOR:
if cmp_versions(peft_v, floor["trigger_peft"]) >= 0:
if (
torchao_resolved is None
or cmp_versions(torchao_resolved, floor["torchao_floor"]) < 0
):
findings.append(
Finding(
rule = "R-INST-003",
file = file,
cell = cell_idx,
severity = "error",
message = f"resolved peft=={peft_v} requires torchao>={floor['torchao_floor']}; install cell asserts torchao={torchao_resolved or '(none)'}",
hint = f'add `!pip install --no-deps --upgrade "torchao>={floor["torchao_floor"]}"` to the install cell',
)
)
return findings
def rule_inst_004_torchcodec_torch(
install_cell: str, colab: dict[str, str], file: str, cell_idx: int
) -> list[Finding]:
findings: list[Finding] = []
res = resolved_set(install_cell, colab)
torch_v = res.get("torch")
codec_v = res.get("torchcodec")
if not torch_v or not codec_v:
return findings
t_minor = version_minor(torch_v)
c_minor = version_minor(codec_v)
allowed = TORCH_TORCHCODEC.get(t_minor)
if allowed is None:
return findings # unknown torch minor — don't flag
if c_minor not in allowed:
findings.append(
Finding(
rule = "R-INST-004",
file = file,
cell = cell_idx,
severity = "error",
message = f"torch=={torch_v} (minor {t_minor}) is incompatible with torchcodec=={codec_v} (minor {c_minor}); compatible minors: {sorted(allowed)}",
hint = f"pin `torchcodec=={sorted(allowed)[-1]}` (or remove the explicit pin and let pip resolve)",
)
)
return findings
def rule_inst_005_transformers_tokenizers(
install_cell: str, colab: dict[str, str], file: str, cell_idx: int
) -> list[Finding]:
"""Fires only when transformers is installed with `--no-deps`. Without
`--no-deps`, pip resolves the correct tokenizers transitively, so the
rule would be a false positive (this is the case for older notebooks
that pin `transformers==4.51.3` but rely on pip's transitive resolver).
The rule targets the exact pattern PR #261b / #264 fixed:
`pip install --no-deps transformers==X` next to a Colab preinstall
`tokenizers` outside transformers's window."""
findings: list[Finding] = []
res = resolved_set(install_cell, colab)
tf = res.get("transformers")
tok = res.get("tokenizers")
if not tf or tok is None:
return findings
# Find the install line that pins transformers and check for --no-deps.
transformers_line_no_deps = False
for inv in iter_pip_invocations(install_cell):
for raw in inv.packages:
sp = parse_spec(raw)
if sp is None or sp.name != "transformers":
continue
if explicit_pin(sp) is None:
continue
if "--no-deps" in inv.flags:
transformers_line_no_deps = True
break
if transformers_line_no_deps:
break
if not transformers_line_no_deps:
return findings
spec_str, ops = transitive_constraint("transformers", tf, "tokenizers")
if not ops:
return findings
if not constraint_satisfied(tok, ops):
findings.append(
Finding(
rule = "R-INST-005",
file = file,
cell = cell_idx,
severity = "error",
message = f"`--no-deps transformers=={tf}` skips pip's transitive resolver; resolved tokenizers={tok} violates {spec_str}",
hint = f'pin `"tokenizers{spec_str}"` (or the matching window) on the same `--no-deps` line',
)
)
return findings
_RE_DOUBLE_BANG = re.compile(r"^[ \t]*!{2,}\s*pip\b", re.MULTILINE)
def rule_inst_006_double_bang(
install_cell: str, file: str, cell_idx: int
) -> list[Finding]:
findings: list[Finding] = []
for m in _RE_DOUBLE_BANG.finditer(install_cell):
line_no = install_cell.count("\n", 0, m.start()) + 1
findings.append(
Finding(
rule = "R-INST-006",
file = file,
cell = cell_idx,
line = line_no,
severity = "warning",
message = "double-bang `!!pip` runs in a subshell; almost always a typo for `!pip`",
hint = "use a single `!`",
)
)
return findings
# ----- AST-level rules over user-facing cells ----- #
class _APIScanner(ast.NodeVisitor):
"""Scan user-facing code cells for known deprecated patterns. R-API-001
(`for_training`/`for_inference`) is intentionally absent: those helpers
are still part of the live unsloth surface as of 2026-05; PR #221 removed
the calls cosmetically from Vision notebooks but did not deprecate the
methods. R-API-004 (live API surface diff) catches actual removals
dynamically without us hand-coding them."""
def __init__(self, file: str, cell_idx: int):
self.file = file
self.cell_idx = cell_idx
self.findings: list[Finding] = []
def visit_Call(self, node: ast.Call) -> None:
# SFTConfig with suboptimal optim (R-API-003).
# NOTE: PR #221 also stripped `gradient_checkpointing` /
# `gradient_checkpointing_kwargs` from a handful of vision notebooks,
# but those kwargs are still accepted by live TRL (verified against
# trl==0.25.1 in the unsloth workspace) so removing them was
# cosmetic, not a deprecation. We do NOT flag them. R-API-004 (live
# API surface diff in the api subcommand) is the right way to catch
# actual TRL signature drift.
if isinstance(node.func, ast.Name) and node.func.id == "SFTConfig":
for kw in node.keywords:
if (
kw.arg == "optim"
and isinstance(kw.value, ast.Constant)
and kw.value.value == "adamw_torch_fused"
):
self.findings.append(
Finding(
rule = "R-API-003",
file = self.file,
cell = self.cell_idx,
line = kw.value.lineno,
severity = "warning",
message = "`optim='adamw_torch_fused'` is suboptimal under Unsloth's memory-efficient training",
hint = 'use `optim="adamw_8bit"` (or `"paged_adamw_8bit"` for GRPO)',
)
)
self.generic_visit(node)
def scan_user_cells(nb: dict[str, Any], file: str) -> list[Finding]:
findings: list[Finding] = []
install_idxs = {i for i, _ in install_cells(nb)}
for i, src in code_cells(nb):
if i in install_idxs:
continue
try:
tree = ast.parse(src)
except SyntaxError:
continue
scanner = _APIScanner(file = file, cell_idx = i)
scanner.visit(tree)
findings.extend(scanner.findings)
return findings
# ----- DONT_UPDATE_EXCEPTIONS coverage ----- #
POLICY_CLAUSES_DEFAULT = [
# (id, regex, applies_to_predicate_on_install_cell_text)
(
"torchao-floor",
re.compile(r"torchao>=0\.16\.0"),
lambda cell: bool(re.search(r"\bpeft\b", cell)),
),
(
"tokenizers-window",
re.compile(r"tokenizers>=0\.22\.0,<=0\.23\.0"),
lambda cell: bool(re.search(r"--no-deps[^\n]*transformers==", cell)),
),
]
def extract_policy_clauses(
update_script: pathlib.Path,
) -> list[tuple[str, re.Pattern[str], Any]]:
"""Best-effort: scan update_all_notebooks.py for canonical phrases used by
multiple templates. Falls back to POLICY_CLAUSES_DEFAULT.
Today we use POLICY_CLAUSES_DEFAULT directly; the regex form is
intentionally permissive so a template-side reword (e.g. comment changes)
doesn't cause false positives. New clauses become 1-line PRs to this list.
"""
return list(POLICY_CLAUSES_DEFAULT)
def rule_l12_exceptions_coverage(notebooks_dir: pathlib.Path) -> list[Finding]:
findings: list[Finding] = []
update_script = notebooks_dir / "update_all_notebooks.py"
exceptions = _extract_dont_update_exceptions(update_script)
clauses = extract_policy_clauses(update_script)
for name in exceptions:
path = notebooks_dir / "nb" / name
if not path.is_file():
continue
nb = load_notebook(path)
for idx, cell in install_cells(nb):
for cid, pat, applies in clauses:
if not applies(cell):
continue
if not pat.search(cell):
findings.append(
Finding(
rule = "R-EXC-001",
file = str(path),
cell = idx,
severity = "error",
message = f"DONT_UPDATE_EXCEPTIONS notebook missing policy clause `{cid}` (pattern {pat.pattern!r})",
hint = f"add the matching install line; the regenerator can't reach this notebook",
)
)
return findings
def _extract_dont_update_exceptions(update_script: pathlib.Path) -> list[str]:
if not update_script.is_file():
return []
src = update_script.read_text(encoding = "utf-8")
m = re.search(r"DONT_UPDATE_EXCEPTIONS\s*=\s*\[(.*?)\]", src, re.DOTALL)
if not m:
return []
out: list[str] = []
for line in m.group(1).splitlines():
m2 = re.match(r'\s*"([^"]+\.ipynb)"', line)
if m2:
out.append(m2.group(1))
return out
# ----- Drift ----- #
def cmd_drift(args: argparse.Namespace) -> int:
nbdir = pathlib.Path(args.notebooks_dir).resolve()
update_script = nbdir / "update_all_notebooks.py"
if not update_script.is_file():
print(f"FAIL: {update_script} not found", file = sys.stderr)
return 2
# Stash any pre-existing dirty state, run the updater, diff, restore.
head = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd = nbdir)
.decode()
.strip()
)
subprocess.run(
["git", "-C", str(nbdir), "stash", "--include-untracked"],
check = False,
capture_output = True,
)
try:
proc = subprocess.run(
[sys.executable, str(update_script)],
cwd = nbdir,
capture_output = True,
text = True,
timeout = 600,
)
except subprocess.TimeoutExpired:
print("FAIL: update_all_notebooks.py timed out (>600s)", file = sys.stderr)
return 2
if proc.returncode != 0:
print(
f"FAIL: update_all_notebooks.py exited {proc.returncode}", file = sys.stderr
)
sys.stderr.write(proc.stderr[-2000:])
return 2
diff_proc = subprocess.run(
["git", "-C", str(nbdir), "diff", "--stat"], capture_output = True, text = True
)
findings: list[Finding] = []
if diff_proc.stdout.strip():
for line in diff_proc.stdout.splitlines():
findings.append(
Finding(
rule = "R-DRIFT-001",
file = line.strip(),
severity = "error",
message = "generator-vs-checked-in drift",
hint = "run `python update_all_notebooks.py` and commit the diff",
)
)
# Restore.
subprocess.run(
["git", "-C", str(nbdir), "checkout", "."], check = False, capture_output = True
)
subprocess.run(
["git", "-C", str(nbdir), "stash", "pop"], check = False, capture_output = True
)
_emit(findings)
return 0 if not findings else 1
# ----- Convert ----- #
def cmd_convert(args: argparse.Namespace) -> int:
nbdir = pathlib.Path(args.notebooks_dir).resolve()
out = pathlib.Path(args.out).resolve()
out.mkdir(parents = True, exist_ok = True)
converter = HERE / "notebook_to_python.py"
if not converter.is_file():
print(f"FAIL: {converter} not found", file = sys.stderr)
return 2
# Convert in batches; the script accepts multiple notebooks at once.
notebooks = list(iter_notebooks(nbdir, include_templates = True))
failed: list[Finding] = []
BATCH = 32
for i in range(0, len(notebooks), BATCH):
chunk = notebooks[i : i + BATCH]
proc = subprocess.run(
[sys.executable, str(converter), "-o", str(out), *map(str, chunk)],
capture_output = True,
text = True,
)
if proc.returncode != 0:
for nb in chunk:
failed.append(
Finding(
rule = "R-CONV-001",
file = str(nb),
severity = "error",
message = "notebook_to_python.py failed for this notebook",
hint = proc.stderr[-200:].strip(),
)
)
print(
f"converted {len(notebooks) - len(failed)}/{len(notebooks)} notebooks to {out}"
)
_emit(failed)
return 0 if not failed else 1
# ----- Lint (combined) ----- #
def cmd_lint(args: argparse.Namespace) -> int:
nbdir = pathlib.Path(args.notebooks_dir).resolve()
colab_path = (
pathlib.Path(args.colab_pin).resolve()
if args.colab_pin
else COLAB_FALLBACK_FILE
)
colab = parse_pip_freeze(colab_path)
if not colab:
print(
f"WARN: Colab pip-freeze empty / missing at {colab_path}; using empty oracle",
file = sys.stderr,
)
findings: list[Finding] = []
notebooks = list(iter_notebooks(nbdir))
for path in notebooks:
try:
nb = load_notebook(path)
except (json.JSONDecodeError, OSError) as e:
findings.append(
Finding(
rule = "R-CONV-002",
file = str(path),
severity = "error",
message = f"notebook unreadable: {e}",
)
)
continue
rel = str(path.relative_to(nbdir))
env = target_environment(rel)
# The Colab oracle is the source of truth ONLY for Colab notebooks.
# Other targets (amd / kaggle / dgx_spark) have their own runtime
# preinstall sets that aren't tracked here yet, so we apply the
# environment-agnostic rules and skip the Colab-specific ones.
oracle = colab if env == "colab" else {}
cells = install_cells(nb)
# Per-cell rules: forbid-pattern checks scoped to a single line.
for idx, cell in cells:
findings += rule_inst_001_git_plus(cell, rel, idx)
findings += rule_inst_006_double_bang(cell, rel, idx)
# Whole-notebook rules: a notebook's install steps are sometimes split
# across multiple cells (initial install + post-install bumps). Merge
# all install cells before resolving compat against Colab.
merged = "\n".join(c for _, c in cells)
if env == "colab" and merged:
first_cell = cells[0][0] if cells else None
findings += rule_inst_003_peft_torchao(merged, oracle, rel, first_cell)
findings += rule_inst_004_torchcodec_torch(merged, oracle, rel, first_cell)
findings += rule_inst_005_transformers_tokenizers(
merged, oracle, rel, first_cell
)
if not args.no_pypi:
findings += rule_inst_002_no_deps_transitive(
merged, oracle, rel, first_cell
)
findings += scan_user_cells(nb, rel)
_emit(findings)
return 0 if not any(f.severity == "error" for f in findings) else 1
# ----- Exceptions coverage ----- #
def cmd_exceptions(args: argparse.Namespace) -> int:
findings = rule_l12_exceptions_coverage(pathlib.Path(args.notebooks_dir).resolve())
_emit(findings)
return 0 if not findings else 1
# ----- API surface scan ----- #
def cmd_api(args: argparse.Namespace) -> int:
surface_path = pathlib.Path(args.surface).resolve()
if not surface_path.is_file():
print(
f"FAIL: {surface_path} not found; run dump-api-surface first",
file = sys.stderr,
)
return 2
surface = json.loads(surface_path.read_text())
converted = pathlib.Path(args.converted_dir).resolve()
findings: list[Finding] = []
fast_models = (
set(surface.get("FastVisionModel", []))
| set(surface.get("FastLanguageModel", []))
| set(surface.get("FastModel", []))
)
for py in sorted(converted.glob("*.py")):
try:
tree = ast.parse(py.read_text(encoding = "utf-8"))
except SyntaxError:
continue
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
base = node.func.value
if isinstance(base, ast.Name) and base.id in (
"FastVisionModel",
"FastLanguageModel",
"FastModel",
):
surface_set = set(surface.get(base.id, []))
if surface_set and node.func.attr not in surface_set:
findings.append(
Finding(
rule = "R-API-004",
file = str(py.name),
line = node.lineno,
severity = "error",
message = f"`{base.id}.{node.func.attr}` is not in the live API surface for the pinned unsloth tag",
hint = "check the unsloth changelog for a renamed/removed API",
)
)
_emit(findings)
return 0 if not findings else 1
# ----- Orchestrator ----- #
def cmd_all(args: argparse.Namespace) -> int:
rcs: list[int] = []
rcs.append(cmd_drift(argparse.Namespace(notebooks_dir = args.notebooks_dir)))
rcs.append(
cmd_lint(
argparse.Namespace(
notebooks_dir = args.notebooks_dir,
colab_pin = args.colab_pin,
no_pypi = args.no_pypi,
)
)
)
rcs.append(cmd_exceptions(argparse.Namespace(notebooks_dir = args.notebooks_dir)))
return 0 if all(rc == 0 for rc in rcs) else 1
def cmd_refresh_colab(args: argparse.Namespace) -> int:
"""Pull the latest Colab pip-freeze.gpu.txt and write to disk."""
out = pathlib.Path(args.out).resolve()
out.parent.mkdir(parents = True, exist_ok = True)
try:
with urllib.request.urlopen(COLAB_PIP_FREEZE_URL, timeout = 15) as r:
data = r.read()
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError) as e:
print(f"FAIL: could not fetch {COLAB_PIP_FREEZE_URL}: {e}", file = sys.stderr)
return 2
out.write_bytes(data)
print(f"wrote {len(data)} bytes to {out}")
return 0
# ----- Helpers ----- #
def _emit(findings: list[Finding]) -> None:
n_err = sum(1 for f in findings if f.severity == "error")
n_warn = sum(1 for f in findings if f.severity == "warning")
for f in findings:
print(json.dumps(f.to_dict(), separators = (",", ":")))
print(f"# total: {n_err} errors, {n_warn} warnings", file = sys.stderr)
def main(argv: list[str] | None = None) -> int:
p = argparse.ArgumentParser(prog = "notebook_validator")
sub = p.add_subparsers(dest = "cmd", required = True)
pa = sub.add_parser("drift")
pa.add_argument("--notebooks-dir", required = True)
pa = sub.add_parser("convert")
pa.add_argument("--notebooks-dir", required = True)
pa.add_argument("--out", required = True)
pa = sub.add_parser("lint")
pa.add_argument("--notebooks-dir", required = True)
pa.add_argument("--colab-pin", default = None)
pa.add_argument(
"--no-pypi",
action = "store_true",
help = "skip rules that require live PyPI metadata fetches",
)
pa = sub.add_parser("exceptions")
pa.add_argument("--notebooks-dir", required = True)
pa = sub.add_parser("api")
pa.add_argument("--converted-dir", required = True)
pa.add_argument("--surface", required = True)
pa = sub.add_parser("all")
pa.add_argument("--notebooks-dir", required = True)
pa.add_argument("--colab-pin", default = None)
pa.add_argument("--no-pypi", action = "store_true")
pa = sub.add_parser("refresh-colab")
pa.add_argument("--out", default = str(COLAB_FALLBACK_FILE))
args = p.parse_args(argv)
return {
"drift": cmd_drift,
"convert": cmd_convert,
"lint": cmd_lint,
"exceptions": cmd_exceptions,
"api": cmd_api,
"all": cmd_all,
"refresh-colab": cmd_refresh_colab,
}[args.cmd](args)
if __name__ == "__main__":
sys.exit(main())