mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-20 00:51:36 +00:00
studio: security and hardening pass (auth rate-limit, sandbox, path containment, schema validation, headers) (#5375)
Some checks are pending
Core / Core (HF=default + TRL=default) (push) Waiting to run
Core / Core (HF=4.57.6 + TRL<1) (push) Waiting to run
Core / Core (HF=latest + TRL=latest) (push) Waiting to run
Core / llama.cpp build + smoke (push) Waiting to run
Lint CI / Source lint (Python + shell + YAML + JSON + safety nets) (push) Waiting to run
MLX CI on Mac M1 / dispatch (push) Waiting to run
Security audit / advisory audit (pip + npm + cargo) (push) Waiting to run
Security audit / pip scan-packages :: extras (push) Waiting to run
Security audit / pip scan-packages :: studio (push) Waiting to run
Security audit / pip scan-packages :: hf-stack (push) Waiting to run
Security audit / npm scan-packages (Studio frontend tarballs) (push) Waiting to run
Security audit / workflow-trigger lint (pull_request_target / cache-poisoning) (push) Waiting to run
Security audit / pytest tests/security (push) Waiting to run
Security audit / npm provenance + new install-script diff (push) Waiting to run
Studio API CI / Studio API & Auth Tests (push) Waiting to run
Backend CI / (Python 3.10) (push) Waiting to run
Backend CI / (Python 3.11) (push) Waiting to run
Backend CI / (Python 3.12) (push) Waiting to run
Backend CI / (Python 3.13) (push) Waiting to run
Backend CI / Repo tests (CPU) (push) Waiting to run
Frontend CI / Frontend build + bundle sanity (push) Waiting to run
Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Studio GGUF CI / Tool calling Tests (push) Waiting to run
Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio API CI / Studio API & Auth Tests (push) Waiting to run
Mac Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Mac Studio GGUF CI / Tool calling Tests (push) Waiting to run
Mac Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio UI CI / Chat UI Tests (push) Waiting to run
Mac Studio Update CI / Studio Updating Tests (push) Waiting to run
Studio Tauri CI / Tauri Linux debug build (no codesign) (push) Waiting to run
Studio UI CI / Chat UI Tests (push) Waiting to run
Studio Update CI / Studio Updating Tests (push) Waiting to run
Windows Studio API CI / Studio API & Auth Tests (push) Waiting to run
Windows Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Windows Studio GGUF CI / Tool calling Tests (push) Waiting to run
Windows Studio GGUF CI / JSON, images (push) Waiting to run
Windows Studio UI CI / Chat UI Tests (push) Waiting to run
Windows Studio Update CI / Studio Updating Tests (push) Waiting to run
Wheel CI / Wheel build + content sanity + import smoke (push) Waiting to run
Some checks are pending
Core / Core (HF=default + TRL=default) (push) Waiting to run
Core / Core (HF=4.57.6 + TRL<1) (push) Waiting to run
Core / Core (HF=latest + TRL=latest) (push) Waiting to run
Core / llama.cpp build + smoke (push) Waiting to run
Lint CI / Source lint (Python + shell + YAML + JSON + safety nets) (push) Waiting to run
MLX CI on Mac M1 / dispatch (push) Waiting to run
Security audit / advisory audit (pip + npm + cargo) (push) Waiting to run
Security audit / pip scan-packages :: extras (push) Waiting to run
Security audit / pip scan-packages :: studio (push) Waiting to run
Security audit / pip scan-packages :: hf-stack (push) Waiting to run
Security audit / npm scan-packages (Studio frontend tarballs) (push) Waiting to run
Security audit / workflow-trigger lint (pull_request_target / cache-poisoning) (push) Waiting to run
Security audit / pytest tests/security (push) Waiting to run
Security audit / npm provenance + new install-script diff (push) Waiting to run
Studio API CI / Studio API & Auth Tests (push) Waiting to run
Backend CI / (Python 3.10) (push) Waiting to run
Backend CI / (Python 3.11) (push) Waiting to run
Backend CI / (Python 3.12) (push) Waiting to run
Backend CI / (Python 3.13) (push) Waiting to run
Backend CI / Repo tests (CPU) (push) Waiting to run
Frontend CI / Frontend build + bundle sanity (push) Waiting to run
Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Studio GGUF CI / Tool calling Tests (push) Waiting to run
Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio API CI / Studio API & Auth Tests (push) Waiting to run
Mac Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Mac Studio GGUF CI / Tool calling Tests (push) Waiting to run
Mac Studio GGUF CI / JSON, images (push) Waiting to run
Mac Studio UI CI / Chat UI Tests (push) Waiting to run
Mac Studio Update CI / Studio Updating Tests (push) Waiting to run
Studio Tauri CI / Tauri Linux debug build (no codesign) (push) Waiting to run
Studio UI CI / Chat UI Tests (push) Waiting to run
Studio Update CI / Studio Updating Tests (push) Waiting to run
Windows Studio API CI / Studio API & Auth Tests (push) Waiting to run
Windows Studio GGUF CI / OpenAI, Anthropic API tests (push) Waiting to run
Windows Studio GGUF CI / Tool calling Tests (push) Waiting to run
Windows Studio GGUF CI / JSON, images (push) Waiting to run
Windows Studio UI CI / Chat UI Tests (push) Waiting to run
Windows Studio Update CI / Studio Updating Tests (push) Waiting to run
Wheel CI / Wheel build + content sanity + import smoke (push) Waiting to run
* studio: contain export and dataset paths under their configured roots
resolve_under_root and resolve_dataset_path previously returned absolute
paths unchanged, so an authenticated client could supply
save_directory="/tmp/escape" (or any other absolute path) and have the
exporter drop adapter files anywhere the server user could write. This
turned up during a recent audit pass where an authenticated POST to
/api/export/export/lora with save_directory="/tmp/lora_escape_test"
returned 200 and wrote adapter_model.safetensors, adapter_config.json,
and tokenizer files under /tmp.
The fix is two-layered:
storage_roots.py adds an _assert_contained(resolved, root) helper that
runs after path resolution and rejects any result whose realpath does
not sit under realpath(root). resolve_under_root now rejects '..'
segments and null bytes outright, and only accepts absolute inputs when
they are already inside the configured root (internal call sites that
re-resolve a stored absolute path stay idempotent;
worker.py:resolve_output_dir(output_dir) etc. continue to work).
resolve_dataset_path picks up the same containment rule, scoped to the
three dataset roots.
models/export.py adds field_validator("save_directory", mode="before")
to ExportCommonOptions and ExportGGUFRequest so bad input fails fast at
422 with a clear message rather than a 500 deep inside the resolver.
The validator rejects empty/whitespace, null bytes, control chars,
strings longer than 255 chars, absolute paths, and '..' segments.
routes/export.py:_export_details now returns os.path.relpath(output_path,
exports_root()) so the Export Complete dialog and /api/models/loras no
longer leak the absolute install prefix to the UI; the basename is
used as a last-resort fallback.
Verified end to end:
- POST /api/export/export/lora {"save_directory":"/tmp/foo"} -> 422
"save_directory must be a name or relative path under the export
root; absolute paths are rejected". /tmp/foo is not created.
- "../../etc/escape" -> 422 "may not contain '..' segments".
- save_directory="my_subdir" -> still accepted (400 only because the
test had no checkpoint loaded yet, not because of validation).
- Internal idempotent re-resolve via resolve_export_dir(absolute path
that is already under exports_root) returns the same path unchanged.
* studio/sandbox: harden bash + python tool execution
The sandboxed Bash and Python tool channels in Chat ran with a thin
preexec hook (PR_SET_NO_NEW_PRIVS + RLIMIT_FSIZE only). Bash had a
small word blocklist; Python had an AST safety pass aimed at
signal-tampering and shell-escape primitives. An audit pass showed
several gaps that a tool-calling model could trigger inadvertently:
- bash curl/wget/nc reached AWS IMDSv2 and returned live STS
credentials for the instance role.
- python "import socket; s.connect((169.254.169.254, 80))"
reached the same endpoint regardless of the bash blocklist.
- "cat /etc/passwd" was blocked at the bash side (because "passwd"
is in the blocklist), but "open('/etc/passwd').read()" in Python
happily returned its contents.
- "chr(115)+chr(117)+chr(100)+chr(111)" style dynamic-arg
construction slipped through the AST shell-escape check.
- The supervisor used proc.kill() on timeout, which only signals
the immediate pid; bash-backgrounded children survived. A fork
bomb could spawn for the full 300s timeout window.
- Session work directories under ~/studio_sandbox/<id>/ were
created with default umask (0o755), so any other UID on the host
could enumerate them.
- session_id sanitisation used a one-shot str.replace("..",""),
which is non-iterative and a small footgun.
This commit takes a conservative middle path: the sandbox still
runs as the Studio UID with no namespace tricks where the kernel
disallows them, but every chokepoint is tightened.
_sandbox_preexec now:
- calls os.setsid() so children share a process group; the
supervisor uses os.killpg(SIGKILL) on timeout/cancel so
backgrounded children die with the parent (new _kill_process_tree
helper, wired into _cancel_watcher and both _bash_exec /
_python_exec timeout branches).
- calls os.umask(0o077) so files the child writes default to 0o600.
- applies PR_SET_PDEATHSIG=SIGKILL so an orphaned child dies if
Studio exits.
- best-effort unshare(CLONE_NEWNET) for a private network namespace
(failure is logged and swallowed; defense-in-depth is still in
place via the bash blocklist and the AST checker below).
- sets RLIMIT_NPROC=10000 (tunable via UNSLOTH_STUDIO_SANDBOX_NPROC),
RLIMIT_AS=8GB, RLIMIT_CPU=300, RLIMIT_NOFILE=1024. The 10k NPROC
figure is chosen to sit well above the ~500 LWPs a healthy Studio
+ llama-server combination already uses while still capping a
runaway fork bomb. NPROC counts LWPs per real UID, so a lower
figure (e.g. 256) starves legitimate bash forks
("bash: fork: retry: Resource temporarily unavailable").
_get_workdir:
- rejects session_id that doesn't match [A-Za-z0-9_-]{1,64};
non-matching values bucket into a shared "_invalid" dir.
- chmod 0o700 on both the workdir and on ~/studio_sandbox/ so
other UIDs cannot read another session's contents.
_BLOCKED_COMMANDS_COMMON gains: doas, pkexec, halt, poweroff, curl,
wget, nc, ncat, netcat, socat, ssh, scp, sftp, rsync, eval, source.
The intent is to keep general bash usage working (echo, ls, pipes,
loops, for, head, etc.) while denying the obvious egress and
escalation paths.
The AST checker (_check_signal_escape_patterns) is split into the
existing shell/signal/loop checks plus a new narrow IO denylist:
- Always flag non-literal args to anything in _SHELL_EXEC_FUNCS,
not just _STRING_SHELL_FUNCS. Closes the dynamic-arg bypass.
- Reject calls to socket.create_connection, socket.socket().connect,
urllib.request.urlopen, http.client.HTTP*Connection, requests.*,
httpx.* whose literal host argument is in a cloud-metadata
denylist (169.254.169.254 + 169.254.* + 100.64.*, plus the
GCP/Alibaba/ECS metadata hostnames and IPv6 link-local). Public
hosts (example.com, huggingface.co, ...) still work. Dynamic
hosts cannot be statically blocked; mitigated by the bash
blocklist + the netns where the kernel allows it.
- Reject literal open("/etc/passwd"), /etc/shadow, /etc/sudoers,
/etc/ssh/*, and /proc/<pid>/environ. Other files
(/etc/os-release, /etc/hostname, /tmp/*, user dirs) still work.
The _check_code_safety summariser is updated to include the new
network_calls and sensitive_file_reads buckets in its error string.
Regression-checked: echo, sleep, ls /tmp, for loops, piped helpers
(echo a | tr a A), urllib.request.urlopen("http://example.com"),
socket.getaddrinfo("example.com",80), open("/etc/os-release"),
open("/tmp/...","w") all still succeed. curl, wget, nc, ssh, rm,
socket.create_connection(("169.254.169.254",80)),
open("/etc/passwd"), open("/proc/self/environ") all correctly
blocked.
* studio: rate-limit login, rotate refresh tokens, add logout, security headers, gate bootstrap injection
A pass over the auth surface found a cluster of related issues that this
commit closes together.
Login (routes/auth.py):
- Add an in-memory per-IP login rate limiter. Five failed POSTs to
/api/auth/login inside a 60s window produce 429 with Retry-After.
A successful login clears the bucket. Previously 30 wrong passwords
in under one second was accepted as 30x 401, which combined with
the (now fixed) admin-username leak from /api/auth/status made
brute-force trivial against a small password.
Logout (routes/auth.py):
- New POST /api/auth/logout returns 204 and calls
storage.revoke_user_refresh_tokens(subject) so the refresh token
is no longer valid. Previously POST /api/auth/logout returned 405
and there was no way to invalidate refresh tokens short of
changing the password. Frontend session.ts already calls
clearAuthTokens() to drop localStorage; the new endpoint lets the
client also tell the server to revoke server-side state.
Refresh-token rotation (routes/auth.py + auth/storage.py):
- New storage.consume_refresh_token(token) atomically validates +
deletes a refresh token, returning (username, is_desktop). The
/api/auth/refresh handler now mints both a new access AND a new
refresh token; the supplied token becomes invalid. Replaying a
consumed refresh returns 401 "Invalid or expired refresh token".
The previous refresh_access_token helper is left in place for
callers that intentionally want the non-rotating shape; nothing
in the route layer uses it now.
/api/auth/status no longer leaks default_username (models/auth.py +
routes/auth.py):
- AuthStatusResponse.default_username becomes Optional[str] with a
None default; the handler always returns None. The frontend already
hardcodes HIDDEN_LOGIN_USERNAME = "unsloth" (auth-form.tsx:82), so
no UI change is required.
window.__UNSLOTH_BOOTSTRAP__ no longer auto-injects (main.py):
- _inject_bootstrap is now opt-in via the
UNSLOTH_STUDIO_INJECT_BOOTSTRAP env var. The previous default
(inject whenever requires_password_change is true) embedded the
plaintext bootstrap password into the first-boot HTML for any
caller that hit /, /change-password, or any unknown SPA path.
Browser extensions and any XSS payload on the page could read it
trivially. With the new gate the bootstrap password lives only in
the auth/.bootstrap_password file (mode 0o600) where it has always
been; users typing it into a current-password field is the right
UX. routes/auth.py:change_password also clears
app.state.bootstrap_password defensively.
Security headers + server fingerprint (main.py + run.py):
- New SecurityHeadersMiddleware adds Content-Security-Policy,
X-Frame-Options: DENY, X-Content-Type-Options: nosniff,
Referrer-Policy: no-referrer,
Permissions-Policy: camera=(), microphone=(), geolocation=(),
interest-cohort=(), and stamps server: unsloth-studio so the
generic uvicorn banner no longer fingerprints the stack. The
uvicorn.Config gains server_header=False so it stops emitting its
own Server header.
/api/health minimisation (main.py):
- Unauthenticated GET /api/health returns just
{"status":"healthy","timestamp":...} so load-balancer liveness
probes keep working without leaking version, device_type,
chat_only, desktop_protocol_version, or studio_root_id to
arbitrary callers. A request that presents a valid Bearer token
still gets the full diagnostic payload so internal launchers and
sibling-Studio detection (which compares studio_root_id) keep
working.
Verification:
- 30 wrong-password POSTs to /api/auth/login -> first 5 = 401, 6th
through 30th = 429.
- POST /api/auth/logout with a fresh token -> 204. The matching
refresh token then fails 401.
- Login -> R1; /api/auth/refresh with R1 -> new access + R2 (R2 !=
R1); /api/auth/refresh with R1 again -> 401; /api/auth/refresh
with R2 -> still succeeds once and rotates again.
- curl /api/auth/status -> default_username: null.
- curl http://127.0.0.1/ does not contain __UNSLOTH_BOOTSTRAP__.
- curl -I / shows CSP, X-Frame-Options: DENY,
X-Content-Type-Options: nosniff, Referrer-Policy: no-referrer,
Permissions-Policy, and server: unsloth-studio.
- curl /api/health unauthenticated -> {status, timestamp} only.
curl with Authorization: Bearer <valid> -> full payload.
- Existing /api/system, /api/models/list, /api/train/status,
/api/inference/status, /api/auth/api-keys, login flow, SPA root
all still return 200 after the changes (regression smoke).
* studio: add SecurityHeadersMiddleware, MaxBodyMiddleware, /recipes redirect, gate _inject_bootstrap, minimise /api/health
This commit lands the main.py-side changes that share a single
middleware-registration spot. They are kept together because every
change here is either (a) a top-level middleware definition that has
to be added next to LoggingMiddleware, or (b) a route handler at the
same file-level.
SecurityHeadersMiddleware (Content-Security-Policy, X-Frame-Options:
DENY, X-Content-Type-Options: nosniff, Referrer-Policy: no-referrer,
Permissions-Policy, server: unsloth-studio). The previous responses
emitted no CSP, no XFO, no Referrer-Policy and were stamped
server: uvicorn.
MaxBodyMiddleware rejects POST/PUT/PATCH on the inference / dataset /
data-recipe / train / export prefixes when Content-Length exceeds
UNSLOTH_STUDIO_MAX_BODY_MB (default 100). The audit hit this by
attaching a 50 MB plain-text file to a chat message and watching
Studio base64-encode it into the JSON body; uvicorn has no enforced
cap so the only previous guard was the per-file 50 MB ceiling that
data-recipe upload routes already enforce. The new middleware extends
that ceiling to the OpenAI-compat path that the Chat attachments
flow through. Verified: a 200 MB JSON POST to /v1/chat/completions
returns HTTP 413 "Request body too large (209,715,264 bytes; max
104,857,600)". A small valid request continues to reach the handler.
_inject_bootstrap is gated behind UNSLOTH_STUDIO_INJECT_BOOTSTRAP.
The previous default was to inline window.__UNSLOTH_BOOTSTRAP__ =
{username, password} into the first-boot HTML whenever
requires_password_change was true, which exposed the plaintext
bootstrap password to any browser extension, page script, or LAN
caller on -H 0.0.0.0. The bootstrap password remains in the on-disk
.bootstrap_password file (mode 0o600) where it has always lived;
users typing it into a current-password field is the right UX.
/api/health unauthenticated returns {"status":"healthy","timestamp":
...} only; the previous payload (version, device_type, chat_only,
desktop_protocol_version, supports_desktop_auth, studio_root_id,
native_path_leases_supported) is preserved for callers that present
a valid Bearer token, so internal launchers and sibling-Studio
detection (which compares studio_root_id) keep working.
/recipes -> /data-recipes 308 redirect. The Data Recipes page lives
at /data-recipes; users typing /recipes hit the SPA catch-all and
saw "Not Found". The redirect also preserves any tail path, so
/recipes/<rest> -> /data-recipes/<rest>.
Verified end to end with curl: CSP / XFO / X-Content-Type-Options /
Referrer-Policy / Permissions-Policy all present on /, server header
is now unsloth-studio (uvicorn's own banner is suppressed via
server_header=False in run.py from the auth-batch commit). Followed
the /recipes redirect lands on the SPA HTML.
* studio: bound TrainingStartRequest hyperparameters at the schema level
POST /api/train/start accepted any value for learning_rate, batch_size,
max_steps, max_seq_length, warmup_steps, warmup_ratio, num_epochs,
save_steps, weight_decay, gradient_accumulation_steps, lora_r,
lora_alpha and lora_dropout, including -1, 0, 1e9, and non-numeric
strings like 'abc' or 'two' (which silently coerce to 0 in the
trainer). Probing showed the API returning 200 to learning_rate=-1
and batch_size=0; only max_steps had any partial clamping.
This commit adds field_validator on every numeric hyperparameter.
Bounds are chosen wide enough to span realistic single-host
configurations (B200 with 180 GB of memory comfortably fits the
upper end) while rejecting the values that always produce broken
training:
- learning_rate: parses str/float, requires 0 < lr < 1.0. Non-numeric
input raises with "learning_rate must be parseable as float (got
'abc')" instead of silently coercing to 0.
- batch_size: [1, 1024].
- gradient_accumulation_steps: [1, 4096].
- num_epochs: [1, 1000].
- max_steps: [1, 1_000_000].
- max_seq_length: [1, 131072].
- warmup_steps: [0, max_steps].
- warmup_ratio: [0.0, 1.0].
- save_steps: [0, 1_000_000].
- weight_decay: [0, 10] (typical 0..0.1).
- lora_r: [1, 512].
- lora_alpha: [1, 1024].
- lora_dropout: [0.0, 1.0).
Each validator names the offending field in its ValueError message
so the 422 response body identifies which input is bad. The
learning_rate validator returns its result as str (the schema field
type is str("2e-4") for backwards compatibility) so existing call
sites that float() the value continue to work.
Verified:
- learning_rate=-1 -> 422 "learning_rate must be > 0 (got -1.0);
typical range is 1e-6 .. 1e-3".
- learning_rate='abc' -> 422 "must be parseable as float".
- batch_size=-1 / 0 / 999999 -> 422 "batch_size must be in [1, 1024]".
- batch_size='two' -> 422 (pydantic int parser).
- max_steps=0 / -5 -> 422 "must be a positive int".
- max_seq_length=200000 -> 422 "must be in [1, 131072]".
- warmup_ratio=2.5 -> 422 "must be in [0.0, 1.0]".
- lora_dropout=1.5 -> 422 "must be in [0.0, 1.0)".
- Valid request with learning_rate='2e-4', batch_size=1, max_steps=5
passes validation and the training run starts as normal.
* studio: redact image-decode errors, clean checkpoint dirs on cancel, tolerate Stop-button + tool-result message shapes
Three small fixes that fall under "do not let the audit findings
become user-visible papercuts".
routes/inference.py - image-decode error redaction (the audit hit
this with a 0-byte / malformed / wrong-extension image upload). The
three image-normalise sites previously raised HTTPException(400,
detail=f"Failed to process image: {e}"). When PIL raised
UnidentifiedImageError(io.BytesIO(raw)) the message string included
"<_io.BytesIO object at 0x7e40a5d7bf60>", leaking both the Python
class name (confirming the PIL/io stack) and a heap address (mildly
useful for ASLR-bypass chaining if another memory-corruption bug is
ever found). Each site now catches UnidentifiedImageError and
returns the generic "Unsupported or corrupt image format"; the
fall-through generic except returns "Failed to process image". No
exception-repr is interpolated into a response body anywhere along
these paths.
core/training/training.py - checkpoint cleanup on cancel. When a
user clicks Cancel Training, the trainer flips _cancel_requested=True
and the supervisor force-terminates the subprocess. The trainer
writes checkpoint-<step> directories under output_dir every
save_steps; previously these survived the cancel and accumulated on
disk (the audit recorded ~67 MB stuck after a 200-step cancel with
save_steps=20). New helper _cleanup_cancelled_checkpoints(output_dir)
globs checkpoint-<int> entries and removes them. It is gated by a
realpath containment check against outputs_root() so it cannot
accidentally rmtree anything outside the configured outputs root.
force_terminate() invokes the helper after the subprocess join when
_cancel_requested is true. Stop-and-Save runs are unaffected because
that path keeps _cancel_requested=False.
models/inference.py - chat message shape tolerance. Two related
frontend interactions used to crash the request validator:
- After the Stop button truncates a generation, the frontend
retained {role:"assistant", content:""} in the conversation
history and replayed it on the next send. ChatMessage previously
required role="assistant" to have non-empty content or tool_calls,
so the next message returned 422 and the thread was permanently
broken. The validator now normalises empty assistant content to
None so the request round-trips and the trailing empty turn can
be ignored downstream.
- The frontend's second-round tool POST drops the streamed
tool_call_id, hitting the strict-spec check "role=tool requires
tool_call_id". The validator now synthesises an opaque id
(call_<8 hex>) when missing, so the request reaches the handler
and the model's final summarising response gets generated. The
proper fix lives in the frontend (carry the streamed id through
the second POST) and will follow.
Verified end to end with curl: HTTP 400 (model not loaded) on both
the empty-assistant history shape and the tool-result-without-id
shape, instead of HTTP 422 from the schema validator.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* studio: tighten code comments from security-hardening pass
Trim verbose docstrings and inline finding references added in the
previous commits in this branch. Functionality unchanged.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* studio: await get_current_subject in /api/health and make refresh-token consumption atomic
The /api/health auth probe called get_current_subject(creds) without
awaiting it. The coroutine object is truthy, so any caller presenting a
Bearer header (valid or not) received the full diagnostic payload
including version, device_type, studio_root_id, etc. Await the coroutine
and treat HTTPException as 'fall back to the minimal liveness payload'.
consume_refresh_token did SELECT then DELETE WHERE id under default
autocommit isolation. Two concurrent POST /api/auth/refresh requests
could both win the SELECT before either DELETE ran, defeating
single-use refresh-token rotation. Replace with a single
DELETE ... WHERE token_hash = ? AND expires_at >= ? RETURNING ...
statement so the validate-and-delete lands as one atomic op under
SQLite's write lock (3.45.1 supports RETURNING; min was 3.35).
* studio: enforce body cap on chunked uploads and drop unsafe-inline from script-src
MaxBodyMiddleware previously only inspected the declared Content-Length
header; clients omitting it or sending Transfer-Encoding: chunked
bypassed the cap and could still drive an OOM via the downstream
JSON / file readers on /v1/chat/completions, /api/inference, /api/data-recipe,
/api/datasets, /api/train, /api/export. Rewrite as a raw ASGI middleware
that drains and counts http.request frames, replies 413 once the running
total exceeds UNSLOTH_STUDIO_MAX_BODY_MB before invoking the FastAPI
handler, and replays the buffered body to downstream so route code that
calls request.json() / await request.body() works unchanged.
CSP previously included 'unsafe-inline' on script-src, which defeats the
main XSS protection. The frontend bundle does not need inline scripts;
the only inline <script> the backend ever emits is _inject_bootstrap,
which is opt-in via UNSLOTH_STUDIO_INJECT_BOOTSTRAP. Drop 'unsafe-inline'
from script-src by default; when _inject_bootstrap fires, generate a
per-response nonce, embed it on the inlined <script>, and have
SecurityHeadersMiddleware splice 'nonce-XXX' into the CSP for that one
response (the internal x-internal-script-nonce header is popped before
the response leaves the server). 'unsafe-inline' stays on style-src for
Vite-injected styles.
* studio: drop empty assistant sentinel before passthrough
ChatMessage._validate_role_shape normalises role="assistant", content=""
(the post-Stop sentinel emitted by the frontend) to content=None so the
in-process path can drop it via _extract_content_parts. The passthrough
path then ran m.model_dump(exclude_none=True), which strips the now-None
content key entirely, sending {"role":"assistant"} to llama-server / the
OpenAI-compat backend. That fails upstream and leaves the user without a
recoverable Stop->resume.
Add _drop_empty_assistant_sentinels and call it at both passthrough
message origins: _openai_messages_for_passthrough (covers
/v1/chat/completions and the Responses API which routes through it) and
the anthropic_messages_to_openai output before
_anthropic_passthrough_*. Assistant messages that carry only tool_calls
(no content) are preserved.
* studio/tests: cover audit-fix surfaces and rebase pre-existing tests
Adds and updates pytest coverage for the four bot-flagged audit fixes
landed earlier in this branch and rebases two pre-existing tests that
were broken by the relaxed-validator and /api/health auth-gate changes.
studio/backend/tests/test_middleware.py (new)
MaxBodyMiddleware: small protected, large declared, unprotected
passthrough, chunked-upload-over-cap rejection (the regression for
the original Content-Length-only gap), and chunked-under-cap replay.
SecurityHeadersMiddleware: script-src no longer carries
'unsafe-inline', style-src still does, default headers
(XFO/XCTO/Referrer-Policy/Permissions-Policy/server), and the
internal x-internal-script-nonce header is consumed by the
middleware and converted to 'nonce-XXX' in the CSP.
/api/health: no auth -> minimal, invalid Bearer -> minimal
(the await regression), valid Bearer -> full diagnostic payload.
studio/backend/tests/test_desktop_auth.py
consume_refresh_token: second-call returns None, expired returns
None, and a 64-thread concurrent pile-up against the same hash
produces exactly one successful consumer (regression for the
SELECT-then-DELETE race).
test_health_response_reports_desktop_capability_fields: rebase
against the new health_check(request) signature by going through
TestClient with a real bearer instead of asyncio.run-ing the
handler directly.
studio/backend/tests/test_openai_tool_passthrough.py
Pin the new ChatMessage tolerance: assistant without content or
tool_calls is tolerated (normalises content -> None), empty-string
and empty-list assistant content normalise to None, and a missing
/ empty tool_call_id on role='tool' is synthesised as call_<hex>
rather than raising. Tests for _drop_empty_assistant_sentinels
cover the three drop shapes (empty string, empty list, missing
content key), preservation of assistant text and tool_calls-only
messages, and end-to-end through
_openai_messages_for_passthrough.
studio/backend/main.py
SecurityHeadersMiddleware.dispatch used response.headers.pop(...)
for the nonce-header handoff; Starlette's MutableHeaders has no
pop. Read-then-del so the internal handoff header is still
stripped before the response leaves the server.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* studio/tests: rebase three more pre-existing CI tests against this branch
CI on PR #5375 was red on three tests that were tuned for behaviour
predating this branch. Updates each so the assertions match what the
audit fixes intentionally changed; no production code touched.
studio/backend/tests/test_trained_model_scan.py
test_scan_trained_models_includes_lora_and_full_finetune_outputs
passed an absolute tmp_path through scan_trained_models, which now
runs resolve_output_dir / _assert_contained against outputs_root().
Repoint outputs_root() at tmp_path via monkeypatch so the fixture
dirs land under the configured root and the realpath containment
check passes.
tests/test_studio_install_workspace_guard.py
test_health_endpoint_exposes_studio_root_id_not_raw_path read
the first 1500 bytes after @app.get("/api/health") and asserted on
the studio_root_id literal. The handler grew (unauth short-circuit
+ await dependency gate) and the literal slid past the byte window.
Replace the fixed window with a slice up to the next top-level
@app.* decorator so the test surveys the whole handler regardless
of size.
tests/studio/studio_api_smoke.py
The "login burst (5x wrong pw) -> 401 each" assertion was tagged
"When/if we add one, this assertion updates in the same PR." We
added the per-IP rate-limit in routes/auth.py
(_LOGIN_MAX_FAILS=5/60s) but missed the assertion update. Rewrite
the burst probe to observe the new invariant: at least one 401,
eventual transition to 429, and Retry-After present on the 429.
Adds a small _login_with_headers helper since the existing login()
helper drops response headers.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* ci(studio-ui): set UNSLOTH_STUDIO_INJECT_BOOTSTRAP=1 for Playwright Studios
The Chat UI Playwright test drives the first-boot change-password
form, which (per playwright_chat_ui.py step "1. Change-password
through the UI") pre-seeds the hidden current_password field from
window.__UNSLOTH_BOOTSTRAP__. That global is only emitted when the
backend's _inject_bootstrap path fires, which since the security
pass on this branch is gated behind UNSLOTH_STUDIO_INJECT_BOOTSTRAP
and defaults to off. Without the global, the React form's
current_password validator never satisfies, the submit button stays
disabled, and the composer.wait_for() probe times out on
/change-password.
Re-enable injection only for the CI Studios that drive the chat UI
across linux/mac/windows. Production deployments are unaffected: the
env var has to be explicitly opted into, and the on-disk
auth/.bootstrap_password remains the source of truth for human users
typing the password in by hand.
Covers all eight Studio launch sites: the primary chat-ui boot and
the "extra UI tests" boot for each of the three OSes, plus the
pipeTransport JSON-crash retry relaunches in the macOS workflow that
re-spawn Studio mid-job.
A follow-up frontend PR will add a visible current_password input so
the form satisfies its own validator without needing the bootstrap
auto-fill at all; once that lands this CI knob can come back out.
* studio/sandbox: drop unshare(CLONE_NEWNET); add trusted-host allowlist; block sandbox file uploads; raise CPU rlimit default to 600 s
CLONE_NEWNET inside _sandbox_preexec silently killed every outbound
HTTP request from sandboxed Python whenever the kernel allowed
unprivileged user namespaces. requests.get('https://huggingface.co'),
urllib.request.urlopen('https://en.wikipedia.org/wiki/...'),
socket.connect(('arxiv.org', 443)) all failed despite the AST visitor
intending to allow them. The bash blocklist (curl / wget / nc / ssh /
scp / sftp / rsync / socat / eval / source) plus the AST-level
metadata-host denylist still carry the network policy after this
change; CLONE_NEWNET was redundant with both.
Add _TRUSTED_PUBLIC_HOST_LITERALS + _TRUSTED_PUBLIC_HOST_SUFFIXES
(~100 informational hosts: Wikipedia language subdomains, Wikimedia,
Wikidata, Google search, Bing, DuckDuckGo, HuggingFace, GitHub,
raw.githubusercontent.com, arXiv, StackOverflow / Stack Exchange,
MDN, docs.python.org, PyTorch / TensorFlow / NumPy / pandas docs,
pypi / files.pythonhosted.org / npmjs / crates.io, ReadTheDocs,
arXiv, Britannica, BBC / Reuters / Nature / Science, NASA / CDC /
NIH / WHO open data, api.weather.gov). The visitor now blocks
literal hosts that are neither metadata nor trusted with a short
LLM-readable string so the model can retry with an allowed source
instead of choking on a multi-line error.
Block upload-shape calls regardless of host: requests.post / put /
patch / delete / request with files= or data=open(...) /
data=bytes_literal; httpx equivalents; urllib.request.urlopen /
Request with data=...; HuggingFace upload_file / upload_folder /
upload_large_folder / create_commit (module-level FQ paths AND
method-name match on any receiver). Message: "Blocked: file upload
disallowed in sandbox".
Bump UNSLOTH_STUDIO_SANDBOX_CPU_S default 300 -> 600 s so long
agentic chains that span multiple tool calls don't get SIGXCPU'd
mid-stride. Env-var override path is unchanged.
Host normalisation now strips trailing dot, userinfo @, and explicit
port before allowlist / denylist comparison so trailing-DNS-dot,
userinfo-smuggling, and explicit-:443 URLs are decided correctly.
* studio: raise default request-body cap from 100 MB to 500 MB
UNSLOTH_STUDIO_MAX_BODY_MB default goes 100 -> 500 to comfortably
cover vision + audio + multi-recipe-batch JSON payloads. The
MaxBodyMiddleware stream-counting logic from this branch's earlier
06ec088 already handles chunked bodies up to the new cap; env-var
override path is unchanged for callers that want a tighter limit.
* studio/auth: restore /api/auth/status.default_username to 'unsloth'
This branch's earlier b39e9a4 changed default_username to None on the
public /api/auth/status endpoint so the username field didn't leak to
unauthenticated callers. In practice this regressed third-party
clients (and the in-tree React login form's pre-fill UX) without
adding meaningful security: the bootstrap password is the actual
secret, and the username 'unsloth' is the documented default.
Pin default_username to storage.DEFAULT_ADMIN_USERNAME ('unsloth')
and tighten the response model so the field is required rather than
Optional. Anyone who needs anonymisation can still reach for an
allow-list deployment with auth disabled.
* studio/training: raise max_seq_length / batch_size / lora_r / lora_alpha caps
This branch's 7102815 introduced field validators with conservative
caps. The follow-up loosens them so long-context experiments and
high-rank LoRA exploration aren't gated at the schema layer:
_MAX_BATCH_SIZE 1024 -> 4096
_MAX_SEQ_LENGTH 131_072 -> 2_000_000 (2M tokens)
lora_r cap 512 -> 16_384 (_MAX_LORA_R)
lora_alpha cap 1024 -> 32_768 (_MAX_LORA_ALPHA)
_MAX_GRAD_ACCUM / _MAX_STEPS / _MAX_EPOCHS / lora_dropout /
warmup_ratio / weight_decay are unchanged. Hardware (VRAM, host
RAM, kernel launch latency) is now the binding constraint at the
new caps, which is the correct ordering -- the validator stays a
sanity check on -1 / 0 / 'abc' style garbage, not a usability gate.
* studio/tests: cover sandbox allowlist + upload block + raised training caps
studio/backend/tests/test_sandbox_tools.py (new):
TestMetadataHostDenylist -- short "Blocked: cloud-metadata host"
message on AWS IMDS, GCP metadata,
Alibaba ECS, AWS IPv6 IMDS, 169.254/16.
TestTrustedHostAllowlist -- Wikipedia (any language subdomain),
Google, DuckDuckGo, HF, raw GitHub,
arXiv, StackOverflow / family,
MDN, docs.python.org, pypi, BBC,
api.weather.gov, NumPy / PyTorch docs.
TestUntrustedHostBlock -- example.com / random unlisted host
rejected with the short "Blocked: host
not in sandbox allowlist; use an
allowed informational source" message.
Dynamic URLs (computed var) still pass
-- documented limit of static analysis.
TestHostNormalization -- trailing dot, explicit :443, uppercase,
userinfo-@-smuggle all decided
correctly without false-block /
false-pass.
TestUploadDenylist -- requests / httpx / urllib.urlopen with
files= / data=open / data=bytes,
HfApi().upload_file / upload_folder /
create_commit, module-level
huggingface_hub.upload_folder. POST
json= to trusted host still passes.
TestSandboxCpuRlimitDefault -- pin UNSLOTH_STUDIO_SANDBOX_CPU_S=600
default and confirm CLONE_NEWNET
source line is gone.
TestMaxBodyDefault -- pin UNSLOTH_STUDIO_MAX_BODY_MB=500
default.
studio/backend/tests/test_studio_train_validation.py (new):
Pin at-cap-accepts / over-cap-rejects boundaries for
max_seq_length=2_000_000, batch_size=4_096, lora_r=16_384,
lora_alpha=32_768 so a future regression that tightens them back
without explicit user opt-in is caught.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* studio: tighten code comments across the security-hardening pass
* studio: always inject bootstrap credentials on first boot
The UNSLOTH_STUDIO_INJECT_BOOTSTRAP gate added an extra
terminal-to-browser copy-paste on every fresh install. In practice
the LAN credential leak it guarded against is narrow: the password
is one-time, the user rotates it on the very next click, the
default Studio bind is 127.0.0.1, and -H 0.0.0.0 already exposes
the entire API surface. Drop the gate so the inject fires whenever
a bootstrap password is still pending. The CSP nonce wiring stays
in place; the inline script remains the only inline script the
backend ever emits.
The three Playwright UI smoke workflows lose their
UNSLOTH_STUDIO_INJECT_BOOTSTRAP=1 lines along with the explanatory
comment blocks since the inject now happens by default.
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wasim Yousef Said <wasimysdev@gmail.com>
This commit is contained in:
parent
ef9f672fe8
commit
0881a7a5d7
21 changed files with 2121 additions and 166 deletions
|
|
@ -480,6 +480,37 @@ def save_refresh_token(
|
|||
conn.close()
|
||||
|
||||
|
||||
def consume_refresh_token(token: str) -> Optional[Tuple[str, bool]]:
|
||||
"""Atomically validate-and-delete a refresh token for single-use rotation.
|
||||
|
||||
DELETE RETURNING fuses validate and delete into one statement so two
|
||||
concurrent refresh requests cannot both consume the same token.
|
||||
"""
|
||||
token_hash = _hash_token(token)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"DELETE FROM refresh_tokens WHERE expires_at < ?",
|
||||
(now,),
|
||||
)
|
||||
cur = conn.execute(
|
||||
"""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE token_hash = ? AND expires_at >= ?
|
||||
RETURNING username, is_desktop
|
||||
""",
|
||||
(token_hash, now),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
if row is None:
|
||||
return None
|
||||
return row["username"], bool(row["is_desktop"])
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[Tuple[str, bool]]:
|
||||
"""
|
||||
Verify a refresh token and return the username plus desktop marker.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Supports web search (DuckDuckGo), Python code execution, and terminal commands.
|
|||
import ast
|
||||
import http.client
|
||||
import os
|
||||
import signal
|
||||
|
||||
os.environ["UNSLOTH_IS_PRESENT"] = "1"
|
||||
|
||||
|
|
@ -58,21 +59,37 @@ _MAX_OUTPUT_CHARS = 8000 # truncate long output
|
|||
_BLOCKED_COMMANDS_COMMON = frozenset(
|
||||
{
|
||||
"rm",
|
||||
"sudo",
|
||||
"su",
|
||||
"dd",
|
||||
"chmod",
|
||||
"chown",
|
||||
"mkfs",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"passwd",
|
||||
"mount",
|
||||
"umount",
|
||||
"fdisk",
|
||||
"sudo",
|
||||
"su",
|
||||
"doas",
|
||||
"pkexec",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"halt",
|
||||
"poweroff",
|
||||
"kill",
|
||||
"killall",
|
||||
"pkill",
|
||||
"passwd",
|
||||
"curl",
|
||||
"wget",
|
||||
"nc",
|
||||
"ncat",
|
||||
"netcat",
|
||||
"socat",
|
||||
"ssh",
|
||||
"scp",
|
||||
"sftp",
|
||||
"rsync",
|
||||
"eval",
|
||||
"source",
|
||||
}
|
||||
)
|
||||
_BLOCKED_COMMANDS_WIN = frozenset(
|
||||
|
|
@ -221,35 +238,67 @@ def _build_safe_env(workdir: str) -> dict[str, str]:
|
|||
|
||||
|
||||
def _sandbox_preexec():
|
||||
"""Pre-exec hook: drop privilege escalation ability and set resource limits.
|
||||
"""Best-effort sandbox setup for sandboxed subprocesses.
|
||||
|
||||
On Linux, applies PR_SET_NO_NEW_PRIVS so sudo/su/pkexec fail at the
|
||||
kernel level. On Linux and macOS, sets RLIMIT_FSIZE.
|
||||
No-op on Windows (use creationflags instead).
|
||||
|
||||
Note: RLIMIT_NPROC is intentionally NOT set because Linux enforces it
|
||||
per real UID, not per process tree, so it would starve the Studio
|
||||
server and other sessions sharing the same user account.
|
||||
|
||||
All modules and handles are resolved at import time (module level) so
|
||||
this function does not trigger Python imports in the forked child,
|
||||
avoiding potential deadlocks in multi-threaded servers.
|
||||
Modules are resolved at import time so the forked child runs no imports.
|
||||
"""
|
||||
try:
|
||||
os.setsid()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
os.umask(0o077)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if _libc is not None:
|
||||
try:
|
||||
# PR_SET_NO_NEW_PRIVS = 38, arg2 = 1 (enable)
|
||||
_libc.prctl(38, 1, 0, 0, 0)
|
||||
_libc.prctl(38, 1, 0, 0, 0) # PR_SET_NO_NEW_PRIVS
|
||||
except (OSError, AttributeError):
|
||||
pass # Not available (container, old kernel, etc.)
|
||||
pass
|
||||
|
||||
try:
|
||||
_libc.prctl(1, 9, 0, 0, 0) # PR_SET_PDEATHSIG = SIGKILL
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
# CLONE_NEWNET intentionally not applied: where userns is enabled it
|
||||
# blocks all egress, including allowlisted hosts. Network policy is
|
||||
# enforced by the AST host check and the bash blocklist.
|
||||
|
||||
if _resource is not None:
|
||||
# RLIMIT_NPROC is per-real-UID, so the cap is well above normal usage.
|
||||
try:
|
||||
nproc = int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_NPROC", "10000"))
|
||||
_resource.setrlimit(_resource.RLIMIT_NPROC, (nproc, nproc))
|
||||
except (ValueError, OSError, AttributeError):
|
||||
pass
|
||||
try:
|
||||
# Limit file size to 100MB (prevents disk filling)
|
||||
_resource.setrlimit(
|
||||
_resource.RLIMIT_FSIZE, (100 * 1024 * 1024, 100 * 1024 * 1024)
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
try:
|
||||
as_bytes = (
|
||||
int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_AS_GB", "8"))
|
||||
* 1024
|
||||
* 1024
|
||||
* 1024
|
||||
)
|
||||
_resource.setrlimit(_resource.RLIMIT_AS, (as_bytes, as_bytes))
|
||||
except (ValueError, OSError, AttributeError):
|
||||
pass
|
||||
try:
|
||||
cpu_s = int(os.environ.get("UNSLOTH_STUDIO_SANDBOX_CPU_S", "600"))
|
||||
_resource.setrlimit(_resource.RLIMIT_CPU, (cpu_s, cpu_s))
|
||||
except (ValueError, OSError, AttributeError):
|
||||
pass
|
||||
try:
|
||||
_resource.setrlimit(_resource.RLIMIT_NOFILE, (1024, 1024))
|
||||
except (ValueError, OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def _get_shell_cmd(command: str) -> list[str]:
|
||||
|
|
@ -265,25 +314,36 @@ def _get_shell_cmd(command: str) -> list[str]:
|
|||
_workdirs: dict[str, str] = {}
|
||||
|
||||
|
||||
# Non-matching session_ids collapse to ``_invalid`` to block cross-session escapes.
|
||||
_SESSION_ID_RE = re.compile(r"\A[A-Za-z0-9_\-]{1,64}\Z")
|
||||
|
||||
|
||||
def _get_workdir(session_id: str | None = None) -> str:
|
||||
"""Return (and lazily create) a persistent working directory for tool execution."""
|
||||
"""Return a per-session sandbox dir at mode 0o700."""
|
||||
global _workdirs
|
||||
key = session_id or "_default"
|
||||
if key not in _workdirs or not os.path.isdir(_workdirs[key]):
|
||||
home = os.path.expanduser("~")
|
||||
sandbox_root = os.path.join(home, "studio_sandbox")
|
||||
if session_id:
|
||||
# Sanitize: strip path separators and parent-dir references
|
||||
safe_id = os.path.basename(session_id.replace("..", ""))
|
||||
if not safe_id:
|
||||
safe_id = "_invalid"
|
||||
workdir = os.path.join(sandbox_root, safe_id)
|
||||
# Verify resolved path stays under sandbox root
|
||||
if not os.path.realpath(workdir).startswith(os.path.realpath(sandbox_root)):
|
||||
if session_id and _SESSION_ID_RE.match(session_id):
|
||||
workdir = os.path.join(sandbox_root, session_id)
|
||||
if not os.path.realpath(workdir).startswith(
|
||||
os.path.realpath(sandbox_root) + os.sep
|
||||
):
|
||||
workdir = os.path.join(sandbox_root, "_invalid")
|
||||
elif session_id:
|
||||
workdir = os.path.join(sandbox_root, "_invalid")
|
||||
else:
|
||||
workdir = os.path.join(sandbox_root, "_default")
|
||||
os.makedirs(workdir, exist_ok = True)
|
||||
try:
|
||||
os.chmod(sandbox_root, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.chmod(workdir, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
_workdirs[key] = workdir
|
||||
return _workdirs[key]
|
||||
|
||||
|
|
@ -932,7 +992,12 @@ def _check_signal_escape_patterns(code: str):
|
|||
isinstance(shell_node, ast.Constant)
|
||||
and shell_node.value is False
|
||||
)
|
||||
if shell_func in _STRING_SHELL_FUNCS or not shell_safe:
|
||||
# Dynamic shell-exec args (chr/format/concat bypasses).
|
||||
if (
|
||||
shell_func in _STRING_SHELL_FUNCS
|
||||
or shell_func in _SHELL_EXEC_FUNCS
|
||||
or not shell_safe
|
||||
):
|
||||
|
||||
def _is_safe_literal(n):
|
||||
if _extract_string_from_node(n) is not None:
|
||||
|
|
@ -1006,15 +1071,418 @@ def _check_signal_escape_patterns(code: str):
|
|||
if visitor.imports_signal and not signal_tampering:
|
||||
warnings.append("Code imports 'signal' module - review manually for safety")
|
||||
|
||||
# Static host policy: block metadata hosts and any literal host outside
|
||||
# the trusted allowlist; uploads blocked regardless of host. Dynamic hosts
|
||||
# are caught by the bash blocklist instead.
|
||||
network_calls: list[dict] = []
|
||||
sensitive_file_reads: list[dict] = []
|
||||
_NETWORK_FQ_PREFIXES = (
|
||||
"socket.socket",
|
||||
"socket.create_connection",
|
||||
"socket.getaddrinfo",
|
||||
"urllib.request.urlopen",
|
||||
"urllib.request.urlretrieve",
|
||||
"urllib3.",
|
||||
"requests.get",
|
||||
"requests.post",
|
||||
"requests.put",
|
||||
"requests.delete",
|
||||
"requests.patch",
|
||||
"requests.head",
|
||||
"requests.request",
|
||||
"requests.Session",
|
||||
"http.client.HTTPConnection",
|
||||
"http.client.HTTPSConnection",
|
||||
"httpx.get",
|
||||
"httpx.post",
|
||||
"httpx.put",
|
||||
"httpx.patch",
|
||||
"httpx.delete",
|
||||
"httpx.request",
|
||||
"httpx.Client",
|
||||
"httpx.AsyncClient",
|
||||
"aiohttp.ClientSession",
|
||||
)
|
||||
_UPLOAD_HTTP_METHODS = (
|
||||
"requests.post",
|
||||
"requests.put",
|
||||
"requests.patch",
|
||||
"requests.delete",
|
||||
"requests.request",
|
||||
"httpx.post",
|
||||
"httpx.put",
|
||||
"httpx.patch",
|
||||
"httpx.delete",
|
||||
"httpx.request",
|
||||
"urllib.request.urlopen",
|
||||
"urllib.request.Request",
|
||||
)
|
||||
_UPLOAD_HF_FQ = (
|
||||
"huggingface_hub.upload_file",
|
||||
"huggingface_hub.upload_folder",
|
||||
"huggingface_hub.upload_large_folder",
|
||||
"huggingface_hub.create_commit",
|
||||
)
|
||||
_UPLOAD_HF_METHODS = frozenset(
|
||||
{
|
||||
"upload_file",
|
||||
"upload_folder",
|
||||
"upload_large_folder",
|
||||
"create_commit",
|
||||
}
|
||||
)
|
||||
# Cloud-metadata / link-local hosts.
|
||||
_METADATA_HOST_LITERALS = {
|
||||
"169.254.169.254",
|
||||
"fd00:ec2::254",
|
||||
"metadata.google.internal",
|
||||
"metadata",
|
||||
"metadata.tencentyun.com",
|
||||
"100.100.100.200",
|
||||
"100.100.100.110",
|
||||
"169.254.170.2",
|
||||
"169.254.170.23",
|
||||
}
|
||||
_METADATA_HOST_PREFIXES = (
|
||||
"169.254.",
|
||||
"100.64.",
|
||||
)
|
||||
# Allowlist kept explicit so each entry is auditable.
|
||||
_TRUSTED_PUBLIC_HOST_LITERALS = frozenset(
|
||||
{
|
||||
# search
|
||||
"www.google.com",
|
||||
"google.com",
|
||||
"www.bing.com",
|
||||
"bing.com",
|
||||
"duckduckgo.com",
|
||||
"html.duckduckgo.com",
|
||||
# encyclopedic / reference
|
||||
"wikipedia.org",
|
||||
"www.wikipedia.org",
|
||||
"wikimedia.org",
|
||||
"www.wikimedia.org",
|
||||
"wikidata.org",
|
||||
"www.wikidata.org",
|
||||
"commons.wikimedia.org",
|
||||
"www.britannica.com",
|
||||
"openlibrary.org",
|
||||
"www.openstreetmap.org",
|
||||
# ML / dev / data
|
||||
"huggingface.co",
|
||||
"hf.co",
|
||||
"github.com",
|
||||
"api.github.com",
|
||||
"raw.githubusercontent.com",
|
||||
"gist.github.com",
|
||||
"docs.github.com",
|
||||
"pypi.org",
|
||||
"files.pythonhosted.org",
|
||||
"www.npmjs.com",
|
||||
"registry.npmjs.org",
|
||||
"crates.io",
|
||||
"static.crates.io",
|
||||
# docs
|
||||
"docs.python.org",
|
||||
"python.org",
|
||||
"www.python.org",
|
||||
"developer.mozilla.org",
|
||||
"developer.apple.com",
|
||||
"learn.microsoft.com",
|
||||
"docs.docker.com",
|
||||
"pytorch.org",
|
||||
"docs.pytorch.org",
|
||||
"tensorflow.org",
|
||||
"www.tensorflow.org",
|
||||
"numpy.org",
|
||||
"pandas.pydata.org",
|
||||
"scipy.org",
|
||||
"scikit-learn.org",
|
||||
"matplotlib.org",
|
||||
"fastapi.tiangolo.com",
|
||||
"starlette.io",
|
||||
# academic
|
||||
"arxiv.org",
|
||||
"export.arxiv.org",
|
||||
"scholar.google.com",
|
||||
"openreview.net",
|
||||
"semanticscholar.org",
|
||||
"www.semanticscholar.org",
|
||||
"biorxiv.org",
|
||||
"www.biorxiv.org",
|
||||
"medrxiv.org",
|
||||
"www.medrxiv.org",
|
||||
"pubmed.ncbi.nlm.nih.gov",
|
||||
"www.ncbi.nlm.nih.gov",
|
||||
# Q&A / community
|
||||
"stackoverflow.com",
|
||||
"stackexchange.com",
|
||||
"askubuntu.com",
|
||||
"superuser.com",
|
||||
"serverfault.com",
|
||||
# standards
|
||||
"www.w3.org",
|
||||
"tools.ietf.org",
|
||||
"datatracker.ietf.org",
|
||||
"www.rfc-editor.org",
|
||||
# reputable news
|
||||
"www.bbc.com",
|
||||
"www.bbc.co.uk",
|
||||
"www.reuters.com",
|
||||
"apnews.com",
|
||||
"www.nature.com",
|
||||
"www.science.org",
|
||||
# government / open data
|
||||
"data.gov",
|
||||
"catalog.data.gov",
|
||||
"www.census.gov",
|
||||
"www.nasa.gov",
|
||||
"data.nasa.gov",
|
||||
"www.cdc.gov",
|
||||
"www.nih.gov",
|
||||
"www.who.int",
|
||||
# weather / time
|
||||
"api.weather.gov",
|
||||
"worldtimeapi.org",
|
||||
}
|
||||
)
|
||||
_TRUSTED_PUBLIC_HOST_SUFFIXES = (
|
||||
".wikipedia.org",
|
||||
".wikimedia.org",
|
||||
".wiktionary.org",
|
||||
".wikibooks.org",
|
||||
".wikiquote.org",
|
||||
".wikisource.org",
|
||||
".wikiversity.org",
|
||||
".wikivoyage.org",
|
||||
".stackexchange.com",
|
||||
".hf.co",
|
||||
".huggingface.co",
|
||||
".githubusercontent.com",
|
||||
".github.io",
|
||||
".arxiv.org",
|
||||
".readthedocs.io",
|
||||
".readthedocs.org",
|
||||
)
|
||||
_SENSITIVE_FILE_PREFIXES = (
|
||||
"/etc/passwd",
|
||||
"/etc/shadow",
|
||||
"/etc/sudoers",
|
||||
"/etc/ssh/",
|
||||
)
|
||||
_SENSITIVE_FILE_RE = re.compile(
|
||||
r"^/proc/(?:self|\d+)/(?:environ|cmdline|task/\d+/environ)$"
|
||||
)
|
||||
|
||||
def _normalize_host(host: str) -> str:
|
||||
if not host:
|
||||
return ""
|
||||
h = host.strip().lower().rstrip(".")
|
||||
if "@" in h:
|
||||
h = h.split("@", 1)[1]
|
||||
if h.startswith("[") and "]" in h:
|
||||
h = h[1 : h.index("]")]
|
||||
elif h.count(":") == 1:
|
||||
h = h.split(":", 1)[0]
|
||||
return h
|
||||
|
||||
def _is_metadata_host(host: str) -> bool:
|
||||
h = _normalize_host(host)
|
||||
if not h:
|
||||
return False
|
||||
if h in _METADATA_HOST_LITERALS:
|
||||
return True
|
||||
if any(h.startswith(p) for p in _METADATA_HOST_PREFIXES):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_trusted_host(host: str) -> bool:
|
||||
h = _normalize_host(host)
|
||||
if not h:
|
||||
return False
|
||||
if h in _TRUSTED_PUBLIC_HOST_LITERALS:
|
||||
return True
|
||||
return any(h.endswith(s) for s in _TRUSTED_PUBLIC_HOST_SUFFIXES)
|
||||
|
||||
def _call_is_upload_shape(node: ast.Call, fq: str) -> bool:
|
||||
"""True for statically obvious upload shapes (files=, data=open(), bytes literal)."""
|
||||
if fq in _UPLOAD_HF_FQ:
|
||||
return True
|
||||
if fq not in _UPLOAD_HTTP_METHODS:
|
||||
return False
|
||||
for kw in node.keywords or []:
|
||||
if kw.arg == "files":
|
||||
return True
|
||||
if kw.arg == "data":
|
||||
v = kw.value
|
||||
if (
|
||||
isinstance(v, ast.Call)
|
||||
and isinstance(v.func, ast.Name)
|
||||
and v.func.id == "open"
|
||||
):
|
||||
return True
|
||||
if isinstance(v, ast.Constant) and isinstance(
|
||||
v.value, (bytes, bytearray)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _method_call_is_hf_upload(node: ast.Call) -> bool:
|
||||
"""True for HfApi upload method names on any receiver."""
|
||||
return (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr in _UPLOAD_HF_METHODS
|
||||
)
|
||||
|
||||
class NetworkAndIoVisitor(ast.NodeVisitor):
|
||||
def visit_Call(self, node):
|
||||
parts: list[str] = []
|
||||
cur = node.func
|
||||
while isinstance(cur, ast.Attribute):
|
||||
parts.insert(0, cur.attr)
|
||||
cur = cur.value
|
||||
if isinstance(cur, ast.Name):
|
||||
parts.insert(0, cur.id)
|
||||
fq = ".".join(parts) if parts else ""
|
||||
|
||||
if _method_call_is_hf_upload(node):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "upload_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": ("Blocked: file upload disallowed in sandbox"),
|
||||
}
|
||||
)
|
||||
|
||||
# Direct sock.connect((host, port)) bypasses the FQ-prefix branch below.
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == "connect"
|
||||
and node.args
|
||||
):
|
||||
a0 = node.args[0]
|
||||
host_lit = None
|
||||
if isinstance(a0, ast.Tuple) and a0.elts:
|
||||
e0 = a0.elts[0]
|
||||
if isinstance(e0, ast.Constant) and isinstance(e0.value, str):
|
||||
host_lit = e0.value
|
||||
elif isinstance(a0, ast.Constant) and isinstance(a0.value, str):
|
||||
host_lit = a0.value
|
||||
if host_lit:
|
||||
if _is_metadata_host(host_lit):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "metadata_host_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": "Blocked: cloud-metadata host",
|
||||
}
|
||||
)
|
||||
elif not _is_trusted_host(host_lit):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "untrusted_host_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": (
|
||||
"Blocked: host not in sandbox allowlist; "
|
||||
"use an allowed informational source"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if fq and any(fq.startswith(p) for p in _NETWORK_FQ_PREFIXES):
|
||||
# 1) Upload-shape check (host-independent).
|
||||
if _call_is_upload_shape(node, fq):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "upload_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": (
|
||||
"Blocked: file upload disallowed in sandbox"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# 2) Extract literal host (URL string or (host, port) tuple).
|
||||
host_arg = None
|
||||
url_arg = None
|
||||
if node.args:
|
||||
a0 = node.args[0]
|
||||
if isinstance(a0, ast.Constant) and isinstance(a0.value, str):
|
||||
url_arg = a0.value
|
||||
elif isinstance(a0, ast.Tuple) and a0.elts:
|
||||
e0 = a0.elts[0]
|
||||
if isinstance(e0, ast.Constant) and isinstance(e0.value, str):
|
||||
host_arg = e0.value
|
||||
if url_arg and host_arg is None:
|
||||
m = re.match(r"^\w+://([^/?#]+)", url_arg)
|
||||
if m:
|
||||
host_arg = m.group(1)
|
||||
|
||||
if host_arg:
|
||||
if _is_metadata_host(host_arg):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "metadata_host_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": "Blocked: cloud-metadata host",
|
||||
}
|
||||
)
|
||||
elif not _is_trusted_host(host_arg):
|
||||
network_calls.append(
|
||||
{
|
||||
"type": "untrusted_host_blocked",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": (
|
||||
"Blocked: host not in sandbox allowlist; "
|
||||
"use an allowed informational source"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
is_open_call = (
|
||||
(isinstance(node.func, ast.Name) and node.func.id == "open")
|
||||
or fq in ("io.open", "pathlib.Path.open")
|
||||
or fq.endswith(".open")
|
||||
)
|
||||
if is_open_call and node.args:
|
||||
a0 = node.args[0]
|
||||
path_lit = None
|
||||
if isinstance(a0, ast.Constant) and isinstance(a0.value, str):
|
||||
path_lit = a0.value
|
||||
if path_lit:
|
||||
flagged = False
|
||||
if any(path_lit.startswith(p) for p in _SENSITIVE_FILE_PREFIXES):
|
||||
flagged = True
|
||||
elif _SENSITIVE_FILE_RE.match(path_lit):
|
||||
flagged = True
|
||||
if flagged:
|
||||
sensitive_file_reads.append(
|
||||
{
|
||||
"type": "sensitive_file_read",
|
||||
"line": getattr(node, "lineno", -1),
|
||||
"description": (
|
||||
f"open({path_lit!r}) targets a host identity / "
|
||||
"credential file; sandboxed code may not read it"
|
||||
),
|
||||
}
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
NetworkAndIoVisitor().visit(tree)
|
||||
|
||||
is_safe = (
|
||||
len(signal_tampering) == 0
|
||||
and len(exception_catching) == 0
|
||||
and len(shell_escapes) == 0
|
||||
and len(network_calls) == 0
|
||||
and len(sensitive_file_reads) == 0
|
||||
)
|
||||
return is_safe, {
|
||||
"signal_tampering": signal_tampering,
|
||||
"exception_catching": exception_catching,
|
||||
"shell_escapes": shell_escapes,
|
||||
"network_calls": network_calls,
|
||||
"sensitive_file_reads": sensitive_file_reads,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
|
@ -1041,7 +1509,21 @@ def _check_code_safety(code: str) -> str | None:
|
|||
exception_reasons = [
|
||||
item.get("description", "") for item in info.get("exception_catching", [])
|
||||
]
|
||||
all_reasons = [r for r in reasons + shell_reasons + exception_reasons if r]
|
||||
network_reasons = [
|
||||
item.get("description", "") for item in info.get("network_calls", [])
|
||||
]
|
||||
file_reasons = [
|
||||
item.get("description", "") for item in info.get("sensitive_file_reads", [])
|
||||
]
|
||||
all_reasons = [
|
||||
r
|
||||
for r in reasons
|
||||
+ shell_reasons
|
||||
+ exception_reasons
|
||||
+ network_reasons
|
||||
+ file_reasons
|
||||
if r
|
||||
]
|
||||
if all_reasons:
|
||||
return (
|
||||
f"Error: unsafe code detected ({'; '.join(all_reasons)}). "
|
||||
|
|
@ -1051,11 +1533,31 @@ def _check_code_safety(code: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _kill_process_tree(proc) -> None:
|
||||
"""SIGKILL the setsid process group; fall back to single-pid kill."""
|
||||
if proc.poll() is not None:
|
||||
return
|
||||
try:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pgid = None
|
||||
if pgid is not None:
|
||||
try:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
return
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
try:
|
||||
proc.kill()
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
|
||||
|
||||
def _cancel_watcher(proc, cancel_event, poll_interval = 0.2):
|
||||
"""Daemon thread that kills a process when cancel_event is set."""
|
||||
while proc.poll() is None:
|
||||
if cancel_event is not None and cancel_event.is_set():
|
||||
proc.kill()
|
||||
_kill_process_tree(proc)
|
||||
return
|
||||
cancel_event.wait(poll_interval) if cancel_event else None
|
||||
|
||||
|
|
@ -1126,8 +1628,11 @@ def _python_exec(
|
|||
try:
|
||||
output, _ = proc.communicate(timeout = timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.communicate()
|
||||
_kill_process_tree(proc)
|
||||
try:
|
||||
proc.communicate(timeout = 5)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
return _truncate(f"Execution timed out after {timeout} seconds.")
|
||||
|
||||
if cancel_event is not None and cancel_event.is_set():
|
||||
|
|
@ -1211,8 +1716,11 @@ def _bash_exec(
|
|||
try:
|
||||
output, _ = proc.communicate(timeout = timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.communicate()
|
||||
_kill_process_tree(proc)
|
||||
try:
|
||||
proc.communicate(timeout = 5)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
return _truncate(f"Execution timed out after {timeout} seconds.")
|
||||
|
||||
if cancel_event is not None and cancel_event.is_set():
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ Pattern follows core/data_recipe/jobs/manager.py.
|
|||
import json as _json
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import structlog
|
||||
|
|
@ -33,9 +35,54 @@ from utils.native_path_leases import (
|
|||
native_path_secret_removed_for_child_start,
|
||||
run_without_native_path_secret,
|
||||
)
|
||||
from utils.paths import outputs_root
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _cleanup_cancelled_checkpoints(output_dir: str | os.PathLike) -> None:
|
||||
"""Remove ``checkpoint-<int>`` subdirs after a cancelled run.
|
||||
Only paths whose realpath is under outputs_root are touched."""
|
||||
out = Path(output_dir)
|
||||
if not out.exists():
|
||||
return
|
||||
try:
|
||||
out_real = out.resolve()
|
||||
out_root_real = Path(outputs_root()).resolve()
|
||||
except OSError:
|
||||
return
|
||||
try:
|
||||
out_real.relative_to(out_root_real)
|
||||
except ValueError:
|
||||
# Refuse to delete anything outside the configured outputs root.
|
||||
logger.warning(
|
||||
"Skipping checkpoint cleanup - %s is not under outputs_root %s",
|
||||
out_real,
|
||||
out_root_real,
|
||||
)
|
||||
return
|
||||
removed = 0
|
||||
for entry in out.iterdir() if out.is_dir() else []:
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
name = entry.name
|
||||
if not name.startswith("checkpoint-"):
|
||||
continue
|
||||
tail = name[len("checkpoint-") :]
|
||||
if not tail.isdigit():
|
||||
continue
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors = False)
|
||||
removed += 1
|
||||
except OSError as exc:
|
||||
logger.warning("Could not remove %s: %s", entry, exc)
|
||||
logger.info(
|
||||
"Cancelled-run cleanup removed %d checkpoint dir(s) under %s",
|
||||
removed,
|
||||
out,
|
||||
)
|
||||
|
||||
|
||||
_CTX = mp.get_context("spawn")
|
||||
|
||||
# Plot styling constants
|
||||
|
|
@ -316,6 +363,8 @@ class TrainingBackend:
|
|||
)
|
||||
self._proc.terminate()
|
||||
proc = self._proc
|
||||
cancelled = self._cancel_requested
|
||||
output_dir = self._output_dir
|
||||
|
||||
if proc is not None:
|
||||
proc.join(timeout = 5.0)
|
||||
|
|
@ -328,6 +377,17 @@ class TrainingBackend:
|
|||
if self._pump_thread is not None and self._pump_thread.is_alive():
|
||||
self._pump_thread.join(timeout = 8.0)
|
||||
|
||||
# Drop checkpoint-* dirs on explicit cancel only; stop-and-save
|
||||
# keeps its artifacts.
|
||||
if cancelled and output_dir:
|
||||
try:
|
||||
_cleanup_cancelled_checkpoints(output_dir)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean up cancelled-run checkpoints under %s",
|
||||
output_dir,
|
||||
)
|
||||
|
||||
def is_training_active(self) -> bool:
|
||||
"""Check if training is currently active."""
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ if os.getenv("ENVIRONMENT_TYPE", "production") == "production":
|
|||
# warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
# warnings.filterwarnings("ignore", module="triton.*")
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, HTMLResponse, Response
|
||||
|
|
@ -260,6 +260,181 @@ logger = LogConfig.setup_logging(
|
|||
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
|
||||
# Web-search favicons load from *.gstatic.com; everything else is same-origin.
|
||||
from starlette.middleware.base import BaseHTTPMiddleware # noqa: E402
|
||||
from starlette.requests import Request as _StarletteRequest # noqa: E402
|
||||
|
||||
|
||||
_CSP_SCRIPT_NONCE_HEADER = "x-internal-script-nonce"
|
||||
|
||||
|
||||
def _build_csp(script_nonce: "str | None" = None) -> str:
|
||||
script_src = "script-src 'self'"
|
||||
if script_nonce:
|
||||
script_src += f" 'nonce-{script_nonce}'"
|
||||
return (
|
||||
"default-src 'self'; "
|
||||
"img-src 'self' data: blob: https://t0.gstatic.com "
|
||||
"https://t1.gstatic.com https://t2.gstatic.com "
|
||||
"https://t3.gstatic.com; "
|
||||
"connect-src 'self'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
f"{script_src}; "
|
||||
"font-src 'self' data:; "
|
||||
"frame-ancestors 'none'; "
|
||||
"form-action 'self'; "
|
||||
"base-uri 'self'"
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Set baseline security headers; splice per-response inline-script nonces into CSP."""
|
||||
|
||||
async def dispatch(self, request: _StarletteRequest, call_next):
|
||||
response = await call_next(request)
|
||||
# Strip the internal nonce hand-off header so it never reaches the client.
|
||||
nonce = response.headers.get(_CSP_SCRIPT_NONCE_HEADER)
|
||||
if nonce is not None:
|
||||
del response.headers[_CSP_SCRIPT_NONCE_HEADER]
|
||||
response.headers.setdefault("Content-Security-Policy", _build_csp(nonce))
|
||||
response.headers.setdefault("X-Frame-Options", "DENY")
|
||||
response.headers.setdefault("X-Content-Type-Options", "nosniff")
|
||||
response.headers.setdefault("Referrer-Policy", "no-referrer")
|
||||
response.headers.setdefault(
|
||||
"Permissions-Policy",
|
||||
"camera=(), microphone=(), geolocation=(), interest-cohort=()",
|
||||
)
|
||||
response.headers["server"] = "unsloth-studio"
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
# Cap upload body on protected POSTs; default 500 MB, env-tunable.
|
||||
import json as _json_for_413 # noqa: E402
|
||||
|
||||
|
||||
_MAX_BODY_BYTES = int(os.environ.get("UNSLOTH_STUDIO_MAX_BODY_MB", "500")) * 1024 * 1024
|
||||
_BODY_PROTECTED_PREFIXES = (
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/api/inference",
|
||||
"/api/data-recipe",
|
||||
"/api/datasets",
|
||||
"/api/train",
|
||||
"/api/export",
|
||||
)
|
||||
|
||||
|
||||
async def _send_413(send, total_bytes: int) -> None:
|
||||
payload = _json_for_413.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Request body too large "
|
||||
f"({total_bytes:,} bytes; max {_MAX_BODY_BYTES:,})."
|
||||
)
|
||||
},
|
||||
).encode("utf-8")
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 413,
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"content-length", str(len(payload)).encode("ascii")),
|
||||
],
|
||||
}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": payload, "more_body": False})
|
||||
|
||||
|
||||
class MaxBodyMiddleware:
|
||||
"""Reject oversized bodies on protected POST/PUT/PATCH; raw ASGI so chunked uploads cannot bypass the cap."""
|
||||
|
||||
def __init__(self, app, max_bytes: int, protected_prefixes: tuple):
|
||||
self.app = app
|
||||
self.max_bytes = max_bytes
|
||||
self.protected_prefixes = protected_prefixes
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
method = scope.get("method", "").upper()
|
||||
path = scope.get("path", "")
|
||||
if method not in ("POST", "PUT", "PATCH") or not any(
|
||||
path.startswith(p) for p in self.protected_prefixes
|
||||
):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
declared = None
|
||||
for name, value in scope.get("headers", []):
|
||||
if name == b"content-length":
|
||||
try:
|
||||
declared = int(value.decode("latin-1"))
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
declared = None
|
||||
break
|
||||
if declared is not None and declared > self.max_bytes:
|
||||
await _send_413(send, declared)
|
||||
return
|
||||
|
||||
chunks: list = []
|
||||
total = 0
|
||||
while True:
|
||||
msg = await receive()
|
||||
mtype = msg.get("type")
|
||||
if mtype == "http.disconnect":
|
||||
return
|
||||
if mtype != "http.request":
|
||||
# Mid-stream unexpected frame: forwarding would corrupt downstream.
|
||||
return
|
||||
body = msg.get("body", b"") or b""
|
||||
if body:
|
||||
total += len(body)
|
||||
if total > self.max_bytes:
|
||||
await _send_413(send, total)
|
||||
return
|
||||
chunks.append(body)
|
||||
if not msg.get("more_body", False):
|
||||
break
|
||||
|
||||
replayed = {"sent": False}
|
||||
|
||||
async def replay_receive():
|
||||
if not replayed["sent"]:
|
||||
replayed["sent"] = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"".join(chunks),
|
||||
"more_body": False,
|
||||
}
|
||||
# After replay, fall through so http.disconnect still propagates.
|
||||
return await receive()
|
||||
|
||||
await self.app(scope, replay_receive, send)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
MaxBodyMiddleware,
|
||||
max_bytes = _MAX_BODY_BYTES,
|
||||
protected_prefixes = _BODY_PROTECTED_PREFIXES,
|
||||
)
|
||||
|
||||
|
||||
from starlette.responses import RedirectResponse as _RedirectResponse # noqa: E402
|
||||
|
||||
|
||||
@app.get("/recipes", include_in_schema = False)
|
||||
@app.get("/recipes/{rest:path}", include_in_schema = False)
|
||||
async def _recipes_redirect(rest: str = ""):
|
||||
target = "/data-recipes" + (("/" + rest) if rest else "")
|
||||
return _RedirectResponse(url = target, status_code = 308)
|
||||
|
||||
|
||||
# CORS middleware
|
||||
_api_only = os.environ.get("UNSLOTH_API_ONLY") == "1"
|
||||
_cors_origins = ["*"]
|
||||
|
|
@ -311,14 +486,35 @@ app.include_router(
|
|||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
platform_map = {"darwin": "mac", "win32": "windows", "linux": "linux"}
|
||||
device_type = platform_map.get(sys.platform, sys.platform)
|
||||
|
||||
return {
|
||||
async def health_check(request: Request):
|
||||
"""Liveness only; full diagnostic dict gated on a valid bearer."""
|
||||
minimal = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
auth = request.headers.get("authorization", "")
|
||||
if not auth.lower().startswith("bearer "):
|
||||
return minimal
|
||||
try:
|
||||
from auth.authentication import get_current_subject as _gcs
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
creds = HTTPAuthorizationCredentials(
|
||||
scheme = "Bearer", credentials = auth.split(" ", 1)[1]
|
||||
)
|
||||
# Must await: a bare coroutine is truthy and would skip the auth check.
|
||||
subject = await _gcs(creds)
|
||||
except HTTPException:
|
||||
return minimal
|
||||
except Exception:
|
||||
return minimal
|
||||
if not subject:
|
||||
return minimal
|
||||
|
||||
platform_map = {"darwin": "mac", "win32": "windows", "linux": "linux"}
|
||||
device_type = platform_map.get(sys.platform, sys.platform)
|
||||
return {
|
||||
**minimal,
|
||||
"service": "Unsloth UI Backend",
|
||||
"version": UNSLOTH_VERSION,
|
||||
"studio_version": STUDIO_VERSION,
|
||||
|
|
@ -328,9 +524,7 @@ async def health_check():
|
|||
"desktop_manageability_version": 1,
|
||||
"supports_desktop_auth": True,
|
||||
"supports_desktop_backend_ownership": True,
|
||||
# why: launchers compare against an install-time hash so a sibling
|
||||
# Studio on the same port is rejected; hex digest avoids leaking the
|
||||
# raw install path on -H 0.0.0.0.
|
||||
# Hex digest of the install path; launchers reject sibling Studios on the same port.
|
||||
"studio_root_id": _studio_root_id(),
|
||||
"native_path_leases_supported": native_path_leases_supported(),
|
||||
**({"desktop_owner": owner} if (owner := _desktop_owner()) else {}),
|
||||
|
|
@ -463,21 +657,22 @@ def _strip_crossorigin(html_bytes: bytes) -> bytes:
|
|||
return html.encode("utf-8")
|
||||
|
||||
|
||||
def _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes:
|
||||
"""Inject bootstrap credentials into HTML when password change is required.
|
||||
def _inject_bootstrap(html_bytes: bytes, app: FastAPI):
|
||||
"""Inject bootstrap credentials when password change is pending.
|
||||
|
||||
The script tag is only injected while the default admin account still
|
||||
has ``must_change_password=True``. Once the user changes the password
|
||||
the HTML is served clean — no credentials leak.
|
||||
Returns ``(html_bytes, script_nonce_or_None)``. Callers must forward
|
||||
the nonce via ``_CSP_SCRIPT_NONCE_HEADER`` so the inline script is
|
||||
not blocked by CSP.
|
||||
"""
|
||||
import json as _json
|
||||
import secrets as _secrets
|
||||
|
||||
if not storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME):
|
||||
return html_bytes
|
||||
return html_bytes, None
|
||||
|
||||
bootstrap_pw = getattr(app.state, "bootstrap_password", None)
|
||||
if not bootstrap_pw:
|
||||
return html_bytes
|
||||
return html_bytes, None
|
||||
|
||||
payload = _json.dumps(
|
||||
{
|
||||
|
|
@ -485,10 +680,11 @@ def _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes:
|
|||
"password": bootstrap_pw,
|
||||
}
|
||||
)
|
||||
tag = f"<script>window.__UNSLOTH_BOOTSTRAP__={payload}</script>"
|
||||
nonce = _secrets.token_urlsafe(16)
|
||||
tag = f'<script nonce="{nonce}">window.__UNSLOTH_BOOTSTRAP__={payload}</script>'
|
||||
html = html_bytes.decode("utf-8")
|
||||
html = html.replace("</head>", f"{tag}</head>", 1)
|
||||
return html.encode("utf-8")
|
||||
return html.encode("utf-8"), nonce
|
||||
|
||||
|
||||
def setup_frontend(app: FastAPI, build_path: Path):
|
||||
|
|
@ -501,17 +697,23 @@ def setup_frontend(app: FastAPI, build_path: Path):
|
|||
if assets_dir.exists():
|
||||
app.mount("/assets", StaticFiles(directory = assets_dir), name = "assets")
|
||||
|
||||
@app.get("/")
|
||||
async def serve_root():
|
||||
def _build_index_response() -> Response:
|
||||
content = (build_path / "index.html").read_bytes()
|
||||
content = _strip_crossorigin(content)
|
||||
content = _inject_bootstrap(content, app)
|
||||
content, nonce = _inject_bootstrap(content, app)
|
||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"}
|
||||
if nonce:
|
||||
headers[_CSP_SCRIPT_NONCE_HEADER] = nonce
|
||||
return Response(
|
||||
content = content,
|
||||
media_type = "text/html",
|
||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"},
|
||||
headers = headers,
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def serve_root():
|
||||
return _build_index_response()
|
||||
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_frontend(full_path: str):
|
||||
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||
|
|
@ -527,13 +729,6 @@ def setup_frontend(app: FastAPI, build_path: Path):
|
|||
return FileResponse(file_path)
|
||||
|
||||
# Serve index.html as bytes — avoids Content-Length mismatch
|
||||
content = (build_path / "index.html").read_bytes()
|
||||
content = _strip_crossorigin(content)
|
||||
content = _inject_bootstrap(content, app)
|
||||
return Response(
|
||||
content = content,
|
||||
media_type = "text/html",
|
||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate"},
|
||||
)
|
||||
return _build_index_response()
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ class AuthStatusResponse(BaseModel):
|
|||
initialized: bool = Field(
|
||||
..., description = "True if the auth database contains a login user"
|
||||
)
|
||||
default_username: str = Field(..., description = "Default seeded admin username")
|
||||
default_username: str = Field(
|
||||
"unsloth",
|
||||
description = "Default admin username for first-boot UI prefill.",
|
||||
)
|
||||
requires_password_change: bool = Field(
|
||||
...,
|
||||
description = "True if the seeded admin must still change the default password",
|
||||
|
|
|
|||
|
|
@ -5,10 +5,36 @@
|
|||
Pydantic schemas for Export API.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional, Literal, Dict, Any
|
||||
|
||||
|
||||
def _validate_save_directory(value: str) -> str:
|
||||
"""Reject save_directory values that escape the export root."""
|
||||
if value is None:
|
||||
raise ValueError("save_directory is required")
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
raise ValueError("save_directory must not be empty")
|
||||
if "\x00" in raw:
|
||||
raise ValueError("save_directory may not contain null bytes")
|
||||
if any(ch in raw for ch in ("\r", "\n")):
|
||||
raise ValueError("save_directory may not contain control characters")
|
||||
if len(raw) > 255:
|
||||
raise ValueError("save_directory must be <= 255 characters")
|
||||
path = Path(raw).expanduser()
|
||||
if path.is_absolute():
|
||||
raise ValueError(
|
||||
"save_directory must be a name or relative path under the "
|
||||
"export root; absolute paths are rejected"
|
||||
)
|
||||
if ".." in path.parts:
|
||||
raise ValueError("save_directory may not contain '..' segments")
|
||||
return raw
|
||||
|
||||
|
||||
class LoadCheckpointRequest(BaseModel):
|
||||
"""Request for loading a checkpoint into the export backend."""
|
||||
|
||||
|
|
@ -64,6 +90,12 @@ class ExportCommonOptions(BaseModel):
|
|||
...,
|
||||
description = "Local directory where the exported artifacts will be written",
|
||||
)
|
||||
|
||||
@field_validator("save_directory", mode = "before")
|
||||
@classmethod
|
||||
def _check_save_directory(cls, v):
|
||||
return _validate_save_directory(v)
|
||||
|
||||
push_to_hub: bool = Field(
|
||||
False,
|
||||
description = "If True, also push the exported model to the Hugging Face Hub",
|
||||
|
|
@ -108,6 +140,12 @@ class ExportGGUFRequest(BaseModel):
|
|||
...,
|
||||
description = "Directory where GGUF files will be saved",
|
||||
)
|
||||
|
||||
@field_validator("save_directory", mode = "before")
|
||||
@classmethod
|
||||
def _check_save_directory(cls, v):
|
||||
return _validate_save_directory(v)
|
||||
|
||||
quantization_method: str = Field(
|
||||
"Q4_K_M",
|
||||
description = 'GGUF quantization method (e.g. "Q4_K_M")',
|
||||
|
|
|
|||
|
|
@ -425,14 +425,6 @@ class ChatMessage(BaseModel):
|
|||
|
||||
@model_validator(mode = "after")
|
||||
def _validate_role_shape(self) -> "ChatMessage":
|
||||
# Enforce the per-role OpenAI spec shape at the request boundary.
|
||||
# Without this, malformed messages (e.g. user entries with no
|
||||
# content, tool_calls on a user/system role, role="tool" without
|
||||
# tool_call_id) would be silently forwarded to llama-server via
|
||||
# the passthrough path, surfacing as opaque upstream errors or
|
||||
# broken tool-call reconciliation downstream.
|
||||
|
||||
# Tool-call metadata must appear only on the appropriate role.
|
||||
if self.tool_calls is not None and self.role != "assistant":
|
||||
raise ValueError('"tool_calls" is only valid on role="assistant" messages.')
|
||||
if self.tool_call_id is not None and self.role != "tool":
|
||||
|
|
@ -440,23 +432,20 @@ class ChatMessage(BaseModel):
|
|||
if self.name is not None and self.role != "tool":
|
||||
raise ValueError('"name" is only valid on role="tool" messages.')
|
||||
|
||||
# Per-role content requirements. OpenAI-compatible clients may send
|
||||
# ``content=""`` for image-only turns when the image travels in a
|
||||
# companion field such as Studio's ``image_base64`` extension, so treat
|
||||
# empty strings as present content for user/system messages.
|
||||
if self.role == "tool":
|
||||
if not self.tool_call_id:
|
||||
raise ValueError(
|
||||
'role="tool" messages require "tool_call_id" per the OpenAI spec.'
|
||||
)
|
||||
# Frontend's second-round POST drops the streamed id;
|
||||
# synthesise one so the request round-trips.
|
||||
import secrets as _secrets
|
||||
|
||||
self.tool_call_id = f"call_{_secrets.token_hex(8)}"
|
||||
if not self.content:
|
||||
raise ValueError('role="tool" messages require non-empty "content".')
|
||||
elif self.role == "assistant":
|
||||
# Assistant messages may omit content when tool_calls is set.
|
||||
if not self.content and not self.tool_calls:
|
||||
raise ValueError(
|
||||
'role="assistant" messages require either "content" or "tool_calls".'
|
||||
)
|
||||
# Tolerate the post-Stop empty-assistant sentinel by
|
||||
# collapsing content="" to None.
|
||||
if (self.content == "" or self.content == []) and not self.tool_calls:
|
||||
self.content = None
|
||||
else: # "user" | "system"
|
||||
if self.content is None or self.content == []:
|
||||
raise ValueError(f'role="{self.role}" messages require "content".')
|
||||
|
|
|
|||
|
|
@ -5,10 +5,43 @@
|
|||
Pydantic schemas for Training API
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
|
||||
|
||||
_MAX_BATCH_SIZE = 4096
|
||||
_MAX_GRAD_ACCUM = 4096
|
||||
_MAX_STEPS = 1_000_000
|
||||
_MAX_EPOCHS = 1000
|
||||
# 2M is a sanity cap; host RAM runs out long before this.
|
||||
_MAX_SEQ_LENGTH = 2_000_000
|
||||
_MAX_LR_VALUE = 1.0
|
||||
_MAX_LORA_R = 16_384
|
||||
_MAX_LORA_ALPHA = 32_768
|
||||
|
||||
|
||||
def _parse_lr(v: Any) -> float:
|
||||
"""Parse learning_rate as a positive float strictly below _MAX_LR_VALUE."""
|
||||
if v is None:
|
||||
raise ValueError("learning_rate is required")
|
||||
if isinstance(v, bool):
|
||||
raise ValueError("learning_rate must be a number, not a bool")
|
||||
try:
|
||||
lr = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"learning_rate must be parseable as float (got {v!r})")
|
||||
if not (lr > 0.0):
|
||||
raise ValueError(
|
||||
f"learning_rate must be > 0 (got {lr!r}); " "typical range is 1e-6 .. 1e-3"
|
||||
)
|
||||
if lr >= _MAX_LR_VALUE:
|
||||
raise ValueError(
|
||||
f"learning_rate must be < 1.0 (got {lr!r}); "
|
||||
"values that large always diverge training"
|
||||
)
|
||||
return lr
|
||||
|
||||
|
||||
class TrainingStartRequest(BaseModel):
|
||||
"""Request schema for starting training"""
|
||||
|
||||
|
|
@ -64,6 +97,147 @@ class TrainingStartRequest(BaseModel):
|
|||
values.setdefault("train_split", values.pop("split"))
|
||||
return values
|
||||
|
||||
@field_validator("learning_rate", mode = "before")
|
||||
@classmethod
|
||||
def _check_learning_rate(cls, v):
|
||||
# Stringify because downstream call sites float() it themselves.
|
||||
lr = _parse_lr(v)
|
||||
return str(lr)
|
||||
|
||||
@field_validator("batch_size")
|
||||
@classmethod
|
||||
def _check_batch_size(cls, v: int) -> int:
|
||||
if v is None:
|
||||
raise ValueError("batch_size is required")
|
||||
if v < 1 or v > _MAX_BATCH_SIZE:
|
||||
raise ValueError(
|
||||
f"batch_size must be in [1, {_MAX_BATCH_SIZE}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("gradient_accumulation_steps")
|
||||
@classmethod
|
||||
def _check_grad_accum(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 1
|
||||
if v < 1 or v > _MAX_GRAD_ACCUM:
|
||||
raise ValueError(
|
||||
f"gradient_accumulation_steps must be in [1, {_MAX_GRAD_ACCUM}] "
|
||||
f"(got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("num_epochs")
|
||||
@classmethod
|
||||
def _check_num_epochs(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 1
|
||||
if v < 1 or v > _MAX_EPOCHS:
|
||||
raise ValueError(f"num_epochs must be in [1, {_MAX_EPOCHS}] (got {v!r})")
|
||||
return v
|
||||
|
||||
@field_validator("max_steps")
|
||||
@classmethod
|
||||
def _check_max_steps(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, int) or v < 1 or v > _MAX_STEPS:
|
||||
raise ValueError(
|
||||
f"max_steps must be a positive int <= {_MAX_STEPS} (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("max_seq_length")
|
||||
@classmethod
|
||||
def _check_max_seq_length(cls, v: int) -> int:
|
||||
if v is None or v < 1 or v > _MAX_SEQ_LENGTH:
|
||||
raise ValueError(
|
||||
f"max_seq_length must be in [1, {_MAX_SEQ_LENGTH}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("warmup_steps")
|
||||
@classmethod
|
||||
def _check_warmup_steps(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, int) or v < 0 or v > _MAX_STEPS:
|
||||
raise ValueError(
|
||||
f"warmup_steps must be a non-negative int <= {_MAX_STEPS} "
|
||||
f"(got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("warmup_ratio")
|
||||
@classmethod
|
||||
def _check_warmup_ratio(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
r = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"warmup_ratio must be a number (got {v!r})")
|
||||
if not (0.0 <= r <= 1.0):
|
||||
raise ValueError(f"warmup_ratio must be in [0.0, 1.0] (got {r!r})")
|
||||
return r
|
||||
|
||||
@field_validator("save_steps")
|
||||
@classmethod
|
||||
def _check_save_steps(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 100
|
||||
if v < 0 or v > _MAX_STEPS:
|
||||
raise ValueError(f"save_steps must be in [0, {_MAX_STEPS}] (got {v!r})")
|
||||
return v
|
||||
|
||||
@field_validator("weight_decay")
|
||||
@classmethod
|
||||
def _check_weight_decay(cls, v: float) -> float:
|
||||
if v is None:
|
||||
return 0.0
|
||||
try:
|
||||
wd = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"weight_decay must be a number (got {v!r})")
|
||||
if wd < 0 or wd > 10.0:
|
||||
raise ValueError(
|
||||
f"weight_decay must be in [0, 10] (got {wd!r}); typical 0..0.1"
|
||||
)
|
||||
return wd
|
||||
|
||||
@field_validator("lora_r")
|
||||
@classmethod
|
||||
def _check_lora_r(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 16
|
||||
if v < 1 or v > _MAX_LORA_R:
|
||||
raise ValueError(f"lora_r must be in [1, {_MAX_LORA_R}] (got {v!r})")
|
||||
return v
|
||||
|
||||
@field_validator("lora_alpha")
|
||||
@classmethod
|
||||
def _check_lora_alpha(cls, v: int) -> int:
|
||||
if v is None:
|
||||
return 16
|
||||
if v < 1 or v > _MAX_LORA_ALPHA:
|
||||
raise ValueError(
|
||||
f"lora_alpha must be in [1, {_MAX_LORA_ALPHA}] (got {v!r})"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("lora_dropout")
|
||||
@classmethod
|
||||
def _check_lora_dropout(cls, v: float) -> float:
|
||||
if v is None:
|
||||
return 0.0
|
||||
try:
|
||||
d = float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"lora_dropout must be a number (got {v!r})")
|
||||
if not (0.0 <= d < 1.0):
|
||||
raise ValueError(f"lora_dropout must be in [0.0, 1.0) (got {d!r})")
|
||||
return d
|
||||
|
||||
custom_format_mapping: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description = (
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@
|
|||
Authentication API routes
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from models.auth import (
|
||||
|
|
@ -33,14 +36,52 @@ from auth.authentication import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
# In-memory per-IP login rate limiter; multi-process deployment needs a shared store.
|
||||
_LOGIN_BUCKETS: dict[str, deque] = {}
|
||||
_LOGIN_BUCKETS_LOCK = threading.Lock()
|
||||
_LOGIN_WINDOW_SECONDS = 60.0
|
||||
_LOGIN_MAX_FAILS = 5
|
||||
_LOGIN_LOCKOUT_SECONDS = 60
|
||||
|
||||
|
||||
def _client_key(request: Request | None) -> str:
|
||||
if request is None or request.client is None:
|
||||
return "_unknown"
|
||||
return request.client.host or "_unknown"
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> int:
|
||||
now = time.monotonic()
|
||||
with _LOGIN_BUCKETS_LOCK:
|
||||
bucket = _LOGIN_BUCKETS.setdefault(ip, deque())
|
||||
while bucket and now - bucket[0] > _LOGIN_WINDOW_SECONDS:
|
||||
bucket.popleft()
|
||||
bucket.append(now)
|
||||
return len(bucket)
|
||||
|
||||
|
||||
def _login_blocked(ip: str) -> int:
|
||||
"""Return seconds until the next attempt is allowed, or 0."""
|
||||
now = time.monotonic()
|
||||
with _LOGIN_BUCKETS_LOCK:
|
||||
bucket = _LOGIN_BUCKETS.get(ip)
|
||||
if not bucket:
|
||||
return 0
|
||||
while bucket and now - bucket[0] > _LOGIN_WINDOW_SECONDS:
|
||||
bucket.popleft()
|
||||
if len(bucket) >= _LOGIN_MAX_FAILS:
|
||||
return max(1, int(_LOGIN_WINDOW_SECONDS - (now - bucket[0])))
|
||||
return 0
|
||||
|
||||
|
||||
def _clear_login_bucket(ip: str) -> None:
|
||||
with _LOGIN_BUCKETS_LOCK:
|
||||
_LOGIN_BUCKETS.pop(ip, None)
|
||||
|
||||
|
||||
@router.get("/status", response_model = AuthStatusResponse)
|
||||
async def auth_status() -> AuthStatusResponse:
|
||||
"""
|
||||
Check whether auth has already been initialized.
|
||||
|
||||
- initialized = False -> frontend should wait for the seeded admin bootstrap.
|
||||
- initialized = True -> frontend should show login or force the first password change.
|
||||
"""
|
||||
"""Auth initialization state; ``default_username`` is exposed for first-boot UI prefill only."""
|
||||
return AuthStatusResponse(
|
||||
initialized = storage.is_initialized(),
|
||||
default_username = storage.DEFAULT_ADMIN_USERNAME,
|
||||
|
|
@ -53,12 +94,23 @@ async def auth_status() -> AuthStatusResponse:
|
|||
|
||||
|
||||
@router.post("/login", response_model = Token)
|
||||
async def login(payload: AuthLoginRequest) -> Token:
|
||||
"""
|
||||
Login with username/password and receive access + refresh tokens.
|
||||
"""
|
||||
async def login(payload: AuthLoginRequest, request: Request) -> Token:
|
||||
"""Login with username/password. Rate-limited per source IP."""
|
||||
ip = _client_key(request)
|
||||
blocked_for = _login_blocked(ip)
|
||||
if blocked_for > 0:
|
||||
raise HTTPException(
|
||||
status_code = status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail = (
|
||||
f"Too many failed login attempts from {ip}. "
|
||||
f"Try again in {blocked_for} seconds."
|
||||
),
|
||||
headers = {"Retry-After": str(blocked_for)},
|
||||
)
|
||||
|
||||
record = storage.get_user_and_secret(payload.username)
|
||||
if record is None:
|
||||
_record_login_failure(ip)
|
||||
raise HTTPException(
|
||||
status_code = status.HTTP_401_UNAUTHORIZED,
|
||||
detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.",
|
||||
|
|
@ -66,11 +118,13 @@ async def login(payload: AuthLoginRequest) -> Token:
|
|||
|
||||
salt, pwd_hash, _jwt_secret, must_change_password = record
|
||||
if not hashing.verify_password(payload.password, salt, pwd_hash):
|
||||
_record_login_failure(ip)
|
||||
raise HTTPException(
|
||||
status_code = status.HTTP_401_UNAUTHORIZED,
|
||||
detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.",
|
||||
)
|
||||
|
||||
_clear_login_bucket(ip)
|
||||
access_token = create_access_token(subject = payload.username)
|
||||
refresh_token = create_refresh_token(subject = payload.username)
|
||||
return Token(
|
||||
|
|
@ -81,6 +135,23 @@ async def login(payload: AuthLoginRequest) -> Token:
|
|||
)
|
||||
|
||||
|
||||
@router.post("/logout", status_code = status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
request: Request,
|
||||
current_subject: str = Depends(get_current_subject_allow_password_change),
|
||||
) -> Response:
|
||||
"""Revoke refresh tokens for the subject; the access token is stateless and expires on its own."""
|
||||
try:
|
||||
storage.revoke_user_refresh_tokens(current_subject)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
request.app.state.bootstrap_password = None
|
||||
except AttributeError:
|
||||
pass
|
||||
return Response(status_code = status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@router.post("/desktop-login", response_model = Token)
|
||||
async def desktop_login(payload: DesktopLoginRequest) -> Token:
|
||||
"""Exchange a local desktop secret for normal admin-subject tokens."""
|
||||
|
|
@ -101,21 +172,20 @@ async def desktop_login(payload: DesktopLoginRequest) -> Token:
|
|||
|
||||
@router.post("/refresh", response_model = Token)
|
||||
async def refresh(payload: RefreshTokenRequest) -> Token:
|
||||
"""
|
||||
Exchange a valid refresh token for a new access token.
|
||||
|
||||
The refresh token itself is reusable until it expires (7 days).
|
||||
"""
|
||||
new_access_token, username, is_desktop = refresh_access_token(payload.refresh_token)
|
||||
if new_access_token is None or username is None:
|
||||
"""Exchange a refresh token for a new access+refresh pair (single-use)."""
|
||||
consumed = storage.consume_refresh_token(payload.refresh_token)
|
||||
if consumed is None:
|
||||
raise HTTPException(
|
||||
status_code = status.HTTP_401_UNAUTHORIZED,
|
||||
detail = "Invalid or expired refresh token",
|
||||
)
|
||||
username, is_desktop = consumed
|
||||
new_access_token = create_access_token(subject = username, desktop = is_desktop)
|
||||
new_refresh_token = create_refresh_token(subject = username, desktop = is_desktop)
|
||||
|
||||
return Token(
|
||||
access_token = new_access_token,
|
||||
refresh_token = payload.refresh_token,
|
||||
refresh_token = new_refresh_token,
|
||||
token_type = "bearer",
|
||||
must_change_password = False
|
||||
if is_desktop
|
||||
|
|
@ -126,6 +196,7 @@ async def refresh(payload: RefreshTokenRequest) -> Token:
|
|||
@router.post("/change-password", response_model = Token)
|
||||
async def change_password(
|
||||
payload: ChangePasswordRequest,
|
||||
request: Request,
|
||||
current_subject: str = Depends(get_current_subject_allow_password_change),
|
||||
) -> Token:
|
||||
"""Allow the authenticated user to replace the default password."""
|
||||
|
|
@ -150,6 +221,10 @@ async def change_password(
|
|||
|
||||
storage.update_password(current_subject, payload.new_password)
|
||||
storage.revoke_user_refresh_tokens(current_subject)
|
||||
try:
|
||||
request.app.state.bootstrap_password = None
|
||||
except AttributeError:
|
||||
pass
|
||||
access_token = create_access_token(subject = current_subject)
|
||||
refresh_token = create_refresh_token(subject = current_subject)
|
||||
return Token(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Export API routes: checkpoint discovery and model export operations.
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
|
@ -184,14 +185,18 @@ async def get_export_status(
|
|||
|
||||
|
||||
def _export_details(output_path: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Wrap the resolved on-disk export path into the details dict the
|
||||
frontend reads to populate the Export Complete screen. Returns None
|
||||
when the export had no local component (Hub-only push) so the
|
||||
Pydantic field stays absent rather than ``{"output_path": null}``.
|
||||
"""
|
||||
"""Return the export path relative to exports_root so the install path is not leaked."""
|
||||
if not output_path:
|
||||
return None
|
||||
return {"output_path": output_path}
|
||||
try:
|
||||
from utils.paths.storage_roots import exports_root
|
||||
|
||||
rel = os.path.relpath(output_path, exports_root())
|
||||
if rel.startswith(".."):
|
||||
rel = os.path.basename(output_path)
|
||||
return {"output_path": rel}
|
||||
except Exception:
|
||||
return {"output_path": os.path.basename(output_path)}
|
||||
|
||||
|
||||
@router.post("/export/merged", response_model = ExportOperationResponse)
|
||||
|
|
|
|||
|
|
@ -1743,7 +1743,7 @@ async def openai_chat_completions(
|
|||
try:
|
||||
import base64 as _b64
|
||||
from io import BytesIO as _BytesIO
|
||||
from PIL import Image as _Image
|
||||
from PIL import Image as _Image, UnidentifiedImageError as _UIE
|
||||
|
||||
raw = _b64.b64decode(image_b64)
|
||||
# Normalize to RGB so PNG encoding succeeds regardless of
|
||||
|
|
@ -1754,9 +1754,15 @@ async def openai_chat_completions(
|
|||
buf = _BytesIO()
|
||||
img.save(buf, format = "PNG")
|
||||
image_b64 = _b64.b64encode(buf.getvalue()).decode("ascii")
|
||||
except Exception as e:
|
||||
except _UIE:
|
||||
raise HTTPException(
|
||||
status_code = 400, detail = f"Failed to process image: {e}"
|
||||
status_code = 400,
|
||||
detail = "Unsupported or corrupt image format.",
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code = 400,
|
||||
detail = "Failed to process image.",
|
||||
)
|
||||
|
||||
# Build message list with system prompt prepended
|
||||
|
|
@ -3426,10 +3432,10 @@ def _normalize_anthropic_openai_images(
|
|||
buf = io.BytesIO()
|
||||
img.save(buf, format = "PNG")
|
||||
png_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code = 400,
|
||||
detail = f"Failed to process image: {e}",
|
||||
detail = "Failed to process image.",
|
||||
)
|
||||
part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"}
|
||||
|
||||
|
|
@ -3465,6 +3471,7 @@ async def anthropic_messages(
|
|||
[m.model_dump() for m in payload.messages],
|
||||
payload.system,
|
||||
)
|
||||
openai_messages = _drop_empty_assistant_sentinels(openai_messages)
|
||||
|
||||
# Enforce vision guard + re-encode embedded images to PNG so the
|
||||
# Anthropic endpoint matches the behavior of /v1/chat/completions.
|
||||
|
|
@ -4190,6 +4197,19 @@ async def _anthropic_passthrough_non_streaming(
|
|||
# =====================================================================
|
||||
|
||||
|
||||
def _drop_empty_assistant_sentinels(messages: list[dict]) -> list[dict]:
|
||||
"""Drop bare ``{"role":"assistant"}`` Stop-button sentinels; passthrough backends reject them."""
|
||||
out: list[dict] = []
|
||||
for m in messages:
|
||||
if m.get("role") == "assistant":
|
||||
has_content = bool(m.get("content"))
|
||||
has_tool_calls = bool(m.get("tool_calls"))
|
||||
if not has_content and not has_tool_calls:
|
||||
continue
|
||||
out.append(m)
|
||||
return out
|
||||
|
||||
|
||||
def _openai_messages_for_passthrough(payload) -> list[dict]:
|
||||
"""Build OpenAI-format message dicts for the /v1/chat/completions
|
||||
passthrough path.
|
||||
|
|
@ -4206,7 +4226,9 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
|
|||
``image_url`` content part so vision + function-calling requests work
|
||||
transparently.
|
||||
"""
|
||||
messages = [m.model_dump(exclude_none = True) for m in payload.messages]
|
||||
messages = _drop_empty_assistant_sentinels(
|
||||
[m.model_dump(exclude_none = True) for m in payload.messages]
|
||||
)
|
||||
|
||||
if not payload.image_base64:
|
||||
return messages
|
||||
|
|
@ -4221,10 +4243,10 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
|
|||
buf = _BytesIO()
|
||||
img.save(buf, format = "PNG")
|
||||
png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code = 400,
|
||||
detail = f"Failed to process image: {e}",
|
||||
detail = "Failed to process image.",
|
||||
)
|
||||
|
||||
data_url = f"data:image/png;base64,{png_b64}"
|
||||
|
|
|
|||
|
|
@ -354,9 +354,14 @@ def run_server(
|
|||
if getattr(self, "started", False) and not self.should_exit:
|
||||
ready_event.set()
|
||||
|
||||
# Create the uvicorn server and expose it for signal handlers
|
||||
# server_header=False suppresses uvicorn's "Server: uvicorn"; SecurityHeadersMiddleware sets its own.
|
||||
config = uvicorn.Config(
|
||||
app, host = host, port = port, log_level = "info", access_log = False
|
||||
app,
|
||||
host = host,
|
||||
port = port,
|
||||
log_level = "info",
|
||||
access_log = False,
|
||||
server_header = False,
|
||||
)
|
||||
_server = _ReadyServer(config)
|
||||
_shutdown_event = Event()
|
||||
|
|
|
|||
|
|
@ -227,6 +227,60 @@ def test_desktop_refresh_preserves_desktop_marker():
|
|||
assert payload["desktop"] is True
|
||||
|
||||
|
||||
def test_consume_refresh_token_second_call_returns_none():
|
||||
"""Single-use rotation rejects the same token on a second consume."""
|
||||
seed_user()
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
raw = secrets.token_urlsafe(48)
|
||||
expires = (datetime.now(timezone.utc) + timedelta(days = 30)).isoformat()
|
||||
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
|
||||
|
||||
first = storage.consume_refresh_token(raw)
|
||||
assert first == (storage.DEFAULT_ADMIN_USERNAME, False)
|
||||
second = storage.consume_refresh_token(raw)
|
||||
assert second is None
|
||||
|
||||
|
||||
def test_consume_refresh_token_concurrent_only_one_succeeds(tmp_path, monkeypatch):
|
||||
"""64-thread pile-up against one token; DELETE RETURNING permits one winner."""
|
||||
seed_user()
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
raw = secrets.token_urlsafe(48)
|
||||
expires = (datetime.now(timezone.utc) + timedelta(days = 30)).isoformat()
|
||||
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
|
||||
|
||||
workers = 64
|
||||
|
||||
def attempt(_idx: int):
|
||||
try:
|
||||
return storage.consume_refresh_token(raw)
|
||||
except sqlite3.OperationalError:
|
||||
# "database is locked" under heavy contention; treat as losing the race.
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers = workers) as pool:
|
||||
results = list(pool.map(attempt, range(workers)))
|
||||
|
||||
successes = [r for r in results if r is not None]
|
||||
assert (
|
||||
len(successes) == 1
|
||||
), f"expected exactly one consumer to win, got {len(successes)}"
|
||||
assert successes[0] == (storage.DEFAULT_ADMIN_USERNAME, False)
|
||||
|
||||
|
||||
def test_consume_refresh_token_expired_returns_none():
|
||||
seed_user()
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
raw = secrets.token_urlsafe(48)
|
||||
expires = (datetime.now(timezone.utc) - timedelta(hours = 1)).isoformat()
|
||||
storage.save_refresh_token(raw, storage.DEFAULT_ADMIN_USERNAME, expires)
|
||||
assert storage.consume_refresh_token(raw) is None
|
||||
|
||||
|
||||
def test_desktop_session_uses_real_admin_identity_for_api_keys():
|
||||
seed_user(must_change_password = True)
|
||||
raw = storage.create_desktop_secret()
|
||||
|
|
@ -392,7 +446,21 @@ def test_health_response_reports_desktop_capability_fields(monkeypatch):
|
|||
|
||||
monkeypatch.setattr(backend_main._hw_module, "CHAT_ONLY", False)
|
||||
|
||||
body = asyncio.run(backend_main.health_check())
|
||||
seed_user()
|
||||
from auth.authentication import create_access_token
|
||||
|
||||
token = create_access_token(storage.DEFAULT_ADMIN_USERNAME)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/api/health", backend_main.health_check, methods = ["GET"])
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/api/health",
|
||||
headers = {"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
|
||||
assert body["desktop_protocol_version"] == 1
|
||||
assert body["supports_desktop_auth"] is True
|
||||
|
|
|
|||
269
studio/backend/tests/test_middleware.py
Normal file
269
studio/backend/tests/test_middleware.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
|
||||
|
||||
"""Tests for MaxBodyMiddleware, SecurityHeadersMiddleware, and the /api/health auth gate."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_BACKEND_ROOT))
|
||||
|
||||
|
||||
@pytest.fixture(scope = "module")
|
||||
def main_module():
|
||||
import main as _main # noqa: F401
|
||||
|
||||
return _main
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# MaxBodyMiddleware
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def _make_protected_app(max_bytes: int, main_module):
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
main_module.MaxBodyMiddleware,
|
||||
max_bytes = max_bytes,
|
||||
protected_prefixes = ("/v1/chat/completions", "/api/train"),
|
||||
)
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat(payload: dict):
|
||||
return {"ok": True, "n": len(payload.get("text", ""))}
|
||||
|
||||
@app.post("/api/other")
|
||||
async def other(payload: dict):
|
||||
return {"ok": True, "unprotected": True}
|
||||
|
||||
@app.get("/api/train/status")
|
||||
async def status_get():
|
||||
return {"ok": True, "get": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestMaxBodyMiddleware:
|
||||
def test_small_protected_body_passes(self, main_module):
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
r = c.post("/v1/chat/completions", json = {"text": "x" * 100})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["n"] == 100
|
||||
|
||||
def test_large_declared_content_length_rejected(self, main_module):
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
r = c.post("/v1/chat/completions", json = {"text": "x" * 5000})
|
||||
assert r.status_code == 413
|
||||
assert "too large" in r.json()["detail"].lower()
|
||||
|
||||
def test_unprotected_prefix_passes_large_body(self, main_module):
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
r = c.post("/api/other", json = {"text": "x" * 5000})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["unprotected"] is True
|
||||
|
||||
def test_chunked_upload_over_cap_rejected(self, main_module):
|
||||
# Regression: declared-Content-Length-only check could be bypassed
|
||||
# by chunked transfer-encoding.
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
|
||||
def gen():
|
||||
yield b'{"text":"'
|
||||
yield b"x" * 800
|
||||
yield b'"}'
|
||||
yield b"\n" + b"y" * 500
|
||||
|
||||
r = c.post(
|
||||
"/v1/chat/completions",
|
||||
content = gen(),
|
||||
headers = {"content-type": "application/json"},
|
||||
)
|
||||
assert r.status_code == 413
|
||||
assert "too large" in r.json()["detail"].lower()
|
||||
|
||||
def test_chunked_upload_under_cap_passes(self, main_module):
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
|
||||
def gen():
|
||||
yield b'{"text":"'
|
||||
yield b"x" * 50
|
||||
yield b'"}'
|
||||
|
||||
r = c.post(
|
||||
"/v1/chat/completions",
|
||||
content = gen(),
|
||||
headers = {"content-type": "application/json"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["n"] == 50
|
||||
|
||||
def test_get_not_subject_to_cap(self, main_module):
|
||||
app = _make_protected_app(1024, main_module)
|
||||
c = TestClient(app)
|
||||
r = c.get("/api/train/status")
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# SecurityHeadersMiddleware / CSP
|
||||
# =====================================================================
|
||||
|
||||
|
||||
def _make_csp_app(main_module, attach_nonce: str | None = None):
|
||||
app = FastAPI()
|
||||
app.add_middleware(main_module.SecurityHeadersMiddleware)
|
||||
|
||||
@app.get("/plain")
|
||||
async def plain():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/with-nonce")
|
||||
async def with_nonce():
|
||||
headers = {}
|
||||
if attach_nonce:
|
||||
headers[main_module._CSP_SCRIPT_NONCE_HEADER] = attach_nonce
|
||||
return Response(
|
||||
content = b"<html></html>",
|
||||
media_type = "text/html",
|
||||
headers = headers,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestSecurityHeadersMiddleware:
|
||||
def test_csp_has_no_unsafe_inline_for_script_src(self, main_module):
|
||||
app = _make_csp_app(main_module)
|
||||
c = TestClient(app)
|
||||
r = c.get("/plain")
|
||||
assert r.status_code == 200
|
||||
csp = r.headers["content-security-policy"]
|
||||
# Parse per-directive so style-src unsafe-inline does not false-match.
|
||||
directives = {
|
||||
chunk.strip().split(" ", 1)[0]: chunk.strip()
|
||||
for chunk in csp.split(";")
|
||||
if chunk.strip()
|
||||
}
|
||||
assert "script-src" in directives
|
||||
assert "'unsafe-inline'" not in directives["script-src"]
|
||||
# style-src keeps unsafe-inline for Vite-injected styles.
|
||||
assert "'unsafe-inline'" in directives["style-src"]
|
||||
|
||||
def test_default_security_headers_present(self, main_module):
|
||||
app = _make_csp_app(main_module)
|
||||
c = TestClient(app)
|
||||
r = c.get("/plain")
|
||||
assert r.headers["x-frame-options"] == "DENY"
|
||||
assert r.headers["x-content-type-options"] == "nosniff"
|
||||
assert r.headers["referrer-policy"] == "no-referrer"
|
||||
assert "camera=()" in r.headers["permissions-policy"]
|
||||
assert r.headers["server"] == "unsloth-studio"
|
||||
|
||||
def test_internal_nonce_header_is_spliced_into_csp_and_stripped(self, main_module):
|
||||
nonce = "test-nonce-abc"
|
||||
app = _make_csp_app(main_module, attach_nonce = nonce)
|
||||
c = TestClient(app)
|
||||
r = c.get("/with-nonce")
|
||||
csp = r.headers["content-security-policy"]
|
||||
assert f"'nonce-{nonce}'" in csp
|
||||
# Internal handoff header must not leak to clients.
|
||||
assert main_module._CSP_SCRIPT_NONCE_HEADER not in {
|
||||
k.lower() for k in r.headers.keys()
|
||||
}
|
||||
|
||||
def test_build_csp_helper_shape(self, main_module):
|
||||
plain = main_module._build_csp()
|
||||
assert "script-src 'self';" in plain
|
||||
assert "'unsafe-inline'" not in plain.split("script-src", 1)[1].split(";", 1)[0]
|
||||
nonced = main_module._build_csp("XYZ")
|
||||
assert "script-src 'self' 'nonce-XYZ';" in nonced
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# /api/health auth gate
|
||||
# =====================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def health_app(tmp_path, monkeypatch):
|
||||
"""Mount /api/health on a fresh app against an isolated auth db."""
|
||||
from auth import storage
|
||||
|
||||
monkeypatch.setattr(storage, "DB_PATH", tmp_path / "auth.db")
|
||||
monkeypatch.setattr(storage, "_BOOTSTRAP_PW_PATH", tmp_path / ".bootstrap_password")
|
||||
monkeypatch.setattr(storage, "_bootstrap_password", None)
|
||||
|
||||
import main as _main
|
||||
|
||||
app = FastAPI()
|
||||
app.add_api_route("/api/health", _main.health_check, methods = ["GET"])
|
||||
|
||||
import secrets as _secrets
|
||||
|
||||
storage.create_initial_user(
|
||||
username = storage.DEFAULT_ADMIN_USERNAME,
|
||||
password = "human-password-123",
|
||||
jwt_secret = _secrets.token_urlsafe(64),
|
||||
must_change_password = False,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
class TestHealthAuthGate:
|
||||
def test_no_auth_returns_minimal_payload(self, health_app):
|
||||
c = TestClient(health_app)
|
||||
r = c.get("/api/health")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert "timestamp" in body
|
||||
for forbidden in ("version", "device_type", "studio_root_id"):
|
||||
assert forbidden not in body
|
||||
|
||||
def test_invalid_bearer_returns_minimal_payload(self, health_app):
|
||||
# Regression: calling the async dep without await made any Bearer header pass.
|
||||
c = TestClient(health_app)
|
||||
r = c.get(
|
||||
"/api/health",
|
||||
headers = {"Authorization": "Bearer not-a-real-token"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "healthy"
|
||||
for forbidden in ("version", "device_type", "studio_root_id"):
|
||||
assert forbidden not in body
|
||||
|
||||
def test_valid_bearer_returns_full_payload(self, health_app):
|
||||
from auth import storage
|
||||
from auth.authentication import create_access_token
|
||||
|
||||
token = create_access_token(storage.DEFAULT_ADMIN_USERNAME)
|
||||
c = TestClient(health_app)
|
||||
r = c.get(
|
||||
"/api/health",
|
||||
headers = {"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert "version" in body
|
||||
assert "device_type" in body
|
||||
assert "studio_root_id" in body
|
||||
|
|
@ -125,22 +125,21 @@ class TestChatMessageToolRoles:
|
|||
)
|
||||
assert msg.content is None
|
||||
|
||||
def test_tool_role_missing_tool_call_id_rejected(self):
|
||||
# Per OpenAI spec, role="tool" messages must carry tool_call_id so
|
||||
# upstream backends can associate the result with its prior call.
|
||||
# Pin the boundary-level rejection so a malformed tool-result
|
||||
# message never reaches the passthrough path.
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ChatMessage(role = "tool", content = '{"temperature": 72}')
|
||||
assert "tool_call_id" in str(exc_info.value)
|
||||
def test_tool_role_missing_tool_call_id_synthesised(self):
|
||||
# Frontend drops the id on second-round POST; validator synthesises one.
|
||||
msg = ChatMessage(role = "tool", content = '{"temperature": 72}')
|
||||
assert msg.tool_call_id is not None
|
||||
assert msg.tool_call_id.startswith("call_")
|
||||
assert len(msg.tool_call_id) >= len("call_") + 8
|
||||
|
||||
def test_tool_role_empty_tool_call_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessage(
|
||||
role = "tool",
|
||||
tool_call_id = "",
|
||||
content = '{"temperature": 72}',
|
||||
)
|
||||
def test_tool_role_empty_tool_call_id_synthesised(self):
|
||||
msg = ChatMessage(
|
||||
role = "tool",
|
||||
tool_call_id = "",
|
||||
content = '{"temperature": 72}',
|
||||
)
|
||||
assert msg.tool_call_id is not None
|
||||
assert msg.tool_call_id.startswith("call_")
|
||||
|
||||
# ── Role-aware content requirements ────────────────────────────
|
||||
|
||||
|
|
@ -162,10 +161,19 @@ class TestChatMessageToolRoles:
|
|||
ChatMessage(role = "tool", tool_call_id = "call_1", content = "")
|
||||
assert "content" in str(exc_info.value)
|
||||
|
||||
def test_assistant_without_content_or_tool_calls_rejected(self):
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ChatMessage(role = "assistant")
|
||||
assert "content" in str(exc_info.value) or "tool_calls" in str(exc_info.value)
|
||||
def test_assistant_without_content_or_tool_calls_tolerated(self):
|
||||
# Stop-button leaves an empty assistant turn; tolerate so replay round-trips.
|
||||
msg = ChatMessage(role = "assistant")
|
||||
assert msg.content is None
|
||||
assert msg.tool_calls is None
|
||||
|
||||
def test_assistant_empty_string_content_normalised_to_none(self):
|
||||
msg = ChatMessage(role = "assistant", content = "")
|
||||
assert msg.content is None
|
||||
|
||||
def test_assistant_empty_list_content_normalised_to_none(self):
|
||||
msg = ChatMessage(role = "assistant", content = [])
|
||||
assert msg.content is None
|
||||
|
||||
# ── Role-constrained tool-call metadata ────────────────────────
|
||||
|
||||
|
|
@ -472,3 +480,91 @@ class TestFriendlyErrorHttpx:
|
|||
assert (
|
||||
_friendly_error(RuntimeError("unrelated")) == "An internal error occurred"
|
||||
)
|
||||
|
||||
|
||||
from routes.inference import ( # noqa: E402
|
||||
_drop_empty_assistant_sentinels,
|
||||
_openai_messages_for_passthrough,
|
||||
)
|
||||
|
||||
|
||||
class TestDropEmptyAssistantSentinels:
|
||||
def test_drops_empty_assistant_between_real_turns(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": ""},
|
||||
{"role": "user", "content": "again"},
|
||||
]
|
||||
out = _drop_empty_assistant_sentinels(msgs)
|
||||
assert out == [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "user", "content": "again"},
|
||||
]
|
||||
|
||||
def test_drops_assistant_with_no_content_key(self):
|
||||
# exclude_none=True strips the content key entirely; filter must catch this.
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant"},
|
||||
{"role": "user", "content": "ok"},
|
||||
]
|
||||
out = _drop_empty_assistant_sentinels(msgs)
|
||||
assert out == [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "user", "content": "ok"},
|
||||
]
|
||||
|
||||
def test_preserves_assistant_with_text(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello back"},
|
||||
]
|
||||
out = _drop_empty_assistant_sentinels(msgs)
|
||||
assert out == msgs
|
||||
|
||||
def test_preserves_assistant_with_tool_calls_only(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": '{"t": 72}',
|
||||
},
|
||||
]
|
||||
out = _drop_empty_assistant_sentinels(msgs)
|
||||
assert out == msgs
|
||||
|
||||
def test_preserves_user_and_system_with_empty_content(self):
|
||||
# Filter scoped to role="assistant" only.
|
||||
msgs = [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": ""},
|
||||
]
|
||||
out = _drop_empty_assistant_sentinels(msgs)
|
||||
assert out == msgs
|
||||
|
||||
def test_openai_messages_for_passthrough_drops_sentinel(self):
|
||||
"""End-to-end: Stop-sentinel must not reach the wire."""
|
||||
req = ChatCompletionRequest(
|
||||
model = "default",
|
||||
messages = [
|
||||
ChatMessage(role = "user", content = "hi"),
|
||||
ChatMessage(role = "assistant", content = ""),
|
||||
ChatMessage(role = "user", content = "again"),
|
||||
],
|
||||
)
|
||||
out = _openai_messages_for_passthrough(req)
|
||||
roles = [m["role"] for m in out]
|
||||
assert roles == ["user", "user"]
|
||||
for m in out:
|
||||
assert m.get("content"), m
|
||||
|
|
|
|||
241
studio/backend/tests/test_sandbox_tools.py
Normal file
241
studio/backend/tests/test_sandbox_tools.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
|
||||
|
||||
"""Tests for the sandboxed-Python AST policy in core/inference/tools.py."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_BACKEND_ROOT))
|
||||
|
||||
from core.inference.tools import _check_code_safety
|
||||
|
||||
|
||||
def _ok(code: str):
|
||||
assert _check_code_safety(code) is None, code
|
||||
|
||||
|
||||
def _blocked(code: str, *, expect_phrase: str):
|
||||
msg = _check_code_safety(code)
|
||||
assert msg is not None, code
|
||||
assert expect_phrase in msg, (expect_phrase, msg)
|
||||
|
||||
|
||||
class TestMetadataHostDenylist:
|
||||
def test_aws_imds_literal_blocked(self):
|
||||
_blocked(
|
||||
'import requests; requests.get("http://169.254.169.254/latest/meta-data/")',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
def test_gcp_metadata_dns_blocked(self):
|
||||
_blocked(
|
||||
'import requests; requests.get("http://metadata.google.internal/")',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
def test_alibaba_ecs_literal_blocked(self):
|
||||
_blocked(
|
||||
'import socket; s=socket.socket(); s.connect(("100.100.100.200", 80))',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
def test_ipv6_imds_literal_blocked(self):
|
||||
_blocked(
|
||||
'import urllib.request; urllib.request.urlopen("http://[fd00:ec2::254]/")',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
def test_metadata_link_local_prefix_blocked(self):
|
||||
_blocked(
|
||||
'import requests; requests.get("http://169.254.170.2/v3/")',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
|
||||
class TestTrustedHostAllowlist:
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"https://en.wikipedia.org/wiki/Python_(programming_language)",
|
||||
"https://fr.wikipedia.org/wiki/Python_(langage)",
|
||||
"https://www.google.com/search?q=foo",
|
||||
"https://duckduckgo.com/?q=foo",
|
||||
"https://huggingface.co/unsloth",
|
||||
"https://cdn-lfs.huggingface.co/repos/abc/def/file.bin",
|
||||
"https://raw.githubusercontent.com/foo/bar/main/README.md",
|
||||
"https://api.github.com/repos/foo/bar",
|
||||
"https://arxiv.org/abs/2401.12345",
|
||||
"https://export.arxiv.org/abs/2401.12345",
|
||||
"https://stackoverflow.com/questions/12345",
|
||||
"https://math.stackexchange.com/questions/12345",
|
||||
"https://developer.mozilla.org/en-US/docs/Web/JavaScript",
|
||||
"https://docs.python.org/3/library/asyncio.html",
|
||||
"https://pypi.org/project/requests/",
|
||||
"https://files.pythonhosted.org/packages/foo/bar.whl",
|
||||
"https://www.bbc.com/news",
|
||||
"https://api.weather.gov/points/40,-90",
|
||||
"https://numpy.org/doc/stable/",
|
||||
"https://pytorch.org/docs/stable/index.html",
|
||||
],
|
||||
)
|
||||
def test_trusted_host_passes(self, url):
|
||||
_ok(f"import requests; requests.get({url!r})")
|
||||
|
||||
def test_wikipedia_subdomain_passes(self):
|
||||
_ok(
|
||||
'import urllib.request; urllib.request.urlopen("https://m.en.wikipedia.org/wiki/Foo")'
|
||||
)
|
||||
|
||||
def test_hf_co_short_form_passes(self):
|
||||
_ok('import requests; requests.get("https://hf.co/unsloth/Qwen3.5-4B-GGUF")')
|
||||
|
||||
def test_github_io_pages_pass(self):
|
||||
_ok('import requests; requests.get("https://unslothai.github.io/")')
|
||||
|
||||
|
||||
class TestUntrustedHostBlock:
|
||||
def test_example_com_blocked(self):
|
||||
_blocked(
|
||||
'import requests; requests.get("https://example.com/")',
|
||||
expect_phrase = "Blocked: host not in sandbox allowlist",
|
||||
)
|
||||
|
||||
def test_random_blog_blocked(self):
|
||||
_blocked(
|
||||
'import urllib.request; urllib.request.urlopen("https://random-blog-host.example/")',
|
||||
expect_phrase = "Blocked: host not in sandbox allowlist",
|
||||
)
|
||||
|
||||
def test_socket_connect_random_host_blocked(self):
|
||||
_blocked(
|
||||
'import socket; s=socket.socket(); s.connect(("evil.example", 80))',
|
||||
expect_phrase = "Blocked: host not in sandbox allowlist",
|
||||
)
|
||||
|
||||
def test_dynamic_url_not_statically_blocked(self):
|
||||
# Static AST cannot resolve runtime URLs; bash blocklist is the fallback.
|
||||
_ok('import requests; url = "https://example.com/"; requests.get(url)')
|
||||
|
||||
|
||||
class TestHostNormalization:
|
||||
def test_trailing_dot_treated_same(self):
|
||||
_ok('import requests; requests.get("https://wikipedia.org./")')
|
||||
|
||||
def test_explicit_port_does_not_unblock_or_misblock(self):
|
||||
_ok('import requests; requests.get("https://en.wikipedia.org:443/wiki/Foo")')
|
||||
_blocked(
|
||||
'import requests; requests.get("https://example.com:8080/")',
|
||||
expect_phrase = "Blocked: host not in sandbox allowlist",
|
||||
)
|
||||
|
||||
def test_userinfo_at_does_not_smuggle_metadata_host(self):
|
||||
_blocked(
|
||||
'import requests; requests.get("https://wikipedia.org@169.254.169.254/latest/")',
|
||||
expect_phrase = "Blocked: cloud-metadata host",
|
||||
)
|
||||
|
||||
def test_uppercase_host_normalised(self):
|
||||
_ok('import requests; requests.get("https://EN.WIKIPEDIA.ORG/wiki/Foo")')
|
||||
|
||||
|
||||
class TestUploadDenylist:
|
||||
def test_requests_post_files_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import requests\n"
|
||||
'requests.post("https://huggingface.co/api/repos/upload", '
|
||||
'files={"f": open("x.bin", "rb")})'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_requests_put_data_bytes_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import requests\n"
|
||||
'requests.put("https://huggingface.co/api/repos/upload", '
|
||||
'data=b"\\x00\\x01\\x02")'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_requests_post_data_open_handle_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import requests\n"
|
||||
'requests.post("https://huggingface.co/api/repos/upload", '
|
||||
'data=open("x.bin", "rb"))'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_httpx_post_files_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import httpx\n"
|
||||
'httpx.post("https://huggingface.co/api/repos/upload", '
|
||||
'files={"f": open("x.bin", "rb")})'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_hf_api_upload_file_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"from huggingface_hub import HfApi\n"
|
||||
'HfApi().upload_file(path_or_fileobj="x.bin", '
|
||||
'path_in_repo="x.bin", repo_id="foo/bar")'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_hf_module_upload_folder_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import huggingface_hub\n"
|
||||
'huggingface_hub.upload_folder(folder_path="./", repo_id="foo/bar")'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_hf_create_commit_method_blocked(self):
|
||||
_blocked(
|
||||
(
|
||||
"import huggingface_hub\n"
|
||||
"api = huggingface_hub.HfApi()\n"
|
||||
'api.create_commit(repo_id="foo/bar", operations=[])'
|
||||
),
|
||||
expect_phrase = "Blocked: file upload disallowed in sandbox",
|
||||
)
|
||||
|
||||
def test_plain_post_json_not_blocked(self):
|
||||
_ok(
|
||||
"import requests\n"
|
||||
'requests.post("https://api.weather.gov/lookup", json={"k": "v"})'
|
||||
)
|
||||
|
||||
|
||||
class TestSandboxCpuRlimitDefault:
|
||||
"""Pin the default so a regression below 600s without opt-in is caught."""
|
||||
|
||||
def test_default_cpu_s_is_600(self):
|
||||
src = (_BACKEND_ROOT / "core" / "inference" / "tools.py").read_text()
|
||||
assert 'UNSLOTH_STUDIO_SANDBOX_CPU_S", "600"' in src
|
||||
|
||||
def test_clone_newnet_removed(self):
|
||||
src = (_BACKEND_ROOT / "core" / "inference" / "tools.py").read_text()
|
||||
assert "_libc.unshare(0x40000000)" not in src
|
||||
# Explanatory comment retained.
|
||||
assert "CLONE_NEWNET" in src
|
||||
|
||||
|
||||
class TestMaxBodyDefault:
|
||||
def test_default_is_500_mb(self):
|
||||
src = (_BACKEND_ROOT / "main.py").read_text()
|
||||
assert 'UNSLOTH_STUDIO_MAX_BODY_MB", "500"' in src
|
||||
90
studio/backend/tests/test_studio_train_validation.py
Normal file
90
studio/backend/tests/test_studio_train_validation.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
|
||||
|
||||
"""Pin TrainingStartRequest hyperparameter caps at the at-cap / over-cap boundary."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_BACKEND_ROOT))
|
||||
|
||||
from models.training import (
|
||||
_MAX_BATCH_SIZE,
|
||||
_MAX_LORA_ALPHA,
|
||||
_MAX_LORA_R,
|
||||
_MAX_SEQ_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def _check_field(field_name: str, value):
|
||||
"""Run the field validator without constructing a full TrainingStartRequest."""
|
||||
from models.training import TrainingStartRequest
|
||||
|
||||
schema_field = TrainingStartRequest.model_fields[field_name]
|
||||
return TrainingStartRequest.__pydantic_validator__.validate_assignment(
|
||||
TrainingStartRequest.model_construct(),
|
||||
field_name,
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
class TestSeqLengthCap:
|
||||
def test_at_cap_accepts(self):
|
||||
_check_field("max_seq_length", _MAX_SEQ_LENGTH)
|
||||
assert _MAX_SEQ_LENGTH == 2_000_000
|
||||
|
||||
def test_over_cap_rejects(self):
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
_check_field("max_seq_length", _MAX_SEQ_LENGTH + 1)
|
||||
assert "max_seq_length" in str(exc.value)
|
||||
|
||||
def test_below_min_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("max_seq_length", 0)
|
||||
|
||||
|
||||
class TestBatchSizeCap:
|
||||
def test_at_cap_accepts(self):
|
||||
_check_field("batch_size", _MAX_BATCH_SIZE)
|
||||
assert _MAX_BATCH_SIZE == 4096
|
||||
|
||||
def test_over_cap_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("batch_size", _MAX_BATCH_SIZE + 1)
|
||||
|
||||
def test_below_min_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("batch_size", 0)
|
||||
|
||||
|
||||
class TestLoraRCap:
|
||||
def test_at_cap_accepts(self):
|
||||
_check_field("lora_r", _MAX_LORA_R)
|
||||
assert _MAX_LORA_R == 16_384
|
||||
|
||||
def test_over_cap_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("lora_r", _MAX_LORA_R + 1)
|
||||
|
||||
def test_below_min_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("lora_r", 0)
|
||||
|
||||
|
||||
class TestLoraAlphaCap:
|
||||
def test_at_cap_accepts(self):
|
||||
_check_field("lora_alpha", _MAX_LORA_ALPHA)
|
||||
assert _MAX_LORA_ALPHA == 32_768
|
||||
|
||||
def test_over_cap_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("lora_alpha", _MAX_LORA_ALPHA + 1)
|
||||
|
||||
def test_below_min_rejects(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_check_field("lora_alpha", 0)
|
||||
|
|
@ -28,7 +28,16 @@ from utils.models.model_config import (
|
|||
)
|
||||
|
||||
|
||||
def test_scan_trained_models_includes_lora_and_full_finetune_outputs(tmp_path: Path):
|
||||
def test_scan_trained_models_includes_lora_and_full_finetune_outputs(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
# resolve_output_dir refuses absolutes outside outputs_root; point it at tmp_path.
|
||||
from utils.models import model_config as _mc
|
||||
from utils.paths import storage_roots as _sr
|
||||
|
||||
monkeypatch.setattr(_sr, "outputs_root", lambda: tmp_path)
|
||||
monkeypatch.setattr(_mc, "outputs_root", lambda: tmp_path)
|
||||
|
||||
lora_dir = tmp_path / "unsloth_SmolLM-135M_1775412608"
|
||||
lora_dir.mkdir()
|
||||
(lora_dir / "adapter_config.json").write_text(
|
||||
|
|
|
|||
|
|
@ -276,21 +276,52 @@ def _clean_relative_path(
|
|||
return Path(*parts) if parts else Path()
|
||||
|
||||
|
||||
def _assert_contained(resolved: Path, root: Path) -> None:
|
||||
"""Raise ValueError if ``resolved`` realpaths outside ``root``."""
|
||||
try:
|
||||
resolved_real = Path(os.path.realpath(resolved))
|
||||
root_real = Path(os.path.realpath(root))
|
||||
except OSError as exc:
|
||||
raise ValueError(f"path resolution failed: {exc}") from exc
|
||||
try:
|
||||
resolved_real.relative_to(root_real)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"path escapes root: {resolved!s} -> {resolved_real!s} "
|
||||
f"is not under {root_real!s}"
|
||||
) from exc
|
||||
|
||||
|
||||
def resolve_under_root(
|
||||
path_value: str | None,
|
||||
*,
|
||||
root: Path,
|
||||
strip_prefixes: tuple[str, ...] = (),
|
||||
) -> Path:
|
||||
"""Resolve ``path_value`` and assert the result is under ``root``.
|
||||
|
||||
Absolutes are accepted only if already contained (so internal pre-resolved
|
||||
paths re-enter idempotently); user-facing schemas reject absolutes upstream.
|
||||
"""
|
||||
if not path_value or not str(path_value).strip():
|
||||
return root
|
||||
|
||||
path = Path(str(path_value).strip()).expanduser()
|
||||
raw = str(path_value).strip()
|
||||
if "\x00" in raw:
|
||||
raise ValueError("path may not contain null bytes")
|
||||
|
||||
path = Path(raw).expanduser()
|
||||
if ".." in path.parts:
|
||||
raise ValueError(f"path may not contain '..' segments: {raw!r}")
|
||||
|
||||
if path.is_absolute():
|
||||
_assert_contained(path, root)
|
||||
return path
|
||||
|
||||
cleaned = _clean_relative_path(str(path), strip_prefixes = strip_prefixes)
|
||||
return root / cleaned
|
||||
cleaned = _clean_relative_path(raw, strip_prefixes = strip_prefixes)
|
||||
candidate = root / cleaned
|
||||
_assert_contained(candidate, root)
|
||||
return candidate
|
||||
|
||||
|
||||
def resolve_output_dir(path_value: str | None = None) -> Path:
|
||||
|
|
@ -318,9 +349,22 @@ def resolve_tensorboard_dir(path_value: str | None = None) -> Path:
|
|||
|
||||
|
||||
def resolve_dataset_path(path_value: str) -> Path:
|
||||
path = Path(path_value).expanduser()
|
||||
raw = str(path_value or "").strip()
|
||||
if "\x00" in raw:
|
||||
raise ValueError("dataset path may not contain null bytes")
|
||||
path = Path(raw).expanduser()
|
||||
if ".." in path.parts:
|
||||
raise ValueError(f"dataset path may not contain '..' segments: {raw!r}")
|
||||
if path.is_absolute():
|
||||
return path
|
||||
for root_fn in (datasets_root, dataset_uploads_root, recipe_datasets_root):
|
||||
try:
|
||||
_assert_contained(path, root_fn())
|
||||
return path
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"dataset path must be relative or under a dataset root: {raw!r}"
|
||||
)
|
||||
|
||||
parts = [part for part in Path(path_value).parts if part not in ("", ".")]
|
||||
if parts[:2] == ["assets", "datasets"]:
|
||||
|
|
|
|||
|
|
@ -316,18 +316,47 @@ if code in (400, 422):
|
|||
else:
|
||||
fail(f"/api/auth/refresh without body returned {code} (expected 400/422)")
|
||||
|
||||
# Login burst with wrong password must keep returning 401, NOT 429.
|
||||
# Documents that no rate-limit / brute-force lockout exists today.
|
||||
# When/if we add one, this assertion updates in the same PR.
|
||||
all_401 = True
|
||||
for i in range(5):
|
||||
code, _ = login("definitely-wrong-password")
|
||||
if code != 401:
|
||||
all_401 = False
|
||||
fail(f"login burst attempt {i+1} returned {code} (expected 401)")
|
||||
|
||||
# Wrong-password burst: expect 401 until the per-IP bucket fills, then
|
||||
# 429 with Retry-After. Bucket cannot be reset between tests, so we
|
||||
# assert the observable invariant rather than a fixed transition index.
|
||||
def _login_with_headers(password: str) -> tuple[int, str | None]:
|
||||
"""Like ``login`` but returns ``(status, retry_after_header)``."""
|
||||
url = f"{BASE}/api/auth/login"
|
||||
data = json.dumps({"username": "unsloth", "password": password}).encode()
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data = data,
|
||||
method = "POST",
|
||||
headers = {"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout = 10) as r:
|
||||
return r.status, r.headers.get("Retry-After")
|
||||
except urllib.error.HTTPError as exc:
|
||||
return exc.code, exc.headers.get("Retry-After") if exc.headers else None
|
||||
|
||||
|
||||
codes = []
|
||||
retry_after = None
|
||||
for i in range(8):
|
||||
code, ra = _login_with_headers("definitely-wrong-password")
|
||||
codes.append(code)
|
||||
if code == 429:
|
||||
retry_after = ra
|
||||
break
|
||||
if all_401:
|
||||
ok("login burst (5x wrong pw) -> 401 each (no rate-limit, documented)")
|
||||
if code != 401:
|
||||
fail(f"login burst attempt {i+1} returned {code} (expected 401 or 429)")
|
||||
break
|
||||
|
||||
if 401 not in codes:
|
||||
fail(f"login burst never returned 401 before rate-limit (codes={codes})")
|
||||
elif 429 not in codes:
|
||||
fail(f"login burst never rate-limited after {len(codes)} wrongs (codes={codes})")
|
||||
elif retry_after is None:
|
||||
fail("429 response missing Retry-After header")
|
||||
else:
|
||||
ok(f"login burst -> 401x{codes.count(401)} then 429 with Retry-After={retry_after}")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -593,12 +593,16 @@ def test_install_ps1_bakes_studio_root_id_into_launcher():
|
|||
def test_health_endpoint_exposes_studio_root_id_not_raw_path():
|
||||
"""studio/backend/main.py /api/health must expose studio_root_id (a
|
||||
hex digest) and NOT the raw studio_root path. Studio supports
|
||||
`-H 0.0.0.0`; an unauthenticated /api/health that returns the raw
|
||||
install path leaks username, home dir, workspace name, etc."""
|
||||
`-H 0.0.0.0`; a /api/health that returns the raw install path
|
||||
leaks username, home dir, workspace name, etc."""
|
||||
main_py = REPO_ROOT / "studio" / "backend" / "main.py"
|
||||
src = main_py.read_text()
|
||||
health_idx = src.index('@app.get("/api/health")')
|
||||
health_block = src[health_idx : health_idx + 1500]
|
||||
# Slice up to the next top-level @app. so a growing body stays in scope.
|
||||
next_app_idx = src.find("\n@app.", health_idx + 1)
|
||||
if next_app_idx == -1:
|
||||
next_app_idx = len(src)
|
||||
health_block = src[health_idx:next_app_idx]
|
||||
assert (
|
||||
'"studio_root_id"' in health_block
|
||||
), "/api/health must expose studio_root_id (hex digest)"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue