studio: security and hardening pass (auth rate-limit, sandbox, path containment, schema validation, headers) (#5375)
Some checks are pending
Core / Core (HF=default + TRL=default) (push) Waiting to run
Core / Core (HF=4.57.6 + TRL<1) (push) Waiting to run
Core / Core (HF=latest + TRL=latest) (push) Waiting to run
Core / llama.cpp build + smoke (push) Waiting to run
Lint CI / Source lint (Python + shell + YAML + JSON + safety nets) (push) Waiting to run
MLX CI on Mac M1 / dispatch (push) Waiting to run
Security audit / advisory audit (pip + npm + cargo) (push) Waiting to run
Security audit / pip scan-packages :: extras (push) Waiting to run
Security audit / pip scan-packages :: studio (push) Waiting to run
Security audit / pip scan-packages :: hf-stack (push) Waiting to run
Security audit / npm scan-packages (Studio frontend tarballs) (push) Waiting to run
Security audit / workflow-trigger lint (pull_request_target / cache-poisoning) (push) Waiting to run
Security audit / pytest tests/security (push) Waiting to run
Security audit / npm provenance + new install-script diff (push) Waiting to run
Studio API CI / Studio API & Auth Tests (push) Waiting to run
Backend CI / (Python 3.10) (push) Waiting to run
Backend CI / (Python 3.11) (push) Waiting to run
Backend CI / (Python 3.12) (push) Waiting to run
Backend CI / (Python 3.13) (push) Waiting to run
Backend CI / Repo tests (CPU) (push) Waiting to run
Frontend CI / Frontend build + bundle sanity (push) Waiting to run
Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Studio GGUF CI / Tool calling Tests (push) Waiting to run
Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio API CI / Studio API & Auth Tests (push) Waiting to run
Mac Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Mac Studio GGUF CI / Tool calling Tests (push) Waiting to run
Mac Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio UI CI / Chat UI Tests (push) Waiting to run
Mac Studio Update CI / Studio Updating Tests (push) Waiting to run
Studio Tauri CI / Tauri Linux debug build (no codesign) (push) Waiting to run
Studio UI CI / Chat UI Tests (push) Waiting to run
Studio Update CI / Studio Updating Tests (push) Waiting to run
Windows Studio API CI / Studio API & Auth Tests (push) Waiting to run
Windows Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Windows Studio GGUF CI / Tool calling Tests (push) Waiting to run
Windows Studio GGUF CI / JSON, images (push) Waiting to run
Windows Studio UI CI / Chat UI Tests (push) Waiting to run
Windows Studio Update CI / Studio Updating Tests (push) Waiting to run
Wheel CI / Wheel build + content sanity + import smoke (push) Waiting to run

* studio: contain export and dataset paths under their configured roots

resolve_under_root and resolve_dataset_path previously returned absolute
paths unchanged, so an authenticated client could supply
save_directory="/tmp/escape" (or any other absolute path) and have the
exporter drop adapter files anywhere the server user could write. This
turned up during a recent audit pass where an authenticated POST to
/api/export/export/lora with save_directory="/tmp/lora_escape_test"
returned 200 and wrote adapter_model.safetensors, adapter_config.json,
and tokenizer files under /tmp.

The fix is two-layered:

storage_roots.py adds an _assert_contained(resolved, root) helper that
runs after path resolution and rejects any result whose realpath does
not sit under realpath(root). resolve_under_root now rejects '..'
segments and null bytes outright, and only accepts absolute inputs when
they are already inside the configured root (internal call sites that
re-resolve a stored absolute path stay idempotent;
worker.py:resolve_output_dir(output_dir) etc. continue to work).
resolve_dataset_path picks up the same containment rule, scoped to the
three dataset roots.

models/export.py adds field_validator("save_directory", mode="before")
to ExportCommonOptions and ExportGGUFRequest so bad input fails fast at
422 with a clear message rather than a 500 deep inside the resolver.
The validator rejects empty/whitespace, null bytes, control chars,
strings longer than 255 chars, absolute paths, and '..' segments.

routes/export.py:_export_details now returns os.path.relpath(output_path,
exports_root()) so the Export Complete dialog and /api/models/loras no
longer leak the absolute install prefix to the UI; the basename is
used as a last-resort fallback.

Verified end to end:
- POST /api/export/export/lora {"save_directory":"/tmp/foo"} -> 422
  "save_directory must be a name or relative path under the export
  root; absolute paths are rejected". /tmp/foo is not created.
- "../../etc/escape" -> 422 "may not contain '..' segments".
- save_directory="my_subdir" -> still accepted (400 only because the
  test had no checkpoint loaded yet, not because of validation).
- Internal idempotent re-resolve via resolve_export_dir(absolute path
  that is already under exports_root) returns the same path unchanged.

* studio/sandbox: harden bash + python tool execution

The sandboxed Bash and Python tool channels in Chat ran with a thin
preexec hook (PR_SET_NO_NEW_PRIVS + RLIMIT_FSIZE only). Bash had a
small word blocklist; Python had an AST safety pass aimed at
signal-tampering and shell-escape primitives. An audit pass showed
several gaps that a tool-calling model could trigger inadvertently:

- bash curl/wget/nc reached AWS IMDSv2 and returned live STS
  credentials for the instance role.
- python "import socket; s.connect((169.254.169.254, 80))"
  reached the same endpoint regardless of the bash blocklist.
- "cat /etc/passwd" was blocked at the bash side (because "passwd"
  is in the blocklist), but "open('/etc/passwd').read()" in Python
  happily returned its contents.
- "chr(115)+chr(117)+chr(100)+chr(111)" style dynamic-arg
  construction slipped through the AST shell-escape check.
- The supervisor used proc.kill() on timeout, which only signals
  the immediate pid; bash-backgrounded children survived. A fork
  bomb could spawn for the full 300s timeout window.
- Session work directories under ~/studio_sandbox/<id>/ were
  created with default umask (0o755), so any other UID on the host
  could enumerate them.
- session_id sanitisation used a one-shot str.replace("..",""),
  which is non-iterative and a small footgun.

This commit takes a conservative middle path: the sandbox still
runs as the Studio UID with no namespace tricks where the kernel
disallows them, but every chokepoint is tightened.

_sandbox_preexec now:
- calls os.setsid() so children share a process group; the
  supervisor uses os.killpg(SIGKILL) on timeout/cancel so
  backgrounded children die with the parent (new _kill_process_tree
  helper, wired into _cancel_watcher and both _bash_exec /
  _python_exec timeout branches).
- calls os.umask(0o077) so files the child writes default to 0o600.
- applies PR_SET_PDEATHSIG=SIGKILL so an orphaned child dies if
  Studio exits.
- best-effort unshare(CLONE_NEWNET) for a private network namespace
  (failure is logged and swallowed; defense-in-depth is still in
  place via the bash blocklist and the AST checker below).
- sets RLIMIT_NPROC=10000 (tunable via UNSLOTH_STUDIO_SANDBOX_NPROC),
  RLIMIT_AS=8GB, RLIMIT_CPU=300, RLIMIT_NOFILE=1024. The 10k NPROC
  figure is chosen to sit well above the ~500 LWPs a healthy Studio
  + llama-server combination already uses while still capping a
  runaway fork bomb. NPROC counts LWPs per real UID, so a lower
  figure (e.g. 256) starves legitimate bash forks
  ("bash: fork: retry: Resource temporarily unavailable").

_get_workdir:
- rejects session_id that doesn't match [A-Za-z0-9_-]{1,64};
  non-matching values bucket into a shared "_invalid" dir.
- chmod 0o700 on both the workdir and on ~/studio_sandbox/ so
  other UIDs cannot read another session's contents.

_BLOCKED_COMMANDS_COMMON gains: doas, pkexec, halt, poweroff, curl,
wget, nc, ncat, netcat, socat, ssh, scp, sftp, rsync, eval, source.
The intent is to keep general bash usage working (echo, ls, pipes,
loops, for, head, etc.) while denying the obvious egress and
escalation paths.

The AST checker (_check_signal_escape_patterns) is split into the
existing shell/signal/loop checks plus a new narrow IO denylist:
- Always flag non-literal args to anything in _SHELL_EXEC_FUNCS,
  not just _STRING_SHELL_FUNCS. Closes the dynamic-arg bypass.
- Reject calls to socket.create_connection, socket.socket().connect,
  urllib.request.urlopen, http.client.HTTP*Connection, requests.*,
  httpx.* whose literal host argument is in a cloud-metadata
  denylist (169.254.169.254 + 169.254.* + 100.64.*, plus the
  GCP/Alibaba/ECS metadata hostnames and IPv6 link-local). Public
  hosts (example.com, huggingface.co, ...) still work. Dynamic
  hosts cannot be statically blocked; mitigated by the bash
  blocklist + the netns where the kernel allows it.
- Reject literal open("/etc/passwd"), /etc/shadow, /etc/sudoers,
  /etc/ssh/*, and /proc/<pid>/environ. Other files
  (/etc/os-release, /etc/hostname, /tmp/*, user dirs) still work.

The _check_code_safety summariser is updated to include the new
network_calls and sensitive_file_reads buckets in its error string.

Regression-checked: echo, sleep, ls /tmp, for loops, piped helpers
(echo a | tr a A), urllib.request.urlopen("http://example.com"),
socket.getaddrinfo("example.com",80), open("/etc/os-release"),
open("/tmp/...","w") all still succeed. curl, wget, nc, ssh, rm,
socket.create_connection(("169.254.169.254",80)),
open("/etc/passwd"), open("/proc/self/environ") all correctly
blocked.

* studio: rate-limit login, rotate refresh tokens, add logout, security headers, gate bootstrap injection

A pass over the auth surface found a cluster of related issues that this
commit closes together.

Login (routes/auth.py):
- Add an in-memory per-IP login rate limiter. Five failed POSTs to
  /api/auth/login inside a 60s window produce 429 with Retry-After.
  A successful login clears the bucket. Previously 30 wrong passwords
  in under one second was accepted as 30x 401, which combined with
  the (now fixed) admin-username leak from /api/auth/status made
  brute-force trivial against a small password.

Logout (routes/auth.py):
- New POST /api/auth/logout returns 204 and calls
  storage.revoke_user_refresh_tokens(subject) so the refresh token
  is no longer valid. Previously POST /api/auth/logout returned 405
  and there was no way to invalidate refresh tokens short of
  changing the password. Frontend session.ts already calls
  clearAuthTokens() to drop localStorage; the new endpoint lets the
  client also tell the server to revoke server-side state.

Refresh-token rotation (routes/auth.py + auth/storage.py):
- New storage.consume_refresh_token(token) atomically validates +
  deletes a refresh token, returning (username, is_desktop). The
  /api/auth/refresh handler now mints both a new access AND a new
  refresh token; the supplied token becomes invalid. Replaying a
  consumed refresh returns 401 "Invalid or expired refresh token".
  The previous refresh_access_token helper is left in place for
  callers that intentionally want the non-rotating shape; nothing
  in the route layer uses it now.

/api/auth/status no longer leaks default_username (models/auth.py +
routes/auth.py):
- AuthStatusResponse.default_username becomes Optional[str] with a
  None default; the handler always returns None. The frontend already
  hardcodes HIDDEN_LOGIN_USERNAME = "unsloth" (auth-form.tsx:82), so
  no UI change is required.

window.__UNSLOTH_BOOTSTRAP__ no longer auto-injects (main.py):
- _inject_bootstrap is now opt-in via the
  UNSLOTH_STUDIO_INJECT_BOOTSTRAP env var. The previous default
  (inject whenever requires_password_change is true) embedded the
  plaintext bootstrap password into the first-boot HTML for any
  caller that hit /, /change-password, or any unknown SPA path.
  Browser extensions and any XSS payload on the page could read it
  trivially. With the new gate the bootstrap password lives only in
  the auth/.bootstrap_password file (mode 0o600) where it has always
  been; users typing it into a current-password field is the right
  UX. routes/auth.py:change_password also clears
  app.state.bootstrap_password defensively.

Security headers + server fingerprint (main.py + run.py):
- New SecurityHeadersMiddleware adds Content-Security-Policy,
  X-Frame-Options: DENY, X-Content-Type-Options: nosniff,
  Referrer-Policy: no-referrer,
  Permissions-Policy: camera=(), microphone=(), geolocation=(),
  interest-cohort=(), and stamps server: unsloth-studio so the
  generic uvicorn banner no longer fingerprints the stack. The
  uvicorn.Config gains server_header=False so it stops emitting its
  own Server header.

/api/health minimisation (main.py):
- Unauthenticated GET /api/health returns just
  {"status":"healthy","timestamp":...} so load-balancer liveness
  probes keep working without leaking version, device_type,
  chat_only, desktop_protocol_version, or studio_root_id to
  arbitrary callers. A request that presents a valid Bearer token
  still gets the full diagnostic payload so internal launchers and
  sibling-Studio detection (which compares studio_root_id) keep
  working.

Verification:
- 30 wrong-password POSTs to /api/auth/login -> first 5 = 401, 6th
  through 30th = 429.
- POST /api/auth/logout with a fresh token -> 204. The matching
  refresh token then fails 401.
- Login -> R1; /api/auth/refresh with R1 -> new access + R2 (R2 !=
  R1); /api/auth/refresh with R1 again -> 401; /api/auth/refresh
  with R2 -> still succeeds once and rotates again.
- curl /api/auth/status -> default_username: null.
- curl http://127.0.0.1/ does not contain __UNSLOTH_BOOTSTRAP__.
- curl -I / shows CSP, X-Frame-Options: DENY,
  X-Content-Type-Options: nosniff, Referrer-Policy: no-referrer,
  Permissions-Policy, and server: unsloth-studio.
- curl /api/health unauthenticated -> {status, timestamp} only.
  curl with Authorization: Bearer <valid> -> full payload.
- Existing /api/system, /api/models/list, /api/train/status,
  /api/inference/status, /api/auth/api-keys, login flow, SPA root
  all still return 200 after the changes (regression smoke).

* studio: add SecurityHeadersMiddleware, MaxBodyMiddleware, /recipes redirect, gate _inject_bootstrap, minimise /api/health

This commit lands the main.py-side changes that share a single
middleware-registration spot. They are kept together because every
change here is either (a) a top-level middleware definition that has
to be added next to LoggingMiddleware, or (b) a route handler at the
same file-level.

SecurityHeadersMiddleware (Content-Security-Policy, X-Frame-Options:
DENY, X-Content-Type-Options: nosniff, Referrer-Policy: no-referrer,
Permissions-Policy, server: unsloth-studio). The previous responses
emitted no CSP, no XFO, no Referrer-Policy and were stamped
server: uvicorn.

MaxBodyMiddleware rejects POST/PUT/PATCH on the inference / dataset /
data-recipe / train / export prefixes when Content-Length exceeds
UNSLOTH_STUDIO_MAX_BODY_MB (default 100). The audit hit this by
attaching a 50 MB plain-text file to a chat message and watching
Studio base64-encode it into the JSON body; uvicorn has no enforced
cap so the only previous guard was the per-file 50 MB ceiling that
data-recipe upload routes already enforce. The new middleware extends
that ceiling to the OpenAI-compat path that the Chat attachments
flow through. Verified: a 200 MB JSON POST to /v1/chat/completions
returns HTTP 413 "Request body too large (209,715,264 bytes; max
104,857,600)". A small valid request continues to reach the handler.

_inject_bootstrap is gated behind UNSLOTH_STUDIO_INJECT_BOOTSTRAP.
The previous default was to inline window.__UNSLOTH_BOOTSTRAP__ =
{username, password} into the first-boot HTML whenever
requires_password_change was true, which exposed the plaintext
bootstrap password to any browser extension, page script, or LAN
caller on -H 0.0.0.0. The bootstrap password remains in the on-disk
.bootstrap_password file (mode 0o600) where it has always lived;
users typing it into a current-password field is the right UX.

/api/health unauthenticated returns {"status":"healthy","timestamp":
...} only; the previous payload (version, device_type, chat_only,
desktop_protocol_version, supports_desktop_auth, studio_root_id,
native_path_leases_supported) is preserved for callers that present
a valid Bearer token, so internal launchers and sibling-Studio
detection (which compares studio_root_id) keep working.

/recipes -> /data-recipes 308 redirect. The Data Recipes page lives
at /data-recipes; users typing /recipes hit the SPA catch-all and
saw "Not Found". The redirect also preserves any tail path, so
/recipes/<rest> -> /data-recipes/<rest>.

Verified end to end with curl: CSP / XFO / X-Content-Type-Options /
Referrer-Policy / Permissions-Policy all present on /, server header
is now unsloth-studio (uvicorn's own banner is suppressed via
server_header=False in run.py from the auth-batch commit). Followed
the /recipes redirect lands on the SPA HTML.

* studio: bound TrainingStartRequest hyperparameters at the schema level

POST /api/train/start accepted any value for learning_rate, batch_size,
max_steps, max_seq_length, warmup_steps, warmup_ratio, num_epochs,
save_steps, weight_decay, gradient_accumulation_steps, lora_r,
lora_alpha and lora_dropout, including -1, 0, 1e9, and non-numeric
strings like 'abc' or 'two' (which silently coerce to 0 in the
trainer). Probing showed the API returning 200 to learning_rate=-1
and batch_size=0; only max_steps had any partial clamping.

This commit adds field_validator on every numeric hyperparameter.
Bounds are chosen wide enough to span realistic single-host
configurations (B200 with 180 GB of memory comfortably fits the
upper end) while rejecting the values that always produce broken
training:

- learning_rate: parses str/float, requires 0 < lr < 1.0. Non-numeric
  input raises with "learning_rate must be parseable as float (got
  'abc')" instead of silently coercing to 0.
- batch_size: [1, 1024].
- gradient_accumulation_steps: [1, 4096].
- num_epochs: [1, 1000].
- max_steps: [1, 1_000_000].
- max_seq_length: [1, 131072].
- warmup_steps: [0, max_steps].
- warmup_ratio: [0.0, 1.0].
- save_steps: [0, 1_000_000].
- weight_decay: [0, 10] (typical 0..0.1).
- lora_r: [1, 512].
- lora_alpha: [1, 1024].
- lora_dropout: [0.0, 1.0).

Each validator names the offending field in its ValueError message
so the 422 response body identifies which input is bad. The
learning_rate validator returns its result as str (the schema field
type is str("2e-4") for backwards compatibility) so existing call
sites that float() the value continue to work.

Verified:
- learning_rate=-1 -> 422 "learning_rate must be > 0 (got -1.0);
  typical range is 1e-6 .. 1e-3".
- learning_rate='abc' -> 422 "must be parseable as float".
- batch_size=-1 / 0 / 999999 -> 422 "batch_size must be in [1, 1024]".
- batch_size='two' -> 422 (pydantic int parser).
- max_steps=0 / -5 -> 422 "must be a positive int".
- max_seq_length=200000 -> 422 "must be in [1, 131072]".
- warmup_ratio=2.5 -> 422 "must be in [0.0, 1.0]".
- lora_dropout=1.5 -> 422 "must be in [0.0, 1.0)".
- Valid request with learning_rate='2e-4', batch_size=1, max_steps=5
  passes validation and the training run starts as normal.

* studio: redact image-decode errors, clean checkpoint dirs on cancel, tolerate Stop-button + tool-result message shapes

Three small fixes that fall under "do not let the audit findings
become user-visible papercuts".

routes/inference.py - image-decode error redaction (the audit hit
this with a 0-byte / malformed / wrong-extension image upload). The
three image-normalise sites previously raised HTTPException(400,
detail=f"Failed to process image: {e}"). When PIL raised
UnidentifiedImageError(io.BytesIO(raw)) the message string included
"<_io.BytesIO object at 0x7e40a5d7bf60>", leaking both the Python
class name (confirming the PIL/io stack) and a heap address (mildly
useful for ASLR-bypass chaining if another memory-corruption bug is
ever found). Each site now catches UnidentifiedImageError and
returns the generic "Unsupported or corrupt image format"; the
fall-through generic except returns "Failed to process image". No
exception-repr is interpolated into a response body anywhere along
these paths.

core/training/training.py - checkpoint cleanup on cancel. When a
user clicks Cancel Training, the trainer flips _cancel_requested=True
and the supervisor force-terminates the subprocess. The trainer
writes checkpoint-<step> directories under output_dir every
save_steps; previously these survived the cancel and accumulated on
disk (the audit recorded ~67 MB stuck after a 200-step cancel with
save_steps=20). New helper _cleanup_cancelled_checkpoints(output_dir)
globs checkpoint-<int> entries and removes them. It is gated by a
realpath containment check against outputs_root() so it cannot
accidentally rmtree anything outside the configured outputs root.
force_terminate() invokes the helper after the subprocess join when
_cancel_requested is true. Stop-and-Save runs are unaffected because
that path keeps _cancel_requested=False.

models/inference.py - chat message shape tolerance. Two related
frontend interactions used to crash the request validator:

- After the Stop button truncates a generation, the frontend
  retained {role:"assistant", content:""} in the conversation
  history and replayed it on the next send. ChatMessage previously
  required role="assistant" to have non-empty content or tool_calls,
  so the next message returned 422 and the thread was permanently
  broken. The validator now normalises empty assistant content to
  None so the request round-trips and the trailing empty turn can
  be ignored downstream.

- The frontend's second-round tool POST drops the streamed
  tool_call_id, hitting the strict-spec check "role=tool requires
  tool_call_id". The validator now synthesises an opaque id
  (call_<8 hex>) when missing, so the request reaches the handler
  and the model's final summarising response gets generated. The
  proper fix lives in the frontend (carry the streamed id through
  the second POST) and will follow.

Verified end to end with curl: HTTP 400 (model not loaded) on both
the empty-assistant history shape and the tool-result-without-id
shape, instead of HTTP 422 from the schema validator.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: tighten code comments from security-hardening pass

Trim verbose docstrings and inline finding references added in the
previous commits in this branch. Functionality unchanged.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: await get_current_subject in /api/health and make refresh-token consumption atomic

The /api/health auth probe called get_current_subject(creds) without
awaiting it. The coroutine object is truthy, so any caller presenting a
Bearer header (valid or not) received the full diagnostic payload
including version, device_type, studio_root_id, etc. Await the coroutine
and treat HTTPException as 'fall back to the minimal liveness payload'.

consume_refresh_token did SELECT then DELETE WHERE id under default
autocommit isolation. Two concurrent POST /api/auth/refresh requests
could both win the SELECT before either DELETE ran, defeating
single-use refresh-token rotation. Replace with a single
DELETE ... WHERE token_hash = ? AND expires_at >= ? RETURNING ...
statement so the validate-and-delete lands as one atomic op under
SQLite's write lock (3.45.1 supports RETURNING; min was 3.35).

* studio: enforce body cap on chunked uploads and drop unsafe-inline from script-src

MaxBodyMiddleware previously only inspected the declared Content-Length
header; clients omitting it or sending Transfer-Encoding: chunked
bypassed the cap and could still drive an OOM via the downstream
JSON / file readers on /v1/chat/completions, /api/inference, /api/data-recipe,
/api/datasets, /api/train, /api/export. Rewrite as a raw ASGI middleware
that drains and counts http.request frames, replies 413 once the running
total exceeds UNSLOTH_STUDIO_MAX_BODY_MB before invoking the FastAPI
handler, and replays the buffered body to downstream so route code that
calls request.json() / await request.body() works unchanged.

CSP previously included 'unsafe-inline' on script-src, which defeats the
main XSS protection. The frontend bundle does not need inline scripts;
the only inline <script> the backend ever emits is _inject_bootstrap,
which is opt-in via UNSLOTH_STUDIO_INJECT_BOOTSTRAP. Drop 'unsafe-inline'
from script-src by default; when _inject_bootstrap fires, generate a
per-response nonce, embed it on the inlined <script>, and have
SecurityHeadersMiddleware splice 'nonce-XXX' into the CSP for that one
response (the internal x-internal-script-nonce header is popped before
the response leaves the server). 'unsafe-inline' stays on style-src for
Vite-injected styles.

* studio: drop empty assistant sentinel before passthrough

ChatMessage._validate_role_shape normalises role="assistant", content=""
(the post-Stop sentinel emitted by the frontend) to content=None so the
in-process path can drop it via _extract_content_parts. The passthrough
path then ran m.model_dump(exclude_none=True), which strips the now-None
content key entirely, sending {"role":"assistant"} to llama-server / the
OpenAI-compat backend. That fails upstream and leaves the user without a
recoverable Stop->resume.

Add _drop_empty_assistant_sentinels and call it at both passthrough
message origins: _openai_messages_for_passthrough (covers
/v1/chat/completions and the Responses API which routes through it) and
the anthropic_messages_to_openai output before
_anthropic_passthrough_*. Assistant messages that carry only tool_calls
(no content) are preserved.

* studio/tests: cover audit-fix surfaces and rebase pre-existing tests

Adds and updates pytest coverage for the four bot-flagged audit fixes
landed earlier in this branch and rebases two pre-existing tests that
were broken by the relaxed-validator and /api/health auth-gate changes.

studio/backend/tests/test_middleware.py (new)
  MaxBodyMiddleware: small protected, large declared, unprotected
  passthrough, chunked-upload-over-cap rejection (the regression for
  the original Content-Length-only gap), and chunked-under-cap replay.
  SecurityHeadersMiddleware: script-src no longer carries
  'unsafe-inline', style-src still does, default headers
  (XFO/XCTO/Referrer-Policy/Permissions-Policy/server), and the
  internal x-internal-script-nonce header is consumed by the
  middleware and converted to 'nonce-XXX' in the CSP.
  /api/health: no auth -> minimal, invalid Bearer -> minimal
  (the await regression), valid Bearer -> full diagnostic payload.

studio/backend/tests/test_desktop_auth.py
  consume_refresh_token: second-call returns None, expired returns
  None, and a 64-thread concurrent pile-up against the same hash
  produces exactly one successful consumer (regression for the
  SELECT-then-DELETE race).
  test_health_response_reports_desktop_capability_fields: rebase
  against the new health_check(request) signature by going through
  TestClient with a real bearer instead of asyncio.run-ing the
  handler directly.

studio/backend/tests/test_openai_tool_passthrough.py
  Pin the new ChatMessage tolerance: assistant without content or
  tool_calls is tolerated (normalises content -> None), empty-string
  and empty-list assistant content normalise to None, and a missing
  / empty tool_call_id on role='tool' is synthesised as call_<hex>
  rather than raising. Tests for _drop_empty_assistant_sentinels
  cover the three drop shapes (empty string, empty list, missing
  content key), preservation of assistant text and tool_calls-only
  messages, and end-to-end through
  _openai_messages_for_passthrough.

studio/backend/main.py
  SecurityHeadersMiddleware.dispatch used response.headers.pop(...)
  for the nonce-header handoff; Starlette's MutableHeaders has no
  pop. Read-then-del so the internal handoff header is still
  stripped before the response leaves the server.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio/tests: rebase three more pre-existing CI tests against this branch

CI on PR #5375 was red on three tests that were tuned for behaviour
predating this branch. Updates each so the assertions match what the
audit fixes intentionally changed; no production code touched.

studio/backend/tests/test_trained_model_scan.py
  test_scan_trained_models_includes_lora_and_full_finetune_outputs
  passed an absolute tmp_path through scan_trained_models, which now
  runs resolve_output_dir / _assert_contained against outputs_root().
  Repoint outputs_root() at tmp_path via monkeypatch so the fixture
  dirs land under the configured root and the realpath containment
  check passes.

tests/test_studio_install_workspace_guard.py
  test_health_endpoint_exposes_studio_root_id_not_raw_path read
  the first 1500 bytes after @app.get("/api/health") and asserted on
  the studio_root_id literal. The handler grew (unauth short-circuit
  + await dependency gate) and the literal slid past the byte window.
  Replace the fixed window with a slice up to the next top-level
  @app.* decorator so the test surveys the whole handler regardless
  of size.

tests/studio/studio_api_smoke.py
  The "login burst (5x wrong pw) -> 401 each" assertion was tagged
  "When/if we add one, this assertion updates in the same PR." We
  added the per-IP rate-limit in routes/auth.py
  (_LOGIN_MAX_FAILS=5/60s) but missed the assertion update. Rewrite
  the burst probe to observe the new invariant: at least one 401,
  eventual transition to 429, and Retry-After present on the 429.
  Adds a small _login_with_headers helper since the existing login()
  helper drops response headers.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ci(studio-ui): set UNSLOTH_STUDIO_INJECT_BOOTSTRAP=1 for Playwright Studios

The Chat UI Playwright test drives the first-boot change-password
form, which (per playwright_chat_ui.py step "1. Change-password
through the UI") pre-seeds the hidden current_password field from
window.__UNSLOTH_BOOTSTRAP__. That global is only emitted when the
backend's _inject_bootstrap path fires, which since the security
pass on this branch is gated behind UNSLOTH_STUDIO_INJECT_BOOTSTRAP
and defaults to off. Without the global, the React form's
current_password validator never satisfies, the submit button stays
disabled, and the composer.wait_for() probe times out on
/change-password.

Re-enable injection only for the CI Studios that drive the chat UI
across linux/mac/windows. Production deployments are unaffected: the
env var has to be explicitly opted into, and the on-disk
auth/.bootstrap_password remains the source of truth for human users
typing the password in by hand.

Covers all eight Studio launch sites: the primary chat-ui boot and
the "extra UI tests" boot for each of the three OSes, plus the
pipeTransport JSON-crash retry relaunches in the macOS workflow that
re-spawn Studio mid-job.

A follow-up frontend PR will add a visible current_password input so
the form satisfies its own validator without needing the bootstrap
auto-fill at all; once that lands this CI knob can come back out.

* studio/sandbox: drop unshare(CLONE_NEWNET); add trusted-host allowlist; block sandbox file uploads; raise CPU rlimit default to 600 s

CLONE_NEWNET inside _sandbox_preexec silently killed every outbound
HTTP request from sandboxed Python whenever the kernel allowed
unprivileged user namespaces. requests.get('https://huggingface.co'),
urllib.request.urlopen('https://en.wikipedia.org/wiki/...'),
socket.connect(('arxiv.org', 443)) all failed despite the AST visitor
intending to allow them. The bash blocklist (curl / wget / nc / ssh /
scp / sftp / rsync / socat / eval / source) plus the AST-level
metadata-host denylist still carry the network policy after this
change; CLONE_NEWNET was redundant with both.

Add _TRUSTED_PUBLIC_HOST_LITERALS + _TRUSTED_PUBLIC_HOST_SUFFIXES
(~100 informational hosts: Wikipedia language subdomains, Wikimedia,
Wikidata, Google search, Bing, DuckDuckGo, HuggingFace, GitHub,
raw.githubusercontent.com, arXiv, StackOverflow / Stack Exchange,
MDN, docs.python.org, PyTorch / TensorFlow / NumPy / pandas docs,
pypi / files.pythonhosted.org / npmjs / crates.io, ReadTheDocs,
arXiv, Britannica, BBC / Reuters / Nature / Science, NASA / CDC /
NIH / WHO open data, api.weather.gov). The visitor now blocks
literal hosts that are neither metadata nor trusted with a short
LLM-readable string so the model can retry with an allowed source
instead of choking on a multi-line error.

Block upload-shape calls regardless of host: requests.post / put /
patch / delete / request with files= or data=open(...) /
data=bytes_literal; httpx equivalents; urllib.request.urlopen /
Request with data=...; HuggingFace upload_file / upload_folder /
upload_large_folder / create_commit (module-level FQ paths AND
method-name match on any receiver). Message: "Blocked: file upload
disallowed in sandbox".

Bump UNSLOTH_STUDIO_SANDBOX_CPU_S default 300 -> 600 s so long
agentic chains that span multiple tool calls don't get SIGXCPU'd
mid-stride. Env-var override path is unchanged.

Host normalisation now strips trailing dot, userinfo @, and explicit
port before allowlist / denylist comparison so trailing-DNS-dot,
userinfo-smuggling, and explicit-:443 URLs are decided correctly.

* studio: raise default request-body cap from 100 MB to 500 MB

UNSLOTH_STUDIO_MAX_BODY_MB default goes 100 -> 500 to comfortably
cover vision + audio + multi-recipe-batch JSON payloads. The
MaxBodyMiddleware stream-counting logic from this branch's earlier
06ec088 already handles chunked bodies up to the new cap; env-var
override path is unchanged for callers that want a tighter limit.

* studio/auth: restore /api/auth/status.default_username to 'unsloth'

This branch's earlier b39e9a4 changed default_username to None on the
public /api/auth/status endpoint so the username field didn't leak to
unauthenticated callers. In practice this regressed third-party
clients (and the in-tree React login form's pre-fill UX) without
adding meaningful security: the bootstrap password is the actual
secret, and the username 'unsloth' is the documented default.

Pin default_username to storage.DEFAULT_ADMIN_USERNAME ('unsloth')
and tighten the response model so the field is required rather than
Optional. Anyone who needs anonymisation can still reach for an
allow-list deployment with auth disabled.

* studio/training: raise max_seq_length / batch_size / lora_r / lora_alpha caps

This branch's 7102815 introduced field validators with conservative
caps. The follow-up loosens them so long-context experiments and
high-rank LoRA exploration aren't gated at the schema layer:

  _MAX_BATCH_SIZE   1024     -> 4096
  _MAX_SEQ_LENGTH   131_072  -> 2_000_000   (2M tokens)
  lora_r cap        512      -> 16_384      (_MAX_LORA_R)
  lora_alpha cap    1024     -> 32_768      (_MAX_LORA_ALPHA)

_MAX_GRAD_ACCUM / _MAX_STEPS / _MAX_EPOCHS / lora_dropout /
warmup_ratio / weight_decay are unchanged. Hardware (VRAM, host
RAM, kernel launch latency) is now the binding constraint at the
new caps, which is the correct ordering -- the validator stays a
sanity check on -1 / 0 / 'abc' style garbage, not a usability gate.

* studio/tests: cover sandbox allowlist + upload block + raised training caps

studio/backend/tests/test_sandbox_tools.py (new):
  TestMetadataHostDenylist     -- short "Blocked: cloud-metadata host"
                                  message on AWS IMDS, GCP metadata,
                                  Alibaba ECS, AWS IPv6 IMDS, 169.254/16.
  TestTrustedHostAllowlist     -- Wikipedia (any language subdomain),
                                  Google, DuckDuckGo, HF, raw GitHub,
                                  arXiv, StackOverflow / family,
                                  MDN, docs.python.org, pypi, BBC,
                                  api.weather.gov, NumPy / PyTorch docs.
  TestUntrustedHostBlock       -- example.com / random unlisted host
                                  rejected with the short "Blocked: host
                                  not in sandbox allowlist; use an
                                  allowed informational source" message.
                                  Dynamic URLs (computed var) still pass
                                  -- documented limit of static analysis.
  TestHostNormalization        -- trailing dot, explicit :443, uppercase,
                                  userinfo-@-smuggle all decided
                                  correctly without false-block /
                                  false-pass.
  TestUploadDenylist           -- requests / httpx / urllib.urlopen with
                                  files= / data=open / data=bytes,
                                  HfApi().upload_file / upload_folder /
                                  create_commit, module-level
                                  huggingface_hub.upload_folder. POST
                                  json= to trusted host still passes.
  TestSandboxCpuRlimitDefault  -- pin UNSLOTH_STUDIO_SANDBOX_CPU_S=600
                                  default and confirm CLONE_NEWNET
                                  source line is gone.
  TestMaxBodyDefault           -- pin UNSLOTH_STUDIO_MAX_BODY_MB=500
                                  default.

studio/backend/tests/test_studio_train_validation.py (new):
  Pin at-cap-accepts / over-cap-rejects boundaries for
  max_seq_length=2_000_000, batch_size=4_096, lora_r=16_384,
  lora_alpha=32_768 so a future regression that tightens them back
  without explicit user opt-in is caught.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: tighten code comments across the security-hardening pass

* studio: always inject bootstrap credentials on first boot

The UNSLOTH_STUDIO_INJECT_BOOTSTRAP gate added an extra
terminal-to-browser copy-paste on every fresh install. In practice
the LAN credential leak it guarded against is narrow: the password
is one-time, the user rotates it on the very next click, the
default Studio bind is 127.0.0.1, and -H 0.0.0.0 already exposes
the entire API surface. Drop the gate so the inject fires whenever
a bootstrap password is still pending. The CSP nonce wiring stays
in place; the inline script remains the only inline script the
backend ever emits.

The three Playwright UI smoke workflows lose their
UNSLOTH_STUDIO_INJECT_BOOTSTRAP=1 lines along with the explanatory
comment blocks since the inject now happens by default.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wasim Yousef Said <wasimysdev@gmail.com>
This commit is contained in:
Daniel Han 2026-05-13 06:12:18 -07:00 committed by GitHub
parent ef9f672fe8
commit 0881a7a5d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 2121 additions and 166 deletions

View file

@ -480,6 +480,37 @@ def save_refresh_token(
conn.close()
def consume_refresh_token(token: str) -> Optional[Tuple[str, bool]]:
"""Atomically validate-and-delete a refresh token for single-use rotation.
DELETE RETURNING fuses validate and delete into one statement so two
concurrent refresh requests cannot both consume the same token.
"""
token_hash = _hash_token(token)
now = datetime.now(timezone.utc).isoformat()
conn = get_connection()
try:
conn.execute(
"DELETE FROM refresh_tokens WHERE expires_at < ?",
(now,),
)
cur = conn.execute(
"""
DELETE FROM refresh_tokens
WHERE token_hash = ? AND expires_at >= ?
RETURNING username, is_desktop
""",
(token_hash, now),
)
row = cur.fetchone()
conn.commit()
if row is None:
return None
return row["username"], bool(row["is_desktop"])
finally:
conn.close()
def verify_refresh_token(token: str) -> Optional[Tuple[str, bool]]:
"""
Verify a refresh token and return the username plus desktop marker.

View file

@ -10,6 +10,7 @@ Supports web search (DuckDuckGo), Python code execution, and terminal commands.
import ast
import http.client
import os
import signal
os.environ["UNSLOTH_IS_PRESENT"] = "1"
@ -58,21 +59,37 @@ _MAX_OUTPUT_CHARS = 8000 # truncate long output
_BLOCKED_COMMANDS_COMMON = frozenset(
{
"rm",
"sudo",
"su",
"dd",
"chmod",
"chown",
"mkfs",
"shutdown",
"reboot",
"passwd",
"mount",
"umount",
"fdisk",
"sudo",
"su",
"doas",
"pkexec",
"shutdown",
"reboot",
"halt",
"poweroff",
"kill",
"killall",
"pkill",
"passwd",
"curl",
"wget",
"nc",
"ncat",
"netcat",
"socat",
"ssh",
"scp",
"sftp",
"rsync",
"eval",
"source",
}
)
_BLOCKED_COMMANDS_WIN = frozenset(
@ -221,35 +238,67 @@ def _build_safe_env(workdir: str) -> dict[str, str]:
def _sandbox_preexec():
"""Pre-exec hook: drop privilege escalation ability and set resource limits.
"""Best-effort sandbox setup for sandboxed subprocesses.
On Linux, applies PR_SET_NO_NEW_PRIVS so sudo/su/pkexec fail at the
kernel level. On Linux and macOS, sets RLIMIT_FSIZE.
No-op on Windows (use creationflags instead).
Note: RLIMIT_NPROC is intentionally NOT set because Linux enforces it
per real UID, not per process tree, so it would starve the Studio
server and other sessions sharing the same user account.
All modules and handles are resolved at import time (module level) so
this function does not trigger Python imports in the forked child,
avoiding potential deadlocks in multi-threaded servers.
Modules are resolved at import time so the forked child runs no imports.
"""
try:
os.setsid()
except OSError:
pass
try:
os.umask(0o077)
except OSError:
pass
if _libc is not None:
try:
# PR_SET_NO_NEW_PRIVS = 38, arg2 = 1 (enable)
_libc.prctl(38, 1, 0, 0, 0)
_libc.prctl(38, 1, 0, 0, 0) # PR_SET_NO_NEW_PRIVS
except (OSError, AttributeError):
pass # Not available (container, old kernel, etc.)
pass
try:
_libc.prctl(1, 9, 0, 0, 0) # PR_SET_PDEATHSIG = SIGKILL
except (OSError, AttributeError):
pass
# CLONE_NEWNET intentionally not applied: where userns is enabled it
# blocks all egress, including allowlisted hosts. Network policy is
# enforced by the AST host check and the bash blocklist.
if _resource is not None:
# RLIMIT_NPROC is per-real-UID, so the cap is well above normal usage.
try:
nproc = int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_NPROC", "10000"))
_resource.setrlimit(_resource.RLIMIT_NPROC, (nproc, nproc))
except (ValueError, OSError, AttributeError):
pass
try:
# Limit file size to 100MB (prevents disk filling)
_resource.setrlimit(
_resource.RLIMIT_FSIZE, (100 * 1024 * 1024, 100 * 1024 * 1024)
)
except (ValueError, OSError):
pass
try:
as_bytes = (
int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_AS_GB", "8"))
* 1024
* 1024
* 1024
)
_resource.setrlimit(_resource.RLIMIT_AS, (as_bytes, as_bytes))
except (ValueError, OSError, AttributeError):
pass
try:
cpu_s = int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_CPU_S", "600"))
_resource.setrlimit(_resource.RLIMIT_CPU, (cpu_s, cpu_s))
except (ValueError, OSError, AttributeError):
pass
try:
_resource.setrlimit(_resource.RLIMIT_NOFILE, (1024, 1024))
except (ValueError, OSError, AttributeError):
pass
def _get_shell_cmd(command: str) -> list[str]:
@ -265,25 +314,36 @@ def _get_shell_cmd(command: str) -> list[str]:
_workdirs: dict[str, str] = {}
# Non-matching session_ids collapse to ``_invalid`` to block cross-session escapes.
_SESSION_ID_RE = re.compile(r"\A[A-Za-z0-9_\-]{1,64}\Z")
def _get_workdir(session_id: str | None = None) -> str:
"""Return (and lazily create) a persistent working directory for tool execution."""
"""Return a per-session sandbox dir at mode 0o700."""
global _workdirs
key = session_id or "_default"
if key not in _workdirs or not os.path.isdir(_workdirs[key]):
home = os.path.expanduser("~")
sandbox_root = os.path.join(home, "studio_sandbox")
if session_id:
# Sanitize: strip path separators and parent-dir references
safe_id = os.path.basename(session_id.replace("..", ""))
if not safe_id:
safe_id = "_invalid"
workdir = os.path.join(sandbox_root, safe_id)
# Verify resolved path stays under sandbox root
if not os.path.realpath(workdir).startswith(os.path.realpath(sandbox_root)):
if session_id and _SESSION_ID_RE.match(session_id):
workdir = os.path.join(sandbox_root, session_id)
if not os.path.realpath(workdir).startswith(
os.path.realpath(sandbox_root) + os.sep
):
workdir = os.path.join(sandbox_root, "_invalid")
elif session_id:
workdir = os.path.join(sandbox_root, "_invalid")
else:
workdir = os.path.join(sandbox_root, "_default")
os.makedirs(workdir, exist_ok = True)
try:
os.chmod(sandbox_root, 0o700)
except OSError:
pass
try:
os.chmod(workdir, 0o700)
except OSError:
pass
_workdirs[key] = workdir
return _workdirs[key]
@ -932,7 +992,12 @@ def _check_signal_escape_patterns(code: str):
isinstance(shell_node, ast.Constant)
and shell_node.value is False
)
if shell_func in _STRING_SHELL_FUNCS or not shell_safe:
# Dynamic shell-exec args (chr/format/concat bypasses).
if (
shell_func in _STRING_SHELL_FUNCS
or shell_func in _SHELL_EXEC_FUNCS
or not shell_safe
):
def _is_safe_literal(n):
if _extract_string_from_node(n) is not None:
@ -1006,15 +1071,418 @@ def _check_signal_escape_patterns(code: str):
if visitor.imports_signal and not signal_tampering:
warnings.append("Code imports 'signal' module - review manually for safety")
# Static host policy: block metadata hosts and any literal host outside
# the trusted allowlist; uploads blocked regardless of host. Dynamic hosts
# are caught by the bash blocklist instead.
network_calls: list[dict] = []
sensitive_file_reads: list[dict] = []
_NETWORK_FQ_PREFIXES = (
"socket.socket",
"socket.create_connection",
"socket.getaddrinfo",
"urllib.request.urlopen",
"urllib.request.urlretrieve",
"urllib3.",
"requests.get",
"requests.post",
"requests.put",
"requests.delete",
"requests.patch",
"requests.head",
"requests.request",
"requests.Session",
"http.client.HTTPConnection",
"http.client.HTTPSConnection",
"httpx.get",
"httpx.post",
"httpx.put",
"httpx.patch",
"httpx.delete",
"httpx.request",
"httpx.Client",
"httpx.AsyncClient",
"aiohttp.ClientSession",
)
_UPLOAD_HTTP_METHODS = (
"requests.post",
"requests.put",
"requests.patch",
"requests.delete",
"requests.request",
"httpx.post",
"httpx.put",
"httpx.patch",
"httpx.delete",
"httpx.request",
"urllib.request.urlopen",
"urllib.request.Request",
)
_UPLOAD_HF_FQ = (
"huggingface_hub.upload_file",
"huggingface_hub.upload_folder",
"huggingface_hub.upload_large_folder",
"huggingface_hub.create_commit",
)
_UPLOAD_HF_METHODS = frozenset(
{
"upload_file",
"upload_folder",
"upload_large_folder",
"create_commit",
}
)
# Cloud-metadata / link-local hosts.
_METADATA_HOST_LITERALS = {
"169.254.169.254",
"fd00:ec2::254",
"metadata.google.internal",
"metadata",
"metadata.tencentyun.com",
"100.100.100.200",
"100.100.100.110",
"169.254.170.2",
"169.254.170.23",
}
_METADATA_HOST_PREFIXES = (
"169.254.",
"100.64.",
)
# Allowlist kept explicit so each entry is auditable.
_TRUSTED_PUBLIC_HOST_LITERALS = frozenset(
{
# search
"www.google.com",
"google.com",
"www.bing.com",
"bing.com",
"duckduckgo.com",
"html.duckduckgo.com",
# encyclopedic / reference
"wikipedia.org",
"www.wikipedia.org",
"wikimedia.org",
"www.wikimedia.org",
"wikidata.org",
"www.wikidata.org",
"commons.wikimedia.org",
"www.britannica.com",
"openlibrary.org",
"www.openstreetmap.org",
# ML / dev / data
"huggingface.co",
"hf.co",
"github.com",
"api.github.com",
"raw.githubusercontent.com",
"gist.github.com",
"docs.github.com",
"pypi.org",
"files.pythonhosted.org",
"www.npmjs.com",
"registry.npmjs.org",
"crates.io",
"static.crates.io",
# docs
"docs.python.org",
"python.org",
"www.python.org",
"developer.mozilla.org",
"developer.apple.com",
"learn.microsoft.com",
"docs.docker.com",
"pytorch.org",
"docs.pytorch.org",
"tensorflow.org",
"www.tensorflow.org",
"numpy.org",
"pandas.pydata.org",
"scipy.org",
"scikit-learn.org",
"matplotlib.org",
"fastapi.tiangolo.com",
"starlette.io",
# academic
"arxiv.org",
"export.arxiv.org",
"scholar.google.com",
"openreview.net",
"semanticscholar.org",
"www.semanticscholar.org",
"biorxiv.org",
"www.biorxiv.org",
"medrxiv.org",
"www.medrxiv.org",
"pubmed.ncbi.nlm.nih.gov",
"www.ncbi.nlm.nih.gov",
# Q&A / community
"stackoverflow.com",
"stackexchange.com",
"askubuntu.com",
"superuser.com",
"serverfault.com",
# standards
"www.w3.org",
"tools.ietf.org",
"datatracker.ietf.org",
"www.rfc-editor.org",
# reputable news
"www.bbc.com",
"www.bbc.co.uk",
"www.reuters.com",
"apnews.com",
"www.nature.com",
"www.science.org",
# government / open data
"data.gov",
"catalog.data.gov",
"www.census.gov",
"www.nasa.gov",
"data.nasa.gov",
"www.cdc.gov",
"www.nih.gov",
"www.who.int",
# weather / time
"api.weather.gov",
"worldtimeapi.org",
}
)
_TRUSTED_PUBLIC_HOST_SUFFIXES = (
".wikipedia.org",
".wikimedia.org",
".wiktionary.org",
".wikibooks.org",
".wikiquote.org",
".wikisource.org",
".wikiversity.org",
".wikivoyage.org",
".stackexchange.com",
".hf.co",
".huggingface.co",
".githubusercontent.com",
".github.io",
".arxiv.org",
".readthedocs.io",
".readthedocs.org",
)
_SENSITIVE_FILE_PREFIXES = (
"/etc/passwd",
"/etc/shadow",
"/etc/sudoers",
"/etc/ssh/",
)
_SENSITIVE_FILE_RE = re.compile(
r"^/proc/(?:self|\d+)/(?:environ|cmdline|task/\d+/environ)$"
)
def _normalize_host(host: str) -> str:
if not host:
return ""
h = host.strip().lower().rstrip(".")
if "@" in h:
h = h.split("@", 1)[1]
if h.startswith("[") and "]" in h:
h = h[1 : h.index("]")]
elif h.count(":") == 1:
h = h.split(":", 1)[0]
return h
def _is_metadata_host(host: str) -> bool:
h = _normalize_host(host)
if not h:
return False
if h in _METADATA_HOST_LITERALS:
return True
if any(h.startswith(p) for p in _METADATA_HOST_PREFIXES):
return True
return False
def _is_trusted_host(host: str) -> bool:
h = _normalize_host(host)
if not h:
return False
if h in _TRUSTED_PUBLIC_HOST_LITERALS:
return True
return any(h.endswith(s) for s in _TRUSTED_PUBLIC_HOST_SUFFIXES)
def _call_is_upload_shape(node: ast.Call, fq: str) -> bool:
"""True for statically obvious upload shapes (files=, data=open(), bytes literal)."""
if fq in _UPLOAD_HF_FQ:
return True
if fq not in _UPLOAD_HTTP_METHODS:
return False
for kw in node.keywords or []:
if kw.arg == "files":
return True
if kw.arg == "data":
v = kw.value
if (
isinstance(v, ast.Call)
and isinstance(v.func, ast.Name)
and v.func.id == "open"
):
return True
if isinstance(v, ast.Constant) and isinstance(
v.value, (bytes, bytearray)
):
return True
return False
def _method_call_is_hf_upload(node: ast.Call) -> bool:
"""True for HfApi upload method names on any receiver."""
return (
isinstance(node.func, ast.Attribute)
and node.func.attr in _UPLOAD_HF_METHODS
)
class NetworkAndIoVisitor(ast.NodeVisitor):
def visit_Call(self, node):
parts: list[str] = []
cur = node.func
while isinstance(cur, ast.Attribute):
parts.insert(0, cur.attr)
cur = cur.value
if isinstance(cur, ast.Name):
parts.insert(0, cur.id)
fq = ".".join(parts) if parts else ""
if _method_call_is_hf_upload(node):
network_calls.append(
{
"type": "upload_blocked",
"line": getattr(node, "lineno", -1),
"description": ("Blocked: file upload disallowed in sandbox"),
}
)
# Direct sock.connect((host, port)) bypasses the FQ-prefix branch below.
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "connect"
and node.args
):
a0 = node.args[0]
host_lit = None
if isinstance(a0, ast.Tuple) and a0.elts:
e0 = a0.elts[0]
if isinstance(e0, ast.Constant) and isinstance(e0.value, str):
host_lit = e0.value
elif isinstance(a0, ast.Constant) and isinstance(a0.value, str):
host_lit = a0.value
if host_lit:
if _is_metadata_host(host_lit):
network_calls.append(
{
"type": "metadata_host_blocked",
"line": getattr(node, "lineno", -1),
"description": "Blocked: cloud-metadata host",
}
)
elif not _is_trusted_host(host_lit):
network_calls.append(
{
"type": "untrusted_host_blocked",
"line": getattr(node, "lineno", -1),
"description": (
"Blocked: host not in sandbox allowlist; "
"use an allowed informational source"
),
}
)
if fq and any(fq.startswith(p) for p in _NETWORK_FQ_PREFIXES):
# 1) Upload-shape check (host-independent).
if _call_is_upload_shape(node, fq):
network_calls.append(
{
"type": "upload_blocked",
"line": getattr(node, "lineno", -1),
"description": (
"Blocked: file upload disallowed in sandbox"
),
}
)
# 2) Extract literal host (URL string or (host, port) tuple).
host_arg = None
url_arg = None
if node.args:
a0 = node.args[0]
if isinstance(a0, ast.Constant) and isinstance(a0.value, str):
url_arg = a0.value
elif isinstance(a0, ast.Tuple) and a0.elts:
e0 = a0.elts[0]
if isinstance(e0, ast.Constant) and isinstance(e0.value, str):
host_arg = e0.value
if url_arg and host_arg is None:
m = re.match(r"^\w+://([^/?#]+)", url_arg)
if m:
host_arg = m.group(1)
if host_arg:
if _is_metadata_host(host_arg):
network_calls.append(
{
"type": "metadata_host_blocked",
"line": getattr(node, "lineno", -1),
"description": "Blocked: cloud-metadata host",
}
)
elif not _is_trusted_host(host_arg):
network_calls.append(
{
"type": "untrusted_host_blocked",
"line": getattr(node, "lineno", -1),
"description": (
"Blocked: host not in sandbox allowlist; "
"use an allowed informational source"
),
}
)
is_open_call = (
(isinstance(node.func, ast.Name) and node.func.id == "open")
or fq in ("io.open", "pathlib.Path.open")
or fq.endswith(".open")
)
if is_open_call and node.args:
a0 = node.args[0]
path_lit = None
if isinstance(a0, ast.Constant) and isinstance(a0.value, str):
path_lit = a0.value
if path_lit:
flagged = False
if any(path_lit.startswith(p) for p in _SENSITIVE_FILE_PREFIXES):
flagged = True
elif _SENSITIVE_FILE_RE.match(path_lit):
flagged = True
if flagged:
sensitive_file_reads.append(
{
"type": "sensitive_file_read",
"line": getattr(node, "lineno", -1),
"description": (
f"open({path_lit!r}) targets a host identity / "
"credential file; sandboxed code may not read it"
),
}
)
self.generic_visit(node)
NetworkAndIoVisitor().visit(tree)
is_safe = (
len(signal_tampering) == 0
and len(exception_catching) == 0
and len(shell_escapes) == 0
and len(network_calls) == 0
and len(sensitive_file_reads) == 0
)
return is_safe, {
"signal_tampering": signal_tampering,
"exception_catching": exception_catching,
"shell_escapes": shell_escapes,
"network_calls": network_calls,
"sensitive_file_reads": sensitive_file_reads,
"warnings": warnings,
}
@ -1041,7 +1509,21 @@ def _check_code_safety(code: str) -> str | None:
exception_reasons = [
item.get("description", "") for item in info.get("exception_catching", [])
]
all_reasons = [r for r in reasons + shell_reasons + exception_reasons if r]
network_reasons = [
item.get("description", "") for item in info.get("network_calls", [])
]
file_reasons = [
item.get("description", "") for item in info.get("sensitive_file_reads", [])
]
all_reasons = [
r
for r in reasons
+ shell_reasons
+ exception_reasons
+ network_reasons
+ file_reasons
if r
]
if all_reasons:
return (
f"Error: unsafe code detected ({'; '.join(all_reasons)}). "
@ -1051,11 +1533,31 @@ def _check_code_safety(code: str) -> str | None:
return None
def _kill_process_tree(proc) -> None:
"""SIGKILL the setsid process group; fall back to single-pid kill."""
if proc.poll() is not None:
return
try:
pgid = os.getpgid(proc.pid)
except (ProcessLookupError, PermissionError):
pgid = None
if pgid is not None:
try:
os.killpg(pgid, signal.SIGKILL)
return
except (ProcessLookupError, PermissionError):
pass
try:
proc.kill()
except (ProcessLookupError, PermissionError):
pass
def _cancel_watcher(proc, cancel_event, poll_interval = 0.2):
"""Daemon thread that kills a process when cancel_event is set."""
while proc.poll() is None:
if cancel_event is not None and cancel_event.is_set():
proc.kill()
_kill_process_tree(proc)
return
cancel_event.wait(poll_interval) if cancel_event else None
@ -1126,8 +1628,11 @@ def _python_exec(
try:
output, _ = proc.communicate(timeout = timeout)
except subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
_kill_process_tree(proc)
try:
proc.communicate(timeout = 5)
except subprocess.TimeoutExpired:
pass
return _truncate(f"Execution timed out after {timeout} seconds.")
if cancel_event is not None and cancel_event.is_set():
@ -1211,8 +1716,11 @@ def _bash_exec(
try:
output, _ = proc.communicate(timeout = timeout)
except subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
_kill_process_tree(proc)
try:
proc.communicate(timeout = 5)
except subprocess.TimeoutExpired:
pass
return _truncate(f"Execution timed out after {timeout} seconds.")
if cancel_event is not None and cancel_event.is_set():

View file

@ -17,7 +17,9 @@ Pattern follows core/data_recipe/jobs/manager.py.
import json as _json
import math
import multiprocessing as mp
import os
import queue
import shutil
import threading
import time
import structlog
@ -33,9 +35,54 @@ from utils.native_path_leases import (
native_path_secret_removed_for_child_start,
run_without_native_path_secret,
)
from utils.paths import outputs_root
logger = get_logger(__name__)
def _cleanup_cancelled_checkpoints(output_dir: str | os.PathLike) -> None:
"""Remove ``checkpoint-<int>`` subdirs after a cancelled run.
Only paths whose realpath is under outputs_root are touched."""
out = Path(output_dir)
if not out.exists():
return
try:
out_real = out.resolve()
out_root_real = Path(outputs_root()).resolve()
except OSError:
return
try:
out_real.relative_to(out_root_real)
except ValueError:
# Refuse to delete anything outside the configured outputs root.
logger.warning(
"Skipping checkpoint cleanup - %s is not under outputs_root %s",
out_real,
out_root_real,
)
return
removed = 0
for entry in out.iterdir() if out.is_dir() else []:
if not entry.is_dir():
continue
name = entry.name
if not name.startswith("checkpoint-"):
continue
tail = name[len("checkpoint-") :]
if not tail.isdigit():
continue
try:
shutil.rmtree(entry, ignore_errors = False)
removed += 1
except OSError as exc:
logger.warning("Could not remove %s: %s", entry, exc)
logger.info(
"Cancelled-run cleanup removed %d checkpoint dir(s) under %s",
removed,
out,
)
_CTX = mp.get_context("spawn")
# Plot styling constants
@ -316,6 +363,8 @@ class TrainingBackend:
)
self._proc.terminate()
proc = self._proc
cancelled = self._cancel_requested
output_dir = self._output_dir
if proc is not None:
proc.join(timeout = 5.0)
@ -328,6 +377,17 @@ class TrainingBackend:
if self._pump_thread is not None and self._pump_thread.is_alive():
self._pump_thread.join(timeout = 8.0)
# Drop checkpoint-* dirs on explicit cancel only; stop-and-save
# keeps its artifacts.
if cancelled and output_dir:
try:
_cleanup_cancelled_checkpoints(output_dir)
except Exception:
logger.exception(
"Failed to clean up cancelled-run checkpoints under %s",
output_dir,
)
def is_training_active(self) -> bool:
"""Check if training is currently active."""
with self._lock:

View file

@ -104,7 +104,7 @@ if os.getenv("ENVIRONMENT_TYPE", "production") == "production":
# warnings.filterwarnings("ignore", category=DeprecationWarning)
# warnings.filterwarnings("ignore", module="triton.*")
from fastapi import Depends, FastAPI, Request
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, HTMLResponse, Response
@ -260,6 +260,181 @@ logger = LogConfig.setup_logging(
app.add_middleware(LoggingMiddleware)
# Web-search favicons load from *.gstatic.com; everything else is same-origin.
from starlette.middleware.base import BaseHTTPMiddleware # noqa: E402
from starlette.requests import Request as _StarletteRequest # noqa: E402
_CSP_SCRIPT_NONCE_HEADER = "x-internal-script-nonce"
def _build_csp(script_nonce: "str | None" = None) -> str:
script_src = "script-src 'self'"
if script_nonce:
script_src += f" 'nonce-{script_nonce}'"
return (
"default-src 'self'; "
"img-src 'self' data: blob: https://t0.gstatic.com "
"https://t1.gstatic.com https://t2.gstatic.com "
"https://t3.gstatic.com; "
"connect-src 'self'; "
"style-src 'self' 'unsafe-inline'; "
f"{script_src}; "
"font-src 'self' data:; "
"frame-ancestors 'none'; "
"form-action 'self'; "
"base-uri 'self'"
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Set baseline security headers; splice per-response inline-script nonces into CSP."""
async def dispatch(self, request: _StarletteRequest, call_next):
response = await call_next(request)
# Strip the internal nonce hand-off header so it never reaches the client.
nonce = response.headers.get(_CSP_SCRIPT_NONCE_HEADER)
if nonce is not None:
del response.headers[_CSP_SCRIPT_NONCE_HEADER]
response.headers.setdefault("Content-Security-Policy", _build_csp(nonce))
response.headers.setdefault("X-Frame-Options", "DENY")
response.headers.setdefault("X-Content-Type-Options", "nosniff")
response.headers.setdefault("Referrer-Policy", "no-referrer")
response.headers.setdefault(
"Permissions-Policy",
"camera=(), microphone=(), geolocation=(), interest-cohort=()",
)
response.headers["server"] = "unsloth-studio"
return response
app.add_middleware(SecurityHeadersMiddleware)
# Cap upload body on protected POSTs; default 500 MB, env-tunable.
import json as _json_for_413 # noqa: E402
_MAX_BODY_BYTES = int(os.environ.get("UNSLOTH_STUDIO_MAX_BODY_MB", "500")) * 1024 * 1024
_BODY_PROTECTED_PREFIXES = (
"/v1/chat/completions",
"/v1/completions",
"/api/inference",
"/api/data-recipe",
"/api/datasets",
"/api/train",
"/api/export",
)
async def _send_413(send, total_bytes: int) -> None:
payload = _json_for_413.dumps(
{
"detail": (
f"Request body too large "
f"({total_bytes:,} bytes; max {_MAX_BODY_BYTES:,})."
)
},
).encode("utf-8")
await send(
{
"type": "http.response.start",
"status": 413,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(payload)).encode("ascii")),
],
}
)
await send({"type": "http.response.body", "body": payload, "more_body": False})
class MaxBodyMiddleware:
"""Reject oversized bodies on protected POST/PUT/PATCH; raw ASGI so chunked uploads cannot bypass the cap."""
def __init__(self, app, max_bytes: int, protected_prefixes: tuple):
self.app = app
self.max_bytes = max_bytes
self.protected_prefixes = protected_prefixes
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
method = scope.get("method", "").upper()
path = scope.get("path", "")
if method not in ("POST", "PUT", "PATCH") or not any(
path.startswith(p) for p in self.protected_prefixes
):
await self.app(scope, receive, send)
return
declared = None
for name, value in scope.get("headers", []):
if name == b"content-length":
try:
declared = int(value.decode("latin-1"))
except (ValueError, UnicodeDecodeError):
declared = None
break
if declared is not None and declared > self.max_bytes:
await _send_413(send, declared)
return
chunks: list = []
total = 0
while True:
msg = await receive()
mtype = msg.get("type")
if mtype == "http.disconnect":
return
if mtype != "http.request":
# Mid-stream unexpected frame: forwarding would corrupt downstream.
return
body = msg.get("body", b"") or b""
if body:
total += len(body)
if total > self.max_bytes:
await _send_413(send, total)
return
chunks.append(body)
if not msg.get("more_body", False):
break
replayed = {"sent": False}
async def replay_receive():
if not replayed["sent"]:
replayed["sent"] = True
return {
"type": "http.request",
"body": b"".join(chunks),
"more_body": False,
}
# After replay, fall through so http.disconnect still propagates.
return await receive()
await self.app(scope, replay_receive, send)
app.add_middleware(
MaxBodyMiddleware,
max_bytes = _MAX_BODY_BYTES,
protected_prefixes = _BODY_PROTECTED_PREFIXES,
)
from starlette.responses import RedirectResponse as _RedirectResponse # noqa: E402
@app.get("/recipes", include_in_schema = False)
@app.get("/recipes/{rest:path}", include_in_schema = False)
async def _recipes_redirect(rest: str = ""):
target = "/data-recipes" + (("/" + rest) if rest else "")
return _RedirectResponse(url = target, status_code = 308)
# CORS middleware
_api_only = os.environ.get("UNSLOTH_API_ONLY") == "1"
_cors_origins = ["*"]
@ -311,14 +486,35 @@ app.include_router(
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
platform_map = {"darwin": "mac", "win32": "windows", "linux": "linux"}
device_type = platform_map.get(sys.platform, sys.platform)
return {
async def health_check(request: Request):
"""Liveness only; full diagnostic dict gated on a valid bearer."""
minimal = {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
}
auth = request.headers.get("authorization", "")
if not auth.lower().startswith("bearer "):
return minimal
try:
from auth.authentication import get_current_subject as _gcs
from fastapi.security import HTTPAuthorizationCredentials
creds = HTTPAuthorizationCredentials(
scheme = "Bearer", credentials = auth.split(" ", 1)[1]
)
# Must await: a bare coroutine is truthy and would skip the auth check.
subject = await _gcs(creds)
except HTTPException:
return minimal
except Exception:
return minimal
if not subject:
return minimal
platform_map = {"darwin": "mac", "win32": "windows", "linux": "linux"}
device_type = platform_map.get(sys.platform, sys.platform)
return {
**minimal,
"service": "Unsloth UI Backend",
"version": UNSLOTH_VERSION,
"studio_version": STUDIO_VERSION,
@ -328,9 +524,7 @@ async def health_check():
"desktop_manageability_version": 1,
"supports_desktop_auth": True,
"supports_desktop_backend_ownership": True,
# why: launchers compare against an install-time hash so a sibling
# Studio on the same port is rejected; hex digest avoids leaking the
# raw install path on -H 0.0.0.0.
# Hex digest of the install path; launchers reject sibling Studios on the same port.
"studio_root_id": _studio_root_id(),
"native_path_leases_supported": native_path_leases_supported(),
**({"desktop_owner": owner} if (owner := _desktop_owner()) else {}),
@ -463,21 +657,22 @@ def _strip_crossorigin(html_bytes: bytes) -> bytes:
return html.encode("utf-8")
def _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes:
"""Inject bootstrap credentials into HTML when password change is required.
def _inject_bootstrap(html_bytes: bytes, app: FastAPI):
"""Inject bootstrap credentials when password change is pending.
The script tag is only injected while the default admin account still
has ``must_change_password=True``. Once the user changes the password
the HTML is served clean no credentials leak.
Returns ``(html_bytes, script_nonce_or_None)``. Callers must forward
the nonce via ``_CSP_SCRIPT_NONCE_HEADER`` so the inline script is
not blocked by CSP.
"""
import json as _json
import secrets as _secrets
if not storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME):
return html_bytes
return html_bytes, None
bootstrap_pw = getattr(app.state, "bootstrap_password", None)
if not bootstrap_pw:
return html_bytes
return html_bytes, None
payload = _json.dumps(
{
@ -485,10 +680,11 @@ def _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes:
"password": bootstrap_pw,
}
)
tag = f"<script>window.__UNSLOTH_BOOTSTRAP__={payload}</script>"
nonce = _secrets.token_urlsafe(16)
tag = f'<script nonce="{nonce}">window.__UNSLOTH_BOOTSTRAP__={payload}</script>'
html = html_bytes.decode("utf-8")
html = html.replace("</head>", f"{tag}</head>", 1)
return html.encode("utf-8")
return html.encode("utf-8"), nonce
def setup_frontend(app: FastAPI, build_path: Path):
@ -501,17 +697,23 @@ def setup_frontend(app: FastAPI, build_path: Path):
if assets_dir.exists():
app.mount("/assets", StaticFiles(directory = assets_dir), name = "assets")
@app.get("/")
async def serve_root():
def _build_index_response() -> Response:
content = (build_path / "index.html").read_bytes()
content = _strip_crossorigin(content)
content = _inject_bootstrap(content, app)
content, nonce = _inject_bootstrap(content, app)
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"}
if nonce:
headers[_CSP_SCRIPT_NONCE_HEADER] = nonce
return Response(
content = content,
media_type = "text/html",
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"},
headers = headers,
)
@app.get("/")
async def serve_root():
return _build_index_response()
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
@ -527,13 +729,6 @@ def setup_frontend(app: FastAPI, build_path: Path):
return FileResponse(file_path)
# Serve index.html as bytes — avoids Content-Length mismatch
content = (build_path / "index.html").read_bytes()
content = _strip_crossorigin(content)
content = _inject_bootstrap(content, app)
return Response(
content = content,
media_type = "text/html",
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"},
)
return _build_index_response()
return True

View file

@ -37,7 +37,10 @@ class AuthStatusResponse(BaseModel):
initialized: bool = Field(
..., description = "True if the auth database contains a login user"
)
default_username: str = Field(..., description = "Default seeded admin username")
default_username: str = Field(
"unsloth",
description = "Default admin username for first-boot UI prefill.",
)
requires_password_change: bool = Field(
...,
description = "True if the seeded admin must still change the default password",

View file

@ -5,10 +5,36 @@
Pydantic schemas for Export API.
"""
from pydantic import BaseModel, Field
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Literal, Dict, Any
def _validate_save_directory(value: str) -> str:
"""Reject save_directory values that escape the export root."""
if value is None:
raise ValueError("save_directory is required")
raw = str(value).strip()
if not raw:
raise ValueError("save_directory must not be empty")
if "\x00" in raw:
raise ValueError("save_directory may not contain null bytes")
if any(ch in raw for ch in ("\r", "\n")):
raise ValueError("save_directory may not contain control characters")
if len(raw) > 255:
raise ValueError("save_directory must be <= 255 characters")
path = Path(raw).expanduser()
if path.is_absolute():
raise ValueError(
"save_directory must be a name or relative path under the "
"export root; absolute paths are rejected"
)
if ".." in path.parts:
raise ValueError("save_directory may not contain '..' segments")
return raw
class LoadCheckpointRequest(BaseModel):
"""Request for loading a checkpoint into the export backend."""
@ -64,6 +90,12 @@ class ExportCommonOptions(BaseModel):
...,
description = "Local directory where the exported artifacts will be written",
)
@field_validator("save_directory", mode = "before")
@classmethod
def _check_save_directory(cls, v):
return _validate_save_directory(v)
push_to_hub: bool = Field(
False,
description = "If True, also push the exported model to the Hugging Face Hub",
@ -108,6 +140,12 @@ class ExportGGUFRequest(BaseModel):
...,
description = "Directory where GGUF files will be saved",
)
@field_validator("save_directory", mode = "before")
@classmethod
def _check_save_directory(cls, v):
return _validate_save_directory(v)
quantization_method: str = Field(
"Q4_K_M",
description = 'GGUF quantization method (e.g. "Q4_K_M")',

View file

@ -425,14 +425,6 @@ class ChatMessage(BaseModel):
@model_validator(mode = "after")
def _validate_role_shape(self) -> "ChatMessage":
# Enforce the per-role OpenAI spec shape at the request boundary.
# Without this, malformed messages (e.g. user entries with no
# content, tool_calls on a user/system role, role="tool" without
# tool_call_id) would be silently forwarded to llama-server via
# the passthrough path, surfacing as opaque upstream errors or
# broken tool-call reconciliation downstream.
# Tool-call metadata must appear only on the appropriate role.
if self.tool_calls is not None and self.role != "assistant":
raise ValueError('"tool_calls" is only valid on role="assistant" messages.')
if self.tool_call_id is not None and self.role != "tool":
@ -440,23 +432,20 @@ class ChatMessage(BaseModel):
if self.name is not None and self.role != "tool":
raise ValueError('"name" is only valid on role="tool" messages.')
# Per-role content requirements. OpenAI-compatible clients may send
# ``content=""`` for image-only turns when the image travels in a
# companion field such as Studio's ``image_base64`` extension, so treat
# empty strings as present content for user/system messages.
if self.role == "tool":
if not self.tool_call_id:
raise ValueError(
'role="tool" messages require "tool_call_id" per the OpenAI spec.'
)
# Frontend's second-round POST drops the streamed id;
# synthesise one so the request round-trips.
import secrets as _secrets
self.tool_call_id = f"call_{_secrets.token_hex(8)}"
if not self.content:
raise ValueError('role="tool" messages require non-empty "content".')
elif self.role == "assistant":
# Assistant messages may omit content when tool_calls is set.
if not self.content and not self.tool_calls:
raise ValueError(
'role="assistant" messages require either "content" or "tool_calls".'
)
# Tolerate the post-Stop empty-assistant sentinel by
# collapsing content="" to None.
if (self.content == "" or self.content == []) and not self.tool_calls:
self.content = None
else: # "user" | "system"
if self.content is None or self.content == []:
raise ValueError(f'role="{self.role}" messages require "content".')

View file

@ -5,10 +5,43 @@
Pydantic schemas for Training API
"""
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing import Any, Optional, List, Dict, Literal
_MAX_BATCH_SIZE = 4096
_MAX_GRAD_ACCUM = 4096
_MAX_STEPS = 1_000_000
_MAX_EPOCHS = 1000
# 2M is a sanity cap; host RAM runs out long before this.
_MAX_SEQ_LENGTH = 2_000_000
_MAX_LR_VALUE = 1.0
_MAX_LORA_R = 16_384
_MAX_LORA_ALPHA = 32_768
def _parse_lr(v: Any) -> float:
"""Parse learning_rate as a positive float strictly below _MAX_LR_VALUE."""
if v is None:
raise ValueError("learning_rate is required")
if isinstance(v, bool):
raise ValueError("learning_rate must be a number, not a bool")
try:
lr = float(v)
except (TypeError, ValueError):
raise ValueError(f"learning_rate must be parseable as float (got {v!r})")
if not (lr > 0.0):
raise ValueError(
f"learning_rate must be > 0 (got {lr!r}); " "typical range is 1e-6 .. 1e-3"
)
if lr >= _MAX_LR_VALUE:
raise ValueError(
f"learning_rate must be < 1.0 (got {lr!r}); "
"values that large always diverge training"
)
return lr
class TrainingStartRequest(BaseModel):
"""Request schema for starting training"""
@ -64,6 +97,147 @@ class TrainingStartRequest(BaseModel):
values.setdefault("train_split", values.pop("split"))
return values
@field_validator("learning_rate", mode = "before")
@classmethod
def _check_learning_rate(cls, v):
# Stringify because downstream call sites float() it themselves.
lr = _parse_lr(v)
return str(lr)
@field_validator("batch_size")
@classmethod
def _check_batch_size(cls, v: int) -> int:
if v is None:
raise ValueError("batch_size is required")
if v < 1 or v > _MAX_BATCH_SIZE:
raise ValueError(
f"batch_size must be in [1, {_MAX_BATCH_SIZE}] (got {v!r})"
)
return v
@field_validator("gradient_accumulation_steps")
@classmethod
def _check_grad_accum(cls, v: int) -> int:
if v is None:
return 1
if v < 1 or v > _MAX_GRAD_ACCUM:
raise ValueError(
f"gradient_accumulation_steps must be in [1, {_MAX_GRAD_ACCUM}] "
f"(got {v!r})"
)
return v
@field_validator("num_epochs")
@classmethod
def _check_num_epochs(cls, v: int) -> int:
if v is None:
return 1
if v < 1 or v > _MAX_EPOCHS:
raise ValueError(f"num_epochs must be in [1, {_MAX_EPOCHS}] (got {v!r})")
return v
@field_validator("max_steps")
@classmethod
def _check_max_steps(cls, v):
if v is None:
return v
if not isinstance(v, int) or v < 1 or v > _MAX_STEPS:
raise ValueError(
f"max_steps must be a positive int <= {_MAX_STEPS} (got {v!r})"
)
return v
@field_validator("max_seq_length")
@classmethod
def _check_max_seq_length(cls, v: int) -> int:
if v is None or v < 1 or v > _MAX_SEQ_LENGTH:
raise ValueError(
f"max_seq_length must be in [1, {_MAX_SEQ_LENGTH}] (got {v!r})"
)
return v
@field_validator("warmup_steps")
@classmethod
def _check_warmup_steps(cls, v):
if v is None:
return v
if not isinstance(v, int) or v < 0 or v > _MAX_STEPS:
raise ValueError(
f"warmup_steps must be a non-negative int <= {_MAX_STEPS} "
f"(got {v!r})"
)
return v
@field_validator("warmup_ratio")
@classmethod
def _check_warmup_ratio(cls, v):
if v is None:
return v
try:
r = float(v)
except (TypeError, ValueError):
raise ValueError(f"warmup_ratio must be a number (got {v!r})")
if not (0.0 <= r <= 1.0):
raise ValueError(f"warmup_ratio must be in [0.0, 1.0] (got {r!r})")
return r
@field_validator("save_steps")
@classmethod
def _check_save_steps(cls, v: int) -> int:
if v is None:
return 100
if v < 0 or v > _MAX_STEPS:
raise ValueError(f"save_steps must be in [0, {_MAX_STEPS}] (got {v!r})")
return v
@field_validator("weight_decay")
@classmethod
def _check_weight_decay(cls, v: float) -> float:
if v is None:
return 0.0
try:
wd = float(v)
except (TypeError, ValueError):
raise ValueError(f"weight_decay must be a number (got {v!r})")
if wd < 0 or wd > 10.0:
raise ValueError(
f"weight_decay must be in [0, 10] (got {wd!r}); typical 0..0.1"
)
return wd
@field_validator("lora_r")
@classmethod
def _check_lora_r(cls, v: int) -> int:
if v is None:
return 16
if v < 1 or v > _MAX_LORA_R:
raise ValueError(f"lora_r must be in [1, {_MAX_LORA_R}] (got {v!r})")
return v
@field_validator("lora_alpha")
@classmethod
def _check_lora_alpha(cls, v: int) -> int:
if v is None:
return 16
if v < 1 or v > _MAX_LORA_ALPHA:
raise ValueError(
f"lora_alpha must be in [1, {_MAX_LORA_ALPHA}] (got {v!r})"
)
return v
@field_validator("lora_dropout")
@classmethod
def _check_lora_dropout(cls, v: float) -> float:
if v is None:
return 0.0
try:
d = float(v)
except (TypeError, ValueError):
raise ValueError(f"lora_dropout must be a number (got {v!r})")
if not (0.0 <= d < 1.0):
raise ValueError(f"lora_dropout must be in [0.0, 1.0) (got {d!r})")
return d
custom_format_mapping: Optional[Dict[str, Any]] = Field(
None,
description = (

View file

@ -5,8 +5,11 @@
Authentication API routes
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
import threading
import time
from collections import deque
from datetime import datetime, timedelta, timezone
from models.auth import (
@ -33,14 +36,52 @@ from auth.authentication import (
router = APIRouter()
# In-memory per-IP login rate limiter; multi-process deployment needs a shared store.
_LOGIN_BUCKETS: dict[str, deque] = {}
_LOGIN_BUCKETS_LOCK = threading.Lock()
_LOGIN_WINDOW_SECONDS = 60.0
_LOGIN_MAX_FAILS = 5
_LOGIN_LOCKOUT_SECONDS = 60
def _client_key(request: Request | None) -> str:
if request is None or request.client is None:
return "_unknown"
return request.client.host or "_unknown"
def _record_login_failure(ip: str) -> int:
now = time.monotonic()
with _LOGIN_BUCKETS_LOCK:
bucket = _LOGIN_BUCKETS.setdefault(ip, deque())
while bucket and now - bucket[0] > _LOGIN_WINDOW_SECONDS:
bucket.popleft()
bucket.append(now)
return len(bucket)
def _login_blocked(ip: str) -> int:
"""Return seconds until the next attempt is allowed, or 0."""
now = time.monotonic()
with _LOGIN_BUCKETS_LOCK:
bucket = _LOGIN_BUCKETS.get(ip)
if not bucket:
return 0
while bucket and now - bucket[0] > _LOGIN_WINDOW_SECONDS:
bucket.popleft()
if len(bucket) >= _LOGIN_MAX_FAILS:
return max(1, int(_LOGIN_WINDOW_SECONDS - (now - bucket[0])))
return 0
def _clear_login_bucket(ip: str) -> None:
with _LOGIN_BUCKETS_LOCK:
_LOGIN_BUCKETS.pop(ip, None)
@router.get("/status", response_model = AuthStatusResponse)
async def auth_status() -> AuthStatusResponse:
"""
Check whether auth has already been initialized.
- initialized = False -> frontend should wait for the seeded admin bootstrap.
- initialized = True -> frontend should show login or force the first password change.
"""
"""Auth initialization state; ``default_username`` is exposed for first-boot UI prefill only."""
return AuthStatusResponse(
initialized = storage.is_initialized(),
default_username = storage.DEFAULT_ADMIN_USERNAME,
@ -53,12 +94,23 @@ async def auth_status() -> AuthStatusResponse:
@router.post("/login", response_model = Token)
async def login(payload: AuthLoginRequest) -> Token:
"""
Login with username/password and receive access + refresh tokens.
"""
async def login(payload: AuthLoginRequest, request: Request) -> Token:
"""Login with username/password. Rate-limited per source IP."""
ip = _client_key(request)
blocked_for = _login_blocked(ip)
if blocked_for > 0:
raise HTTPException(
status_code = status.HTTP_429_TOO_MANY_REQUESTS,
detail = (
f"Too many failed login attempts from {ip}. "
f"Try again in {blocked_for} seconds."
),
headers = {"Retry-After": str(blocked_for)},
)
record = storage.get_user_and_secret(payload.username)
if record is None:
_record_login_failure(ip)
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.",
@ -66,11 +118,13 @@ async def login(payload: AuthLoginRequest) -> Token:
salt, pwd_hash, _jwt_secret, must_change_password = record
if not hashing.verify_password(payload.password, salt, pwd_hash):
_record_login_failure(ip)
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.",
)
_clear_login_bucket(ip)
access_token = create_access_token(subject = payload.username)
refresh_token = create_refresh_token(subject = payload.username)
return Token(
@ -81,6 +135,23 @@ async def login(payload: AuthLoginRequest) -> Token:
)
@router.post("/logout", status_code = status.HTTP_204_NO_CONTENT)
async def logout(
request: Request,
current_subject: str = Depends(get_current_subject_allow_password_change),
) -> Response:
"""Revoke refresh tokens for the subject; the access token is stateless and expires on its own."""
try:
storage.revoke_user_refresh_tokens(current_subject)
except Exception:
pass
try:
request.app.state.bootstrap_password = None
except AttributeError:
pass
return Response(status_code = status.HTTP_204_NO_CONTENT)
@router.post("/desktop-login", response_model = Token)
async def desktop_login(payload: DesktopLoginRequest) -> Token:
"""Exchange a local desktop secret for normal admin-subject tokens."""
@ -101,21 +172,20 @@ async def desktop_login(payload: DesktopLoginRequest) -> Token:
@router.post("/refresh", response_model = Token)
async def refresh(payload: RefreshTokenRequest) -> Token:
"""
Exchange a valid refresh token for a new access token.
The refresh token itself is reusable until it expires (7 days).
"""
new_access_token, username, is_desktop = refresh_access_token(payload.refresh_token)
if new_access_token is None or username is None:
"""Exchange a refresh token for a new access+refresh pair (single-use)."""
consumed = storage.consume_refresh_token(payload.refresh_token)
if consumed is None:
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail = "Invalid or expired refresh token",
)
username, is_desktop = consumed
new_access_token = create_access_token(subject = username, desktop = is_desktop)
new_refresh_token = create_refresh_token(subject = username, desktop = is_desktop)
return Token(
access_token = new_access_token,
refresh_token = payload.refresh_token,
refresh_token = new_refresh_token,
token_type = "bearer",
must_change_password = False
if is_desktop
@ -126,6 +196,7 @@ async def refresh(payload: RefreshTokenRequest) -> Token:
@router.post("/change-password", response_model = Token)
async def change_password(
payload: ChangePasswordRequest,
request: Request,
current_subject: str = Depends(get_current_subject_allow_password_change),
) -> Token:
"""Allow the authenticated user to replace the default password."""
@ -150,6 +221,10 @@ async def change_password(
storage.update_password(current_subject, payload.new_password)
storage.revoke_user_refresh_tokens(current_subject)
try:
request.app.state.bootstrap_password = None
except AttributeError:
pass
access_token = create_access_token(subject = current_subject)
refresh_token = create_refresh_token(subject = current_subject)
return Token(

View file

@ -7,6 +7,7 @@ Export API routes: checkpoint discovery and model export operations.
import asyncio
import json
import os
import sys
import time
from pathlib import Path
@ -184,14 +185,18 @@ async def get_export_status(
def _export_details(output_path: Optional[str]) -> Optional[Dict[str, Any]]:
"""Wrap the resolved on-disk export path into the details dict the
frontend reads to populate the Export Complete screen. Returns None
when the export had no local component (Hub-only push) so the
Pydantic field stays absent rather than ``{"output_path": null}``.
"""
"""Return the export path relative to exports_root so the install path is not leaked."""
if not output_path:
return None
return {"output_path": output_path}
try:
from utils.paths.storage_roots import exports_root
rel = os.path.relpath(output_path, exports_root())
if rel.startswith(".."):
rel = os.path.basename(output_path)
return {"output_path": rel}
except Exception:
return {"output_path": os.path.basename(output_path)}
@router.post("/export/merged", response_model = ExportOperationResponse)

View file

@ -1743,7 +1743,7 @@ async def openai_chat_completions(
try:
import base64 as _b64
from io import BytesIO as _BytesIO
from PIL import Image as _Image
from PIL import Image as _Image, UnidentifiedImageError as _UIE
raw = _b64.b64decode(image_b64)
# Normalize to RGB so PNG encoding succeeds regardless of
@ -1754,9 +1754,15 @@ async def openai_chat_completions(
buf = _BytesIO()
img.save(buf, format = "PNG")
image_b64 = _b64.b64encode(buf.getvalue()).decode("ascii")
except Exception as e:
except _UIE:
raise HTTPException(
status_code = 400, detail = f"Failed to process image: {e}"
status_code = 400,
detail = "Unsupported or corrupt image format.",
)
except Exception:
raise HTTPException(
status_code = 400,
detail = "Failed to process image.",
)
# Build message list with system prompt prepended
@ -3426,10 +3432,10 @@ def _normalize_anthropic_openai_images(
buf = io.BytesIO()
img.save(buf, format = "PNG")
png_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
except Exception as e:
except Exception:
raise HTTPException(
status_code = 400,
detail = f"Failed to process image: {e}",
detail = "Failed to process image.",
)
part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"}
@ -3465,6 +3471,7 @@ async def anthropic_messages(
[m.model_dump() for m in payload.messages],
payload.system,
)
openai_messages = _drop_empty_assistant_sentinels(openai_messages)
# Enforce vision guard + re-encode embedded images to PNG so the
# Anthropic endpoint matches the behavior of /v1/chat/completions.
@ -4190,6 +4197,19 @@ async def _anthropic_passthrough_non_streaming(
# =====================================================================
def _drop_empty_assistant_sentinels(messages: list[dict]) -> list[dict]:
"""Drop bare ``{"role":"assistant"}`` Stop-button sentinels; passthrough backends reject them."""
out: list[dict] = []
for m in messages:
if m.get("role") == "assistant":
has_content = bool(m.get("content"))
has_tool_calls = bool(m.get("tool_calls"))
if not has_content and not has_tool_calls:
continue
out.append(m)
return out
def _openai_messages_for_passthrough(payload) -> list[dict]:
"""Build OpenAI-format message dicts for the /v1/chat/completions
passthrough path.
@ -4206,7 +4226,9 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
``image_url`` content part so vision + function-calling requests work
transparently.
"""
messages = [m.model_dump(exclude_none = True) for m in payload.messages]
messages = _drop_empty_assistant_sentinels(
[m.model_dump(exclude_none = True) for m in payload.messages]
)
if not payload.image_base64:
return messages
@ -4221,10 +4243,10 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
buf = _BytesIO()
img.save(buf, format = "PNG")
png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii")
except Exception as e:
except Exception:
raise HTTPException(
status_code = 400,
detail = f"Failed to process image: {e}",
detail = "Failed to process image.",
)
data_url = f"data:image/png;base64,{png_b64}"

View file

@ -354,9 +354,14 @@ def run_server(
if getattr(self, "started", False) and not self.should_exit:
ready_event.set()
# Create the uvicorn server and expose it for signal handlers
# server_header=False suppresses uvicorn's "Server: uvicorn"; SecurityHeadersMiddleware sets its own.
config = uvicorn.Config(
app, host = host, port = port, log_level = "info", access_log = False
app,
host = host,
port = port,
log_level = "info",
access_log = False,
server_header = False,
)
_server = _ReadyServer(config)
_shutdown_event = Event()

View file

@ -227,6 +227,60 @@ def test_desktop_refresh_preserves_desktop_marker():
assert payload["desktop"] is True
def test_consume_refresh_token_second_call_returns_none():
"""Single-use rotation rejects the same token on a second consume."""
seed_user()
from datetime import datetime, timedelta, timezone
raw = secrets.token_urlsafe(48)
expires = (datetime.now(timezone.utc) + timedelta(days = 30)).isoformat()
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
first = storage.consume_refresh_token(raw)
assert first == (storage.DEFAULT_ADMIN_USERNAME, False)
second = storage.consume_refresh_token(raw)
assert second is None
def test_consume_refresh_token_concurrent_only_one_succeeds(tmp_path, monkeypatch):
"""64-thread pile-up against one token; DELETE RETURNING permits one winner."""
seed_user()
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
raw = secrets.token_urlsafe(48)
expires = (datetime.now(timezone.utc) + timedelta(days = 30)).isoformat()
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
workers = 64
def attempt(_idx: int):
try:
return storage.consume_refresh_token(raw)
except sqlite3.OperationalError:
# "database is locked" under heavy contention; treat as losing the race.
return None
with ThreadPoolExecutor(max_workers = workers) as pool:
results = list(pool.map(attempt, range(workers)))
successes = [r for r in results if r is not None]
assert (
len(successes) == 1
), f"expected exactly one consumer to win, got {len(successes)}"
assert successes[0] == (storage.DEFAULT_ADMIN_USERNAME, False)
def test_consume_refresh_token_expired_returns_none():
seed_user()
from datetime import datetime, timedelta, timezone
raw = secrets.token_urlsafe(48)
expires = (datetime.now(timezone.utc) - timedelta(hours = 1)).isoformat()
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
assert storage.consume_refresh_token(raw) is None
def test_desktop_session_uses_real_admin_identity_for_api_keys():
seed_user(must_change_password = True)
raw = storage.create_desktop_secret()
@ -392,7 +446,21 @@ def test_health_response_reports_desktop_capability_fields(monkeypatch):
monkeypatch.setattr(backend_main._hw_module, "CHAT_ONLY", False)
body = asyncio.run(backend_main.health_check())
seed_user()
from auth.authentication import create_access_token
token = create_access_token(storage.DEFAULT_ADMIN_USERNAME)
app = FastAPI()
app.add_api_route("/api/health", backend_main.health_check, methods = ["GET"])
client = TestClient(app)
response = client.get(
"/api/health",
headers = {"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
body = response.json()
assert body["desktop_protocol_version"] == 1
assert body["supports_desktop_auth"] is True

View file

@ -0,0 +1,269 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""Tests for MaxBodyMiddleware, SecurityHeadersMiddleware, and the /api/health auth gate."""
import asyncio
import importlib.util
import json
import os
import sys
from pathlib import Path
import pytest
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response
from fastapi.testclient import TestClient
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(_BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(_BACKEND_ROOT))
@pytest.fixture(scope = "module")
def main_module():
import main as _main # noqa: F401
return _main
# =====================================================================
# MaxBodyMiddleware
# =====================================================================
def _make_protected_app(max_bytes: int, main_module):
app = FastAPI()
app.add_middleware(
main_module.MaxBodyMiddleware,
max_bytes = max_bytes,
protected_prefixes = ("/v1/chat/completions", "/api/train"),
)
@app.post("/v1/chat/completions")
async def chat(payload: dict):
return {"ok": True, "n": len(payload.get("text", ""))}
@app.post("/api/other")
async def other(payload: dict):
return {"ok": True, "unprotected": True}
@app.get("/api/train/status")
async def status_get():
return {"ok": True, "get": True}
return app
class TestMaxBodyMiddleware:
def test_small_protected_body_passes(self, main_module):
app = _make_protected_app(1024, main_module)
c = TestClient(app)
r = c.post("/v1/chat/completions", json = {"text": "x" * 100})
assert r.status_code == 200
assert r.json()["n"] == 100
def test_large_declared_content_length_rejected(self, main_module):
app = _make_protected_app(1024, main_module)
c = TestClient(app)
r = c.post("/v1/chat/completions", json = {"text": "x" * 5000})
assert r.status_code == 413
assert "too large" in r.json()["detail"].lower()
def test_unprotected_prefix_passes_large_body(self, main_module):
app = _make_protected_app(1024, main_module)
c = TestClient(app)
r = c.post("/api/other", json = {"text": "x" * 5000})
assert r.status_code == 200
assert r.json()["unprotected"] is True
def test_chunked_upload_over_cap_rejected(self, main_module):
# Regression: declared-Content-Length-only check could be bypassed
# by chunked transfer-encoding.
app = _make_protected_app(1024, main_module)
c = TestClient(app)
def gen():
yield b'{"text":"'
yield b"x" * 800
yield b'"}'
yield b"\n" + b"y" * 500
r = c.post(
"/v1/chat/completions",
content = gen(),
headers = {"content-type": "application/json"},
)
assert r.status_code == 413
assert "too large" in r.json()["detail"].lower()
def test_chunked_upload_under_cap_passes(self, main_module):
app = _make_protected_app(1024, main_module)
c = TestClient(app)
def gen():
yield b'{"text":"'
yield b"x" * 50
yield b'"}'
r = c.post(
"/v1/chat/completions",
content = gen(),
headers = {"content-type": "application/json"},
)
assert r.status_code == 200
assert r.json()["n"] == 50
def test_get_not_subject_to_cap(self, main_module):
app = _make_protected_app(1024, main_module)
c = TestClient(app)
r = c.get("/api/train/status")
assert r.status_code == 200
# =====================================================================
# SecurityHeadersMiddleware / CSP
# =====================================================================
def _make_csp_app(main_module, attach_nonce: str | None = None):
app = FastAPI()
app.add_middleware(main_module.SecurityHeadersMiddleware)
@app.get("/plain")
async def plain():
return {"ok": True}
@app.get("/with-nonce")
async def with_nonce():
headers = {}
if attach_nonce:
headers[main_module._CSP_SCRIPT_NONCE_HEADER] = attach_nonce
return Response(
content = b"<html></html>",
media_type = "text/html",
headers = headers,
)
return app
class TestSecurityHeadersMiddleware:
def test_csp_has_no_unsafe_inline_for_script_src(self, main_module):
app = _make_csp_app(main_module)
c = TestClient(app)
r = c.get("/plain")
assert r.status_code == 200
csp = r.headers["content-security-policy"]
# Parse per-directive so style-src unsafe-inline does not false-match.
directives = {
chunk.strip().split(" ", 1)[0]: chunk.strip()
for chunk in csp.split(";")
if chunk.strip()
}
assert "script-src" in directives
assert "'unsafe-inline'" not in directives["script-src"]
# style-src keeps unsafe-inline for Vite-injected styles.
assert "'unsafe-inline'" in directives["style-src"]
def test_default_security_headers_present(self, main_module):
app = _make_csp_app(main_module)
c = TestClient(app)
r = c.get("/plain")
assert r.headers["x-frame-options"] == "DENY"
assert r.headers["x-content-type-options"] == "nosniff"
assert r.headers["referrer-policy"] == "no-referrer"
assert "camera=()" in r.headers["permissions-policy"]
assert r.headers["server"] == "unsloth-studio"
def test_internal_nonce_header_is_spliced_into_csp_and_stripped(self, main_module):
nonce = "test-nonce-abc"
app = _make_csp_app(main_module, attach_nonce = nonce)
c = TestClient(app)
r = c.get("/with-nonce")
csp = r.headers["content-security-policy"]
assert f"'nonce-{nonce}'" in csp
# Internal handoff header must not leak to clients.
assert main_module._CSP_SCRIPT_NONCE_HEADER not in {
k.lower() for k in r.headers.keys()
}
def test_build_csp_helper_shape(self, main_module):
plain = main_module._build_csp()
assert "script-src 'self';" in plain
assert "'unsafe-inline'" not in plain.split("script-src", 1)[1].split(";", 1)[0]
nonced = main_module._build_csp("XYZ")
assert "script-src 'self' 'nonce-XYZ';" in nonced
# =====================================================================
# /api/health auth gate
# =====================================================================
@pytest.fixture
def health_app(tmp_path, monkeypatch):
"""Mount /api/health on a fresh app against an isolated auth db."""
from auth import storage
monkeypatch.setattr(storage, "DB_PATH", tmp_path / "auth.db")
monkeypatch.setattr(storage, "_BOOTSTRAP_PW_PATH", tmp_path / ".bootstrap_password")
monkeypatch.setattr(storage, "_bootstrap_password", None)
import main as _main
app = FastAPI()
app.add_api_route("/api/health", _main.health_check, methods = ["GET"])
import secrets as _secrets
storage.create_initial_user(
username = storage.DEFAULT_ADMIN_USERNAME,
password = "human-password-123",
jwt_secret = _secrets.token_urlsafe(64),
must_change_password = False,
)
return app
class TestHealthAuthGate:
def test_no_auth_returns_minimal_payload(self, health_app):
c = TestClient(health_app)
r = c.get("/api/health")
assert r.status_code == 200
body = r.json()
assert body["status"] == "healthy"
assert "timestamp" in body
for forbidden in ("version", "device_type", "studio_root_id"):
assert forbidden not in body
def test_invalid_bearer_returns_minimal_payload(self, health_app):
# Regression: calling the async dep without await made any Bearer header pass.
c = TestClient(health_app)
r = c.get(
"/api/health",
headers = {"Authorization": "Bearer not-a-real-token"},
)
assert r.status_code == 200
body = r.json()
assert body["status"] == "healthy"
for forbidden in ("version", "device_type", "studio_root_id"):
assert forbidden not in body
def test_valid_bearer_returns_full_payload(self, health_app):
from auth import storage
from auth.authentication import create_access_token
token = create_access_token(storage.DEFAULT_ADMIN_USERNAME)
c = TestClient(health_app)
r = c.get(
"/api/health",
headers = {"Authorization": f"Bearer {token}"},
)
assert r.status_code == 200
body = r.json()
assert body["status"] == "healthy"
assert "version" in body
assert "device_type" in body
assert "studio_root_id" in body

View file

@ -125,22 +125,21 @@ class TestChatMessageToolRoles:
)
assert msg.content is None
def test_tool_role_missing_tool_call_id_rejected(self):
# Per OpenAI spec, role="tool" messages must carry tool_call_id so
# upstream backends can associate the result with its prior call.
# Pin the boundary-level rejection so a malformed tool-result
# message never reaches the passthrough path.
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "tool", content = '{"temperature": 72}')
assert "tool_call_id" in str(exc_info.value)
def test_tool_role_missing_tool_call_id_synthesised(self):
# Frontend drops the id on second-round POST; validator synthesises one.
msg = ChatMessage(role = "tool", content = '{"temperature": 72}')
assert msg.tool_call_id is not None
assert msg.tool_call_id.startswith("call_")
assert len(msg.tool_call_id) >= len("call_") + 8
def test_tool_role_empty_tool_call_id_rejected(self):
with pytest.raises(ValidationError):
ChatMessage(
role = "tool",
tool_call_id = "",
content = '{"temperature": 72}',
)
def test_tool_role_empty_tool_call_id_synthesised(self):
msg = ChatMessage(
role = "tool",
tool_call_id = "",
content = '{"temperature": 72}',
)
assert msg.tool_call_id is not None
assert msg.tool_call_id.startswith("call_")
# ── Role-aware content requirements ────────────────────────────
@ -162,10 +161,19 @@ class TestChatMessageToolRoles:
ChatMessage(role = "tool", tool_call_id = "call_1", content = "")
assert "content" in str(exc_info.value)
def test_assistant_without_content_or_tool_calls_rejected(self):
with pytest.raises(ValidationError) as exc_info:
ChatMessage(role = "assistant")
assert "content" in str(exc_info.value) or "tool_calls" in str(exc_info.value)
def test_assistant_without_content_or_tool_calls_tolerated(self):
# Stop-button leaves an empty assistant turn; tolerate so replay round-trips.
msg = ChatMessage(role = "assistant")
assert msg.content is None
assert msg.tool_calls is None
def test_assistant_empty_string_content_normalised_to_none(self):
msg = ChatMessage(role = "assistant", content = "")
assert msg.content is None
def test_assistant_empty_list_content_normalised_to_none(self):
msg = ChatMessage(role = "assistant", content = [])
assert msg.content is None
# ── Role-constrained tool-call metadata ────────────────────────
@ -472,3 +480,91 @@ class TestFriendlyErrorHttpx:
assert (
_friendly_error(RuntimeError("unrelated")) == "An internal error occurred"
)
from routes.inference import ( # noqa: E402
_drop_empty_assistant_sentinels,
_openai_messages_for_passthrough,
)
class TestDropEmptyAssistantSentinels:
def test_drops_empty_assistant_between_real_turns(self):
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": ""},
{"role": "user", "content": "again"},
]
out = _drop_empty_assistant_sentinels(msgs)
assert out == [
{"role": "user", "content": "hi"},
{"role": "user", "content": "again"},
]
def test_drops_assistant_with_no_content_key(self):
# exclude_none=True strips the content key entirely; filter must catch this.
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant"},
{"role": "user", "content": "ok"},
]
out = _drop_empty_assistant_sentinels(msgs)
assert out == [
{"role": "user", "content": "hi"},
{"role": "user", "content": "ok"},
]
def test_preserves_assistant_with_text(self):
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello back"},
]
out = _drop_empty_assistant_sentinels(msgs)
assert out == msgs
def test_preserves_assistant_with_tool_calls_only(self):
msgs = [
{"role": "user", "content": "weather?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": '{"t": 72}',
},
]
out = _drop_empty_assistant_sentinels(msgs)
assert out == msgs
def test_preserves_user_and_system_with_empty_content(self):
# Filter scoped to role="assistant" only.
msgs = [
{"role": "system", "content": ""},
{"role": "user", "content": ""},
]
out = _drop_empty_assistant_sentinels(msgs)
assert out == msgs
def test_openai_messages_for_passthrough_drops_sentinel(self):
"""End-to-end: Stop-sentinel must not reach the wire."""
req = ChatCompletionRequest(
model = "default",
messages = [
ChatMessage(role = "user", content = "hi"),
ChatMessage(role = "assistant", content = ""),
ChatMessage(role = "user", content = "again"),
],
)
out = _openai_messages_for_passthrough(req)
roles = [m["role"] for m in out]
assert roles == ["user", "user"]
for m in out:
assert m.get("content"), m

View file

@ -0,0 +1,241 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""Tests for the sandboxed-Python AST policy in core/inference/tools.py."""
import os
import sys
from pathlib import Path
import pytest
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(_BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(_BACKEND_ROOT))
from core.inference.tools import _check_code_safety
def _ok(code: str):
assert _check_code_safety(code) is None, code
def _blocked(code: str, *, expect_phrase: str):
msg = _check_code_safety(code)
assert msg is not None, code
assert expect_phrase in msg, (expect_phrase, msg)
class TestMetadataHostDenylist:
def test_aws_imds_literal_blocked(self):
_blocked(
'import requests; requests.get("http://169.254.169.254/latest/meta-data/")',
expect_phrase = "Blocked: cloud-metadata host",
)
def test_gcp_metadata_dns_blocked(self):
_blocked(
'import requests; requests.get("http://metadata.google.internal/")',
expect_phrase = "Blocked: cloud-metadata host",
)
def test_alibaba_ecs_literal_blocked(self):
_blocked(
'import socket; s=socket.socket(); s.connect(("100.100.100.200", 80))',
expect_phrase = "Blocked: cloud-metadata host",
)
def test_ipv6_imds_literal_blocked(self):
_blocked(
'import urllib.request; urllib.request.urlopen("http://[fd00:ec2::254]/")',
expect_phrase = "Blocked: cloud-metadata host",
)
def test_metadata_link_local_prefix_blocked(self):
_blocked(
'import requests; requests.get("http://169.254.170.2/v3/")',
expect_phrase = "Blocked: cloud-metadata host",
)
class TestTrustedHostAllowlist:
@pytest.mark.parametrize(
"url",
[
"https://en.wikipedia.org/wiki/Python_(programming_language)",
"https://fr.wikipedia.org/wiki/Python_(langage)",
"https://www.google.com/search?q=foo",
"https://duckduckgo.com/?q=foo",
"https://huggingface.co/unsloth",
"https://cdn-lfs.huggingface.co/repos/abc/def/file.bin",
"https://raw.githubusercontent.com/foo/bar/main/README.md",
"https://api.github.com/repos/foo/bar",
"https://arxiv.org/abs/2401.12345",
"https://export.arxiv.org/abs/2401.12345",
"https://stackoverflow.com/questions/12345",
"https://math.stackexchange.com/questions/12345",
"https://developer.mozilla.org/en-US/docs/Web/JavaScript",
"https://docs.python.org/3/library/asyncio.html",
"https://pypi.org/project/requests/",
"https://files.pythonhosted.org/packages/foo/bar.whl",
"https://www.bbc.com/news",
"https://api.weather.gov/points/40,-90",
"https://numpy.org/doc/stable/",
"https://pytorch.org/docs/stable/index.html",
],
)
def test_trusted_host_passes(self, url):
_ok(f"import requests; requests.get({url!r})")
def test_wikipedia_subdomain_passes(self):
_ok(
'import urllib.request; urllib.request.urlopen("https://m.en.wikipedia.org/wiki/Foo")'
)
def test_hf_co_short_form_passes(self):
_ok('import requests; requests.get("https://hf.co/unsloth/Qwen3.5-4B-GGUF")')
def test_github_io_pages_pass(self):
_ok('import requests; requests.get("https://unslothai.github.io/")')
class TestUntrustedHostBlock:
def test_example_com_blocked(self):
_blocked(
'import requests; requests.get("https://example.com/")',
expect_phrase = "Blocked: host not in sandbox allowlist",
)
def test_random_blog_blocked(self):
_blocked(
'import urllib.request; urllib.request.urlopen("https://random-blog-host.example/")',
expect_phrase = "Blocked: host not in sandbox allowlist",
)
def test_socket_connect_random_host_blocked(self):
_blocked(
'import socket; s=socket.socket(); s.connect(("evil.example", 80))',
expect_phrase = "Blocked: host not in sandbox allowlist",
)
def test_dynamic_url_not_statically_blocked(self):
# Static AST cannot resolve runtime URLs; bash blocklist is the fallback.
_ok('import requests; url = "https://example.com/"; requests.get(url)')
class TestHostNormalization:
def test_trailing_dot_treated_same(self):
_ok('import requests; requests.get("https://wikipedia.org./")')
def test_explicit_port_does_not_unblock_or_misblock(self):
_ok('import requests; requests.get("https://en.wikipedia.org:443/wiki/Foo")')
_blocked(
'import requests; requests.get("https://example.com:8080/")',
expect_phrase = "Blocked: host not in sandbox allowlist",
)
def test_userinfo_at_does_not_smuggle_metadata_host(self):
_blocked(
'import requests; requests.get("https://wikipedia.org@169.254.169.254/latest/")',
expect_phrase = "Blocked: cloud-metadata host",
)
def test_uppercase_host_normalised(self):
_ok('import requests; requests.get("https://EN.WIKIPEDIA.ORG/wiki/Foo")')
class TestUploadDenylist:
def test_requests_post_files_blocked(self):
_blocked(
(
"import requests\n"
'requests.post("https://huggingface.co/api/repos/upload", '
'files={"f": open("x.bin", "rb")})'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_requests_put_data_bytes_blocked(self):
_blocked(
(
"import requests\n"
'requests.put("https://huggingface.co/api/repos/upload", '
'data=b"\\x00\\x01\\x02")'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_requests_post_data_open_handle_blocked(self):
_blocked(
(
"import requests\n"
'requests.post("https://huggingface.co/api/repos/upload", '
'data=open("x.bin", "rb"))'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_httpx_post_files_blocked(self):
_blocked(
(
"import httpx\n"
'httpx.post("https://huggingface.co/api/repos/upload", '
'files={"f": open("x.bin", "rb")})'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_hf_api_upload_file_blocked(self):
_blocked(
(
"from huggingface_hub import HfApi\n"
'HfApi().upload_file(path_or_fileobj="x.bin", '
'path_in_repo="x.bin", repo_id="foo/bar")'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_hf_module_upload_folder_blocked(self):
_blocked(
(
"import huggingface_hub\n"
'huggingface_hub.upload_folder(folder_path="./", repo_id="foo/bar")'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_hf_create_commit_method_blocked(self):
_blocked(
(
"import huggingface_hub\n"
"api = huggingface_hub.HfApi()\n"
'api.create_commit(repo_id="foo/bar", operations=[])'
),
expect_phrase = "Blocked: file upload disallowed in sandbox",
)
def test_plain_post_json_not_blocked(self):
_ok(
"import requests\n"
'requests.post("https://api.weather.gov/lookup", json={"k": "v"})'
)
class TestSandboxCpuRlimitDefault:
"""Pin the default so a regression below 600s without opt-in is caught."""
def test_default_cpu_s_is_600(self):
src = (_BACKEND_ROOT / "core" / "inference" / "tools.py").read_text()
assert 'UNSLOTH_STUDIO_SANDBOX_CPU_S", "600"' in src
def test_clone_newnet_removed(self):
src = (_BACKEND_ROOT / "core" / "inference" / "tools.py").read_text()
assert "_libc.unshare(0x40000000)" not in src
# Explanatory comment retained.
assert "CLONE_NEWNET" in src
class TestMaxBodyDefault:
def test_default_is_500_mb(self):
src = (_BACKEND_ROOT / "main.py").read_text()
assert 'UNSLOTH_STUDIO_MAX_BODY_MB", "500"' in src

View file

@ -0,0 +1,90 @@
# SPDX-License-Identifier: AGPL-3.0-only
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
"""Pin TrainingStartRequest hyperparameter caps at the at-cap / over-cap boundary."""
import sys
from pathlib import Path
import pytest
from pydantic import ValidationError
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(_BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(_BACKEND_ROOT))
from models.training import (
_MAX_BATCH_SIZE,
_MAX_LORA_ALPHA,
_MAX_LORA_R,
_MAX_SEQ_LENGTH,
)
def _check_field(field_name: str, value):
"""Run the field validator without constructing a full TrainingStartRequest."""
from models.training import TrainingStartRequest
schema_field = TrainingStartRequest.model_fields[field_name]
return TrainingStartRequest.__pydantic_validator__.validate_assignment(
TrainingStartRequest.model_construct(),
field_name,
value,
)
class TestSeqLengthCap:
def test_at_cap_accepts(self):
_check_field("max_seq_length", _MAX_SEQ_LENGTH)
assert _MAX_SEQ_LENGTH == 2_000_000
def test_over_cap_rejects(self):
with pytest.raises(ValidationError) as exc:
_check_field("max_seq_length", _MAX_SEQ_LENGTH + 1)
assert "max_seq_length" in str(exc.value)
def test_below_min_rejects(self):
with pytest.raises(ValidationError):
_check_field("max_seq_length", 0)
class TestBatchSizeCap:
def test_at_cap_accepts(self):
_check_field("batch_size", _MAX_BATCH_SIZE)
assert _MAX_BATCH_SIZE == 4096
def test_over_cap_rejects(self):
with pytest.raises(ValidationError):
_check_field("batch_size", _MAX_BATCH_SIZE + 1)
def test_below_min_rejects(self):
with pytest.raises(ValidationError):
_check_field("batch_size", 0)
class TestLoraRCap:
def test_at_cap_accepts(self):
_check_field("lora_r", _MAX_LORA_R)
assert _MAX_LORA_R == 16_384
def test_over_cap_rejects(self):
with pytest.raises(ValidationError):
_check_field("lora_r", _MAX_LORA_R + 1)
def test_below_min_rejects(self):
with pytest.raises(ValidationError):
_check_field("lora_r", 0)
class TestLoraAlphaCap:
def test_at_cap_accepts(self):
_check_field("lora_alpha", _MAX_LORA_ALPHA)
assert _MAX_LORA_ALPHA == 32_768
def test_over_cap_rejects(self):
with pytest.raises(ValidationError):
_check_field("lora_alpha", _MAX_LORA_ALPHA + 1)
def test_below_min_rejects(self):
with pytest.raises(ValidationError):
_check_field("lora_alpha", 0)

View file

@ -28,7 +28,16 @@ from utils.models.model_config import (
)
def test_scan_trained_models_includes_lora_and_full_finetune_outputs(tmp_path: Path):
def test_scan_trained_models_includes_lora_and_full_finetune_outputs(
tmp_path: Path, monkeypatch
):
# resolve_output_dir refuses absolutes outside outputs_root; point it at tmp_path.
from utils.models import model_config as _mc
from utils.paths import storage_roots as _sr
monkeypatch.setattr(_sr, "outputs_root", lambda: tmp_path)
monkeypatch.setattr(_mc, "outputs_root", lambda: tmp_path)
lora_dir = tmp_path / "unsloth_SmolLM-135M_1775412608"
lora_dir.mkdir()
(lora_dir / "adapter_config.json").write_text(

View file

@ -276,21 +276,52 @@ def _clean_relative_path(
return Path(*parts) if parts else Path()
def _assert_contained(resolved: Path, root: Path) -> None:
"""Raise ValueError if ``resolved`` realpaths outside ``root``."""
try:
resolved_real = Path(os.path.realpath(resolved))
root_real = Path(os.path.realpath(root))
except OSError as exc:
raise ValueError(f"path resolution failed: {exc}") from exc
try:
resolved_real.relative_to(root_real)
except ValueError as exc:
raise ValueError(
f"path escapes root: {resolved!s} -> {resolved_real!s} "
f"is not under {root_real!s}"
) from exc
def resolve_under_root(
path_value: str | None,
*,
root: Path,
strip_prefixes: tuple[str, ...] = (),
) -> Path:
"""Resolve ``path_value`` and assert the result is under ``root``.
Absolutes are accepted only if already contained (so internal pre-resolved
paths re-enter idempotently); user-facing schemas reject absolutes upstream.
"""
if not path_value or not str(path_value).strip():
return root
path = Path(str(path_value).strip()).expanduser()
raw = str(path_value).strip()
if "\x00" in raw:
raise ValueError("path may not contain null bytes")
path = Path(raw).expanduser()
if ".." in path.parts:
raise ValueError(f"path may not contain '..' segments: {raw!r}")
if path.is_absolute():
_assert_contained(path, root)
return path
cleaned = _clean_relative_path(str(path), strip_prefixes = strip_prefixes)
return root / cleaned
cleaned = _clean_relative_path(raw, strip_prefixes = strip_prefixes)
candidate = root / cleaned
_assert_contained(candidate, root)
return candidate
def resolve_output_dir(path_value: str | None = None) -> Path:
@ -318,9 +349,22 @@ def resolve_tensorboard_dir(path_value: str | None = None) -> Path:
def resolve_dataset_path(path_value: str) -> Path:
path = Path(path_value).expanduser()
raw = str(path_value or "").strip()
if "\x00" in raw:
raise ValueError("dataset path may not contain null bytes")
path = Path(raw).expanduser()
if ".." in path.parts:
raise ValueError(f"dataset path may not contain '..' segments: {raw!r}")
if path.is_absolute():
return path
for root_fn in (datasets_root, dataset_uploads_root, recipe_datasets_root):
try:
_assert_contained(path, root_fn())
return path
except ValueError:
continue
raise ValueError(
f"dataset path must be relative or under a dataset root: {raw!r}"
)
parts = [part for part in Path(path_value).parts if part not in ("", ".")]
if parts[:2] == ["assets", "datasets"]:

View file

@ -316,18 +316,47 @@ if code in (400, 422):
else:
fail(f"/api/auth/refresh without body returned {code} (expected 400/422)")
# Login burst with wrong password must keep returning 401, NOT 429.
# Documents that no rate-limit / brute-force lockout exists today.
# When/if we add one, this assertion updates in the same PR.
all_401 = True
for i in range(5):
code, _ = login("definitely-wrong-password")
if code != 401:
all_401 = False
fail(f"login burst attempt {i+1} returned {code} (expected 401)")
# Wrong-password burst: expect 401 until the per-IP bucket fills, then
# 429 with Retry-After. Bucket cannot be reset between tests, so we
# assert the observable invariant rather than a fixed transition index.
def _login_with_headers(password: str) -> tuple[int, str | None]:
"""Like ``login`` but returns ``(status, retry_after_header)``."""
url = f"{BASE}/api/auth/login"
data = json.dumps({"username": "unsloth", "password": password}).encode()
req = urllib.request.Request(
url,
data = data,
method = "POST",
headers = {"Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req, timeout = 10) as r:
return r.status, r.headers.get("Retry-After")
except urllib.error.HTTPError as exc:
return exc.code, exc.headers.get("Retry-After") if exc.headers else None
codes = []
retry_after = None
for i in range(8):
code, ra = _login_with_headers("definitely-wrong-password")
codes.append(code)
if code == 429:
retry_after = ra
break
if all_401:
ok("login burst (5x wrong pw) -> 401 each (no rate-limit, documented)")
if code != 401:
fail(f"login burst attempt {i+1} returned {code} (expected 401 or 429)")
break
if 401 not in codes:
fail(f"login burst never returned 401 before rate-limit (codes={codes})")
elif 429 not in codes:
fail(f"login burst never rate-limited after {len(codes)} wrongs (codes={codes})")
elif retry_after is None:
fail("429 response missing Retry-After header")
else:
ok(f"login burst -> 401x{codes.count(401)} then 429 with Retry-After={retry_after}")
# ─────────────────────────────────────────────────────────────────────────

View file

@ -593,12 +593,16 @@ def test_install_ps1_bakes_studio_root_id_into_launcher():
def test_health_endpoint_exposes_studio_root_id_not_raw_path():
"""studio/backend/main.py /api/health must expose studio_root_id (a
hex digest) and NOT the raw studio_root path. Studio supports
`-H 0.0.0.0`; an unauthenticated /api/health that returns the raw
install path leaks username, home dir, workspace name, etc."""
`-H 0.0.0.0`; a /api/health that returns the raw install path
leaks username, home dir, workspace name, etc."""
main_py = REPO_ROOT / "studio" / "backend" / "main.py"
src = main_py.read_text()
health_idx = src.index('@app.get("/api/health")')
health_block = src[health_idx : health_idx + 1500]
# Slice up to the next top-level @app. so a growing body stays in scope.
next_app_idx = src.find("\n@app.", health_idx + 1)
if next_app_idx == -1:
next_app_idx = len(src)
health_block = src[health_idx:next_app_idx]
assert (
'"studio_root_id"' in health_block
), "/api/health must expose studio_root_id (hex digest)"