From eb8b0dee2e7b06727360add18e2759ab27849f93 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 24 Apr 2026 10:09:25 -0700 Subject: [PATCH] Studio: make stop button actually stop generation (#5069) * Studio: make stop button actually stop generation The UI stop button routes through assistant-ui's cancelRun, which aborts the frontend fetch. Four issues combined to let llama-server keep decoding long after the user clicked stop: 1. request.is_disconnected() does not fire reliably behind proxies (e.g. Colab) that don't propagate fetch aborts. 2. llama-server defaults n_predict to n_ctx when max_tokens is not sent, so a cancelled request keeps producing tokens up to 262144. 3. The httpx.Client pool keeps TCP keep-alive, so even a cleanly closed stream reuses the same connection and llama-server's liveness poll never sees a disconnect. 4. No explicit backend route to cancel - every cancel path relied on is_disconnected. Changes: - Add POST /api/inference/cancel keyed by session_id/completion_id, with a registry populated for the lifetime of each streaming response. - Have the frontend (chat-adapter.ts) POST /inference/cancel on AbortController abort, alongside the existing fetch teardown. - Send max_tokens=4096 + t_max_predict_ms=120000 as defaults on every outbound chat completion to llama-server; honoured by user overrides. - Disable httpx keep-alive on the streaming client so connection close reaches llama-server and its 1s liveness check fires. No behaviour changes for non-streaming paths or for existing callers that already pass max_tokens/session_id. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel path and scope cancel route - Require at least one identifier for /api/inference/cancel so a missing thread id cannot silently cancel every in-flight generation. - Scope /cancel to a dedicated studio_router so it is not exposed under the /v1 OpenAI-compat prefix as a surprise endpoint. - Store a set of cancel events per key in _CANCEL_REGISTRY so concurrent requests on the same session_id do not overwrite each other, and deduplicate in _cancel_by_keys so the cancelled count reflects unique requests. - Always send session_id with chat completions (not only when tools are enabled) so non-tool GGUF streams register under it and are reachable from /cancel. - Register the non-GGUF stream_chunks path in the cancel registry too, so transformers-based stop-button works behind proxies that swallow fetch aborts. - Only apply the 2-minute t_max_predict_ms wall-clock cap when the caller did not pass max_tokens, so legitimate long generations on slow CPU/macOS/Windows supported installs are not silently truncated. - Remove the abort listener on normal stream completion so reused AbortSignals cannot fire a spurious cancel POST after the fact. * studio: close cancel-race and stale-cancel gaps in stop path - Register the cancel tracker before returning StreamingResponse so a stop POST that arrives during prefill / warmup / proxy buffering finds an entry in _CANCEL_REGISTRY. Cleanup now runs via a Starlette BackgroundTask instead of a finally inside the async generator body. - Add a per-run cancel_id on the frontend (crypto.randomUUID) and in ChatCompletionRequest so /api/inference/cancel matches one specific generation. Removes the stale-cancel bug where pressing stop then starting a new run in the same thread would cancel the retry. - Apply t_max_predict_ms unconditionally in all three llama-server payload builders (previously gated on max_tokens=None, which made it dead code for UI callers that always send params.maxTokens). Raise the default to 10 minutes so slow CPU / macOS / Windows installs are not cut off mid-generation. - Make _cancel_by_keys refuse empty input (return 0) so a future internal caller can not accidentally mass-cancel every in-flight request. - Accept cancel_id (primary), session_id, and completion_id on the /api/inference/cancel route. Unify the three streaming sites on the same _cancel_keys / _tracker variable names. - Annotate _CANCEL_REGISTRY as dict[str, set[threading.Event]]. * Add review tests for PR #5069 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel semantics and wall-clock cap - Make /inference/cancel match cancel_id EXCLUSIVELY when supplied. Previously the handler iterated ('cancel_id','session_id','completion_id') and unioned matches, so a stale cancel POST carrying {cancel_id:old, session_id:thr} would still cancel a later run on the same thread via the shared session_id. cancel_id is now a per-run exclusive key; session_id / completion_id are only used as fallbacks when cancel_id is absent. - Close the early-cancel race. If /inference/cancel lands before the streaming handler reaches _TrackedCancel.__enter__() (stop clicked during prefill / warmup / proxy buffering), the cancel was silently dropped. Stash unmatched cancel_ids in _PENDING_CANCELS with a 30 s TTL; _TrackedCancel.__enter__() now replays any matching pending cancel by set()-ing the event immediately after registration. - Make t_max_predict_ms = _DEFAULT_T_MAX_PREDICT_MS conditional on max_tokens is None at all three llama-server payload sites. The cap is a safety net for callers who leave max_tokens unset (otherwise llama-server defaults n_predict to n_ctx, up to 262144). Callers who set an explicit max_tokens are already self-limiting and must not be silently truncated at 10 minutes on slow CPU / macOS / Windows legitimate long generations. - Guard each StreamingResponse return with try/except BaseException so _tracker.__exit__ runs even if StreamingResponse construction or any preceding statement raises between _tracker.__enter__() and the BackgroundTask attachment. Prevents a registry leak on that narrow window. * studio: close TOCTOU race and restore wall-clock backstop on UI path - Close TOCTOU race in the pending-cancel mechanism. The previous fix split cancel_inference's (cancel_by_keys + remember_pending_cancel) and _TrackedCancel.__enter__'s (register + consume_pending) into four separate lock acquisitions. Under contention a cancel POST could acquire-then-release the lock, find the registry empty, and stash ONLY AFTER __enter__ had already registered and consumed an empty pending map -- silently dropping the cancel. Both call sites now do their work inside a single _CANCEL_LOCK critical section, via the new atomic helper _cancel_by_cancel_id_or_stash() and an inlined consume-pending step in __enter__. Reproduced the race under forced interleaving pre-fix; 0/2000 drops post-fix under parallel stress. - Apply t_max_predict_ms UNCONDITIONALLY at all three llama-server payload sites. The previous iteration gated the cap on `max_tokens is None`, which turned out to be dead code on the primary Studio UI path: chat-adapter.ts sets maxTokens=loadResp.context_length after every model load, so every chat request carries an explicit max_tokens and the wall-clock safety net never fired. The cap's original purpose is to bound stuck decodes regardless of the token budget; it must always apply. - Raise _DEFAULT_T_MAX_PREDICT_MS from 10 minutes to 1 hour. 10 minutes was too aggressive for legitimate slow-CPU chat responses (a 4096-token reply at 2 tok/s takes ~34 min); 1 hour accommodates that and still catches genuine zombie decodes. - Prune _PENDING_CANCELS inside _cancel_by_keys as well, so stashed entries expire proportionally to overall cancel traffic rather than only to cancel_id-specific POSTs. * studio: trim verbose comments and docstrings in cancel path * studio/llama_cpp: drop upstream PR hashes from benchmark comment * Add review tests for Studio stop button * Consolidate review tests for Studio stop button * Align cancel-route test with exclusive cancel_id semantics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: move cancel cleanup to generator finally; drop dead helper - Move _tracker.__exit__ from Starlette BackgroundTask into each streaming generator's finally block. Starlette skips the background callback when stream_response raises (OSError / ClientDisconnect), which leaked _CANCEL_REGISTRY entries on abrupt disconnect. - Check cancel_event.is_set() at the top of each GGUF while loop so a pending-replay cancel falls through to final_chunk + [DONE] instead of propagating GeneratorExit out of _stream_with_retry. - Remove unused _remember_pending_cancel; _cancel_by_cancel_id_or_stash superseded it. * Add review tests for Studio stop-button * studio: wire audio-input stream into cancel registry - Register cancel_event with _TrackedCancel on the audio-input streaming path so POST /api/inference/cancel can stop whisper / audio-input GGUF runs. Previously the registry stayed empty on this branch, so the stop button returned {"cancelled":0} and the decode ran to completion. - Apply the same finally-based cleanup and pre-iteration cancel-event check used on the other three streaming paths. - Update the _CANCEL_REGISTRY block comment to list cancel_id as the primary key (was stale "session_id preferred"). * Consolidate review tests for Studio stop-button cancel flow - Merge the 6 behavioral tests from test_stream_cleanup_on_disconnect.py (finally cleanup on normal/exception/aclose, pre-set cancel_event pattern, and its regressions) into test_stream_cancel_registration_timing.py, which is the PR's existing file covering the same area. - Extend structural invariants to include audio_input_stream alongside the three GGUF / Unsloth streaming generators: no _tracker.__enter__ inside the async gen body, cleanup via try/finally, no background= on StreamingResponse. - Delete test_stream_cleanup_on_disconnect.py (now empty). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: make cancel-via-POST interrupt Unsloth and audio-input streams Close two remaining gaps in the stop-button cancellation wiring: - stream_chunks (Unsloth path): add a top-of-loop cancel_event check and call backend.reset_generation_state() so cancel POSTs flush GPU state and close the SSE cleanly instead of relying on request.is_disconnected (which does not fire through proxies like Colab's). - audio_input_stream: run the synchronous audio_input_generate() via asyncio.to_thread so blocking whisper chunks do not freeze the event loop, matching the pattern already used by the GGUF streaming paths. * Add review tests for Studio stop-button cancel flow * Consolidate review tests for Studio stop-button cancel flow - Delete standalone test_cancel_registry.py at repo root: tests duplicated test_cancel_atomicity.py / test_cancel_id_wiring.py and re-implemented registry primitives inline (scaffolding). - Extend tests/studio/test_stream_cancel_registration_timing.py with regression guards for the iter-1 cancel-loop fixes: structural: each streaming generator checks cancel_event in its loop; audio_input_stream offloads next() via asyncio.to_thread; stream_chunks cancel branch calls reset_generation_state(). runtime: Unsloth loop breaks on external cancel and resets state; audio loop stays responsive under blocking next(); both loops emit zero tokens on pre-set cancel (replay path). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: extend stop-path to passthrough streams; tighten wall-clock cap - Lower _DEFAULT_T_MAX_PREDICT_MS from 1 hour to 10 minutes so the wall-clock backstop actually bounds runaway decodes when cancel signaling fails. - Wire _TrackedCancel and cancel_event.is_set() into _openai_passthrough_stream and _anthropic_passthrough_stream and disable httpx keepalive so stop requests from /v1 and /v1/messages tool-calling clients reach llama-server. - Apply t_max_predict_ms to the tool-passthrough request body so the backstop covers passthrough paths as well. - Symmetric pre-registration stash for session_id/completion_id cancels (_cancel_by_keys_or_stash) so early cancels by those keys replay on later registration like cancel_id. - Drop dead except BaseException guards around StreamingResponse() at four streaming sites; cleanup lives in the generator's finally. * studio: harden cancel registry against ghost-cancel and leak paths - Revert the session_id/completion_id stash in the fallback cancel helper. session_id is thread-scoped and reused across runs, so stashing it on an unmatched POST would fire cancel_event for the user's next unrelated request via _TrackedCancel.__enter__. cancel_id remains the only per-run unique key that gets stashed. - Default max_tokens to _DEFAULT_MAX_TOKENS in the tool-passthrough body. Mirror the direct GGUF path so OpenAI/Anthropic passthrough callers who omit max_tokens get the same zombie-decode cap instead of relying on the wall-clock backstop alone. - Wrap _openai_passthrough_stream setup with an outer try/except BaseException. The inner except httpx.RequestError does not catch asyncio.CancelledError at await client.send, which would otherwise leave _tracker registered in _CANCEL_REGISTRY indefinitely. - Frontend stop POST uses plain fetch + manual Authorization header instead of authFetch. A 401 on the cancel POST no longer refreshes tokens or redirects the user to the login page mid-stop. * Add review tests for Studio stop-button cancel flow * studio: trim comments on stop-button review changes Collapse multi-paragraph rationale blocks on the cancel registry, _openai_passthrough_stream, and the frontend onAbortCancel handler into one-line explanations of why the non-obvious behaviour exists. Drop authFetch import that became unused when the cancel POST switched to plain fetch. * Consolidate review tests for Studio stop-button cancel flow Move review-added tests out of test_cancel_dispatch_edges.py into the existing PR test files that already cover the same areas: - backend registry fan-out / exclusivity / idempotency / falsy-keys edge cases moved into tests/studio/test_cancel_atomicity.py - frontend plain-fetch (not authFetch) + manual Authorization header moved into tests/studio/test_cancel_id_wiring.py Delete the now-empty test_cancel_dispatch_edges.py. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Studio: stop default-capping responses at 4096 tokens (follow-up to #5069) (#5174) * Studio: stop default-capping responses at 4096 tokens Follow-up to #5069. The 4096 default introduced for runaway-decode defense silently truncates any caller that omits max_tokens. The Studio chat UI sets params.maxTokens = loadResp.context_length after a GGUF load, so it's fine, but every other consumer is not: - OpenAI-API direct callers (/v1/chat/completions, /v1/responses, /v1/messages, /v1/completions) where the OpenAI default is effectively unlimited per response. langchain, llama-index, raw curl, and the openai SDK all rely on that. - Reasoning models. Qwen3 / gpt-oss reasoning traces routinely exceed 4096 tokens before the model emits a single visible content token. The user sees the trace cut off mid-thought. - Long-form generation ("write a chapter", "produce a full SVG"). Reproduced on this branch: gemma-4-E2B-it-GGUF Q8_0, prompt asking for a 10000-word story, no max_tokens in the request: finish_reason: stop (misleading -- should be 'length') content_chars: 19772 content_tail: ...'a comforting, yet immense, pressure.\n\n*"' Body ended mid-sentence on a stray opening quote, right at the 4096 token mark. After this patch the same request returns 38357 chars ending with '...held in a perfect, dynamic equilibrium.' -- a natural stop, not a truncation. Implementation: rename the constant to _DEFAULT_MAX_TOKENS_FLOOR and set it to 32768. Each call site now uses the model's effective context length when known, falling back to the floor: default_cap = self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR The 10-minute t_max_predict_ms wall-clock backstop from #5069 is preserved as the second line of defense. Plumbed _build_passthrough_payload + _build_openai_passthrough_body through the routes layer so the Anthropic and OpenAI passthrough paths also respect the model's context length. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Studio: cancel passthrough streams during llama-server prefill + route through apiUrl for Tauri Three reviewer-flagged correctness gaps in the stop-button mechanism. 1) `_openai_passthrough_stream` could not honor cancel during prefill. The cancel check ran inside the `async for raw_line in lines_iter` body, so a cancel POST that arrived before llama-server emitted the first SSE line was unobservable until prefill completed. With a long prompt under proxy/Colab conditions -- the exact target scenario for this PR -- that left the model decoding for a long time after the user clicked Stop. Add an asyncio watcher task that closes `resp` as soon as `cancel_event` is set, raising in `aiter_lines` so the generator can exit. The watcher polls a threading.Event because the cancel registry is keyed by threading.Event for the synchronous /cancel handler. 2) `_anthropic_passthrough_stream` had the same blocking-prefill pattern. Same fix. 3) The frontend's stop-button cancel POST used a bare relative `fetch("/api/inference/cancel", ...)`, which targets the webview origin in Tauri production builds (where the backend is at `http://127.0.0.1:8888`). Route through the existing `apiUrl()` helper from `lib/api-base.ts` to match every other Studio call. Browser/dev builds get the empty base, so behavior is unchanged there. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: cancel during prefill terminates within ~250ms on both passthrough paths (was 145s+ on the Anthropic path before this change), and the standard non-passthrough chat path still cancels with no regression. * Studio: log cancel-body parse errors instead of silently swallowing Reviewer-flagged defensive logging gap. The bare `except Exception: pass` in `cancel_inference` would mask malformed payloads that hint at a buggy client or a transport issue. Log at debug so future investigation isn't left guessing whether `body={}` came from a missing body or a parse failure. Behavior is unchanged: an unparseable body still falls through to the empty-dict path and the cancel call returns `{"cancelled": 0}`. * Studio: Anthropic passthrough cancel parity with OpenAI passthrough Two reviewer-flagged consistency gaps in the cancel surface for /v1/messages. 1) Anthropic passthrough did not register cancel_id, so a per-run cancel POST (the cleanest Studio-style cancel path) silently missed when the route hit `_anthropic_passthrough_stream`. The OpenAI passthrough has registered (cancel_id, session_id, completion_id) since this PR was first opened; mirror that here. Also add `cancel_id` to `AnthropicMessagesRequest` so the route handler can plumb it through. 2) The cancel handler's fallback key list checked only completion_id and session_id, never message_id. Anthropic clients that send their native `id` (returned in the SSE message_start event) for cancel had no way to hit the registry. Add message_id to the fallback list. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: P2 now cancels by cancel_id in 137ms (was hanging pre-fix), and the new P2b case cancels by message_id in 77ms. P1 (OpenAI) and P3 (standard chat) still pass with no regression. --------- Co-authored-by: danielhanchen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com> --- studio/backend/core/inference/llama_cpp.py | 57 +- studio/backend/main.py | 4 + studio/backend/models/inference.py | 5 + studio/backend/routes/__init__.py | 2 + studio/backend/routes/inference.py | 495 +++++++++--- .../src/features/chat/api/chat-adapter.ts | 41 +- .../frontend/src/features/chat/types/api.ts | 1 + tests/studio/test_cancel_atomicity.py | 289 +++++++ tests/studio/test_cancel_id_wiring.py | 169 +++++ tests/studio/test_llama_cpp_wall_clock_cap.py | 123 +++ .../test_stream_cancel_registration_timing.py | 718 ++++++++++++++++++ 11 files changed, 1785 insertions(+), 119 deletions(-) create mode 100644 tests/studio/test_cancel_atomicity.py create mode 100644 tests/studio/test_cancel_id_wiring.py create mode 100644 tests/studio/test_llama_cpp_wall_clock_cap.py create mode 100644 tests/studio/test_stream_cancel_registration_timing.py diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 35ef90ee3..81ff9ae4e 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -53,6 +53,21 @@ _INTENT_SIGNAL = re.compile( r")" ) _MAX_REPROMPTS = 3 + +# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to +# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel +# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally, +# but the llama.cpp README notes it ONLY fires after a newline has been +# generated -- a model stuck in a long unbroken non-newline sequence is +# unbounded by it. So we still want a token cap as the front-line limiter. +# +# The cap is the model's effective context length when we know it, +# falling back to a generous floor when metadata is unavailable. 4096 was +# too low: Qwen3 / gpt-oss reasoning traces routinely exceed it, and any +# OpenAI-API caller that omits max_tokens (langchain, llama-index, raw +# curl) sees responses silently truncated mid-sentence. +_DEFAULT_MAX_TOKENS_FLOOR = 32768 +_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min _REPROMPT_MAX_CHARS = 2000 # ── Pre-compiled patterns for GGUF shard detection ─────────── @@ -1636,7 +1651,7 @@ class LlamaCppBackend: # existing text (code refactoring, summarization, reasoning). # For general chat with low repetition, overhead is ~5 ms. # - # Benchmarks from llama.cpp PRs #18471, #19164: + # Benchmarks from upstream llama.cpp speculative-decoding PRs: # Scenario | Without | With | Speedup # gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x # Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x @@ -2549,8 +2564,15 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - payload["max_tokens"] = max_tokens + # Default cap to the model's effective context length when known, + # otherwise the conservative floor. The wall-clock backstop below + # keeps a stuck model from running indefinitely either way. + payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: payload["stop"] = stop payload["stream_options"] = {"include_usage": True} @@ -2570,7 +2592,9 @@ class LlamaCppBackend: _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) + ) as client: with self._stream_with_retry( client, url, @@ -2769,8 +2793,12 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - payload["max_tokens"] = max_tokens + payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: payload["stop"] = stop @@ -2809,7 +2837,10 @@ class LlamaCppBackend: write = 10, pool = 10, ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, + limits = httpx.Limits(max_keepalive_connections = 0), + ) as client: with self._stream_with_retry( client, url, @@ -3422,8 +3453,12 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: stream_payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - stream_payload["max_tokens"] = max_tokens + stream_payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + stream_payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: stream_payload["stop"] = stop stream_payload["stream_options"] = {"include_usage": True} @@ -3442,7 +3477,9 @@ class LlamaCppBackend: _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) + ) as client: with self._stream_with_retry( client, url, diff --git a/studio/backend/main.py b/studio/backend/main.py index 05adcaa2e..9212404b3 100644 --- a/studio/backend/main.py +++ b/studio/backend/main.py @@ -62,6 +62,7 @@ from routes import ( datasets_router, export_router, inference_router, + inference_studio_router, models_router, training_history_router, training_router, @@ -207,6 +208,9 @@ app.include_router(auth_router, prefix = "/api/auth", tags = ["auth"]) app.include_router(training_router, prefix = "/api/train", tags = ["training"]) app.include_router(models_router, prefix = "/api/models", tags = ["models"]) app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"]) +# Studio-only inference endpoints (cancel, etc.) are intentionally NOT +# exposed on the /v1 OpenAI-compat prefix below. +app.include_router(inference_studio_router, prefix = "/api/inference", tags = ["inference"]) # OpenAI-compatible endpoints: mount the same inference router at /v1 # so external tools (Open WebUI, SillyTavern, etc.) can use the diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index e5b037755..bf0177efb 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -531,6 +531,10 @@ class ChatCompletionRequest(BaseModel): None, description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.", ) + cancel_id: Optional[str] = Field( + None, + description = "[x-unsloth] Per-request cancellation token. Frontend sends a fresh UUID per run so /inference/cancel matches one specific generation.", + ) # ── Streaming response chunks ──────────────────────────────────── @@ -992,6 +996,7 @@ class AnthropicMessagesRequest(BaseModel): enable_tools: Optional[bool] = None enabled_tools: Optional[list[str]] = None session_id: Optional[str] = None + cancel_id: Optional[str] = None model_config = {"extra": "allow"} diff --git a/studio/backend/routes/__init__.py b/studio/backend/routes/__init__.py index e79f6553f..cf4586281 100644 --- a/studio/backend/routes/__init__.py +++ b/studio/backend/routes/__init__.py @@ -8,6 +8,7 @@ API Routes from routes.training import router as training_router from routes.models import router as models_router from routes.inference import router as inference_router +from routes.inference import studio_router as inference_studio_router from routes.datasets import router as datasets_router from routes.auth import router as auth_router from routes.data_recipe import router as data_recipe_router @@ -18,6 +19,7 @@ __all__ = [ "training_router", "models_router", "inference_router", + "inference_studio_router", "datasets_router", "auth_router", "data_recipe_router", diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index ed331a566..cf3b37a2f 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -113,7 +113,12 @@ if str(backend_path) not in sys.path: # Import backend functions try: from core.inference import get_inference_backend - from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags + from core.inference.llama_cpp import ( + LlamaCppBackend, + _DEFAULT_MAX_TOKENS_FLOOR, + _DEFAULT_T_MAX_PREDICT_MS, + detect_reasoning_flags, + ) from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults @@ -122,7 +127,12 @@ except ImportError: if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.inference import get_inference_backend - from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags + from core.inference.llama_cpp import ( + LlamaCppBackend, + _DEFAULT_MAX_TOKENS_FLOOR, + _DEFAULT_T_MAX_PREDICT_MS, + detect_reasoning_flags, + ) from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults @@ -185,6 +195,126 @@ import numpy as np from datetime import date as _date router = APIRouter() +# Studio-only router (not mounted on /v1 OpenAI-compat). +studio_router = APIRouter() + + +# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts +# so is_disconnected() never fires. POST /inference/cancel looks up +# in-flight cancel_events here by cancel_id (per-run) or session_id / +# completion_id (fallbacks). +_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {} +_CANCEL_LOCK = threading.Lock() + +# Cancel POSTs that arrive before registration are stashed; the next +# matching __enter__ replays set() within the TTL. +_PENDING_CANCELS: dict[str, float] = {} +_PENDING_CANCEL_TTL_S = 30.0 + + +def _prune_pending(now: float) -> None: + for k in [ + k for k, ts in _PENDING_CANCELS.items() if now - ts > _PENDING_CANCEL_TTL_S + ]: + _PENDING_CANCELS.pop(k, None) + + +class _TrackedCancel: + """Register cancel_event in _CANCEL_REGISTRY for the block's duration.""" + + def __init__(self, event: threading.Event, *keys): + self.event = event + self.keys = tuple(k for k in keys if k) + + def __enter__(self): + # Register + consume-pending must be one critical section to close + # the TOCTOU race against a concurrent cancel POST. + should_cancel = False + with _CANCEL_LOCK: + for k in self.keys: + _CANCEL_REGISTRY.setdefault(k, set()).add(self.event) + now = time.monotonic() + _prune_pending(now) + for k in self.keys: + if k and _PENDING_CANCELS.pop(k, None) is not None: + should_cancel = True + if should_cancel: + self.event.set() + return self.event + + def __exit__(self, *exc): + with _CANCEL_LOCK: + for k in self.keys: + bucket = _CANCEL_REGISTRY.get(k) + if bucket is None: + continue + bucket.discard(self.event) + if not bucket: + _CANCEL_REGISTRY.pop(k, None) + return False + + +def _cancel_by_keys(keys) -> int: + """Set cancel_event for matching registry entries; no stash. + session_id/completion_id are shared across runs on the same thread, + so stashing them would ghost-cancel the user's next request. Only + cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash).""" + if not keys: + return 0 + events: set[threading.Event] = set() + with _CANCEL_LOCK: + _prune_pending(time.monotonic()) + for k in keys: + bucket = _CANCEL_REGISTRY.get(k) + if bucket: + events.update(bucket) + for ev in events: + ev.set() + return len(events) + + +def _cancel_by_cancel_id_or_stash(cancel_id: str) -> int: + """Atomic lookup-or-stash; pairs with _TrackedCancel.__enter__ to + close the TOCTOU race.""" + now = time.monotonic() + events: set[threading.Event] = set() + with _CANCEL_LOCK: + _prune_pending(now) + bucket = _CANCEL_REGISTRY.get(cancel_id) + if bucket: + events.update(bucket) + else: + _PENDING_CANCELS[cancel_id] = now + for ev in events: + ev.set() + return len(events) + + +async def _await_cancel_then_close(cancel_event, resp) -> None: + """Watch a threading.Event from asyncio and close ``resp`` when it fires. + + Used by the passthrough streamers so a /cancel POST can interrupt + while the async iterator is blocked waiting for llama-server prefill. + Without this watcher the in-loop ``cancel_event.is_set()`` check is + unreachable until the first SSE chunk arrives, which is exactly the + proxy/Colab scenario the cancel POST exists to handle. + + Polls a threading.Event because the cancel registry is keyed by + threading.Event so the synchronous /cancel handler can call .set(). + 50ms cadence adds at most that much latency to a prefill cancel; the + common-case streaming cancel path still observes the event in the + iterator's first iteration after the next chunk. + """ + try: + while not cancel_event.is_set(): + await asyncio.sleep(0.05) + try: + await resp.aclose() + except Exception: + pass + except asyncio.CancelledError: + return + # Appended to tool-use nudge to discourage plan-without-action _TOOL_ACTION_NUDGE = ( @@ -706,6 +836,48 @@ async def unload_model( raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}") +@studio_router.post("/cancel") +async def cancel_inference( + request: Request, + current_subject: str = Depends(get_current_subject), +): + """Cancel in-flight inference requests. + + Body (JSON, at least one key required): + cancel_id - preferred: per-run UUID, matched exclusively. + session_id - fallback when cancel_id is absent. + completion_id - fallback when cancel_id is absent. + + A cancel_id arriving before its stream registers is stashed briefly + and replayed on registration. Returns {"cancelled": N}. + """ + try: + body = await request.json() + if not isinstance(body, dict): + body = {} + except Exception as e: + logger.debug("Failed to parse cancel request body: %s", e) + body = {} + + cancel_id = body.get("cancel_id") + if isinstance(cancel_id, str) and cancel_id: + return {"cancelled": _cancel_by_cancel_id_or_stash(cancel_id)} + + keys = [] + # `message_id` is the Anthropic passthrough's per-run identifier -- + # included so /v1/messages clients can cancel by their native id. + for k in ("completion_id", "session_id", "message_id"): + v = body.get(k) + if isinstance(v, str) and v: + keys.append(v) + + if not keys: + return {"cancelled": 0} + + n = _cancel_by_keys(keys) + return {"cancelled": n} + + @router.post("/generate/stream") async def generate_stream( request: GenerateRequest, @@ -1169,6 +1341,9 @@ async def openai_chat_completions( ) if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def audio_input_stream(): try: @@ -1185,10 +1360,17 @@ async def openai_chat_completions( ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" - for chunk_text in audio_input_generate(): + gen = audio_input_generate() + _DONE = object() + while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break if chunk_text: chunk = ChatCompletionChunk( id = completion_id, @@ -1221,6 +1403,8 @@ async def openai_chat_completions( f"Error during audio input streaming: {e}", exc_info = True ) yield f"data: {json.dumps({'error': {'message': _friendly_error(e), 'type': 'server_error'}})}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( audio_input_stream(), @@ -1466,6 +1650,10 @@ async def openai_chat_completions( _tool_sentinel = object() + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() + async def gguf_tool_stream(): try: first_chunk = ChatCompletionChunk( @@ -1488,6 +1676,8 @@ async def openai_chat_completions( _stream_usage = None _stream_timings = None while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return @@ -1595,6 +1785,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( gguf_tool_stream(), @@ -1628,6 +1820,9 @@ async def openai_chat_completions( _gguf_sentinel = object() if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def gguf_stream_chunks(): try: @@ -1652,6 +1847,8 @@ async def openai_chat_completions( _stream_usage = None _stream_timings = None while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return @@ -1735,6 +1932,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( gguf_stream_chunks(), @@ -1834,6 +2033,9 @@ async def openai_chat_completions( # ── Streaming response ──────────────────────────────────────── if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def stream_chunks(): try: @@ -1861,6 +2063,9 @@ async def openai_chat_completions( loop = asyncio.get_event_loop() gen = generate() while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break # next(gen, _DONE) returns _DONE instead of raising # StopIteration — StopIteration cannot propagate # through asyncio futures (Python limitation). @@ -1916,6 +2121,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( stream_chunks(), @@ -2596,7 +2803,9 @@ async def _responses_stream( ), ) - body = _build_openai_passthrough_body(chat_req) + body = _build_openai_passthrough_body( + chat_req, backend_ctx = llama_backend.context_length + ) target_url = f"{llama_backend.base_url}/v1/chat/completions" async def event_generator(): @@ -3081,6 +3290,8 @@ async def anthropic_messages( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = openai_tool_choice, + session_id = payload.session_id, + cancel_id = payload.cancel_id, ) return await _anthropic_passthrough_non_streaming( llama_backend, @@ -3441,6 +3652,7 @@ def _build_passthrough_payload( repetition_penalty = None, presence_penalty = None, tool_choice = "auto", + backend_ctx = None, ): body = { "messages": openai_messages, @@ -3453,8 +3665,12 @@ def _build_passthrough_payload( } if stream: body["stream_options"] = {"include_usage": True} - if max_tokens is not None: - body["max_tokens"] = max_tokens + body["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (backend_ctx or _DEFAULT_MAX_TOKENS_FLOOR) + ) + body["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: body["stop"] = stop if min_p is not None: @@ -3484,6 +3700,8 @@ async def _anthropic_passthrough_stream( repetition_penalty = None, presence_penalty = None, tool_choice = "auto", + session_id = None, + cancel_id = None, ): """Streaming client-side pass-through: forward tools to llama-server and translate its streaming response to Anthropic SSE without executing anything.""" @@ -3501,8 +3719,14 @@ async def _anthropic_passthrough_stream( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = tool_choice, + backend_ctx = llama_backend.context_length, ) + # cancel_id mirrors the OpenAI passthrough so a per-run cancel POST + # works without the caller having to know the local message_id. + _tracker = _TrackedCancel(cancel_event, cancel_id, session_id, message_id) + _tracker.__enter__() + async def _stream(): emitter = AnthropicPassthroughEmitter() for line in emitter.start(message_id, model_name): @@ -3535,15 +3759,28 @@ async def _anthropic_passthrough_stream( # has anything orphaned to finalize. Each aclose is wrapped in # `try: ... except Exception: pass` so anyio cleanup noise from # nested aclose paths can't bubble out. - client = httpx.AsyncClient(timeout = 600) + client = httpx.AsyncClient( + timeout = 600, + limits = httpx.Limits(max_keepalive_connections = 0), + ) resp = None lines_iter = None + cancel_watcher = None try: req = client.build_request("POST", target_url, json = body) resp = await client.send(req, stream = True) + # See _openai_passthrough_stream for rationale: aiter_lines() + # blocks during llama-server prefill, so the in-loop cancel + # check is unreachable until the first SSE chunk arrives. + # The watcher closes `resp` on cancel, raising in aiter_lines. + cancel_watcher = asyncio.create_task( + _await_cancel_then_close(cancel_event, resp) + ) lines_iter = resp.aiter_lines() async for raw_line in lines_iter: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() break @@ -3558,9 +3795,18 @@ async def _anthropic_passthrough_stream( continue for line in emitter.feed_chunk(chunk): yield line + except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): + if not cancel_event.is_set(): + raise except Exception as e: logger.error("anthropic_messages passthrough stream error: %s", e) finally: + if cancel_watcher is not None: + cancel_watcher.cancel() + try: + await cancel_watcher + except (asyncio.CancelledError, Exception): + pass if lines_iter is not None: try: await lines_iter.aclose() @@ -3575,6 +3821,7 @@ async def _anthropic_passthrough_stream( await client.aclose() except Exception: pass + _tracker.__exit__(None, None, None) for line in emitter.finish(): yield line @@ -3621,6 +3868,7 @@ async def _anthropic_passthrough_non_streaming( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = tool_choice, + backend_ctx = llama_backend.context_length, ) async with httpx.AsyncClient() as client: @@ -3742,7 +3990,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: return messages -def _build_openai_passthrough_body(payload) -> dict: +def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict: """Assemble the llama-server request body from a ChatCompletionRequest. Only explicitly-known OpenAI / llama-server fields are forwarded so that @@ -3764,6 +4012,7 @@ def _build_openai_passthrough_body(payload) -> dict: repetition_penalty = payload.repetition_penalty, presence_penalty = payload.presence_penalty, tool_choice = tool_choice, + backend_ctx = backend_ctx, ) @@ -3784,103 +4033,56 @@ async def _openai_passthrough_stream( observes a standard OpenAI response. """ target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body(payload) + body = _build_openai_passthrough_body( + payload, backend_ctx = llama_backend.context_length + ) - # Dispatch the upstream request BEFORE returning StreamingResponse so - # transport errors and non-200 upstream statuses surface as real HTTP - # errors to the client. OpenAI SDKs rely on status codes to raise - # ``APIError``/``BadRequestError``/...; burying the failure inside a - # 200 SSE ``error`` frame silently breaks their error handling. - client = httpx.AsyncClient(timeout = 600) - resp = None + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() + + # Outer guard: asyncio.CancelledError at `await client.send(...)` is + # a BaseException that bypasses `except httpx.RequestError`; without + # this the tracker leaks. The generator's finally only runs once + # iteration starts. try: - req = client.build_request("POST", target_url, json = body) - resp = await client.send(req, stream = True) - except httpx.RequestError as e: - # llama-server subprocess crashed / still starting / unreachable. - logger.error("openai passthrough stream: upstream unreachable: %s", e) - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = 502, - detail = _friendly_error(e), + # Dispatch BEFORE returning StreamingResponse so transport errors + # and non-200 upstream statuses surface as real HTTP errors -- + # OpenAI SDKs rely on status codes to raise APIError/BadRequestError. + client = httpx.AsyncClient( + timeout = 600, + limits = httpx.Limits(max_keepalive_connections = 0), ) - - if resp.status_code != 200: - err_bytes = await resp.aread() - err_text = err_bytes.decode("utf-8", errors = "replace") - logger.error( - "openai passthrough upstream error: status=%s body=%s", - resp.status_code, - err_text[:500], - ) - upstream_status = resp.status_code + resp = None try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = upstream_status, - detail = f"llama-server error: {err_text[:500]}", - ) - - async def _stream(): - # Same httpx lifecycle pattern as _anthropic_passthrough_stream: - # avoid `async with` on the client/response AND explicitly save - # resp.aiter_lines() so we can close it ourselves in the finally - # block. See the long comment there for the full rationale on - # why the anonymous `async for raw_line in resp.aiter_lines():` - # pattern leaks an unclosed async generator that Python's - # asyncgen GC hook then finalizes in a different asyncio task, - # producing "Exception ignored in:" / "async generator ignored - # GeneratorExit" / anyio cancel-scope traces on Python 3.13 + - # httpcore 1.0.x. - lines_iter = None - try: - lines_iter = resp.aiter_lines() - async for raw_line in lines_iter: - if await request.is_disconnected(): - cancel_event.set() - break - if not raw_line: - continue - if not raw_line.startswith("data: "): - continue - # Relay the llama-server SSE chunk verbatim so the client - # sees its native `id`, `finish_reason`, `delta.tool_calls`, - # and final `usage` unchanged. - yield raw_line + "\n\n" - if raw_line[6:].strip() == "[DONE]": - break - except Exception as e: - # Mid-stream failures still have to be reported inside the SSE - # body because the 200 response headers have already been - # committed by the time the first chunk flushes. - logger.error("openai passthrough stream error: %s", e) - err = { - "error": { - "message": _friendly_error(e), - "type": "server_error", - }, - } - yield f"data: {json.dumps(err)}\n\n" - finally: - if lines_iter is not None: + req = client.build_request("POST", target_url, json = body) + resp = await client.send(req, stream = True) + except httpx.RequestError as e: + # llama-server subprocess crashed / still starting / unreachable. + logger.error("openai passthrough stream: upstream unreachable: %s", e) + if resp is not None: try: - await lines_iter.aclose() + await resp.aclose() except Exception: pass + try: + await client.aclose() + except Exception: + pass + raise HTTPException( + status_code = 502, + detail = _friendly_error(e), + ) + + if resp.status_code != 200: + err_bytes = await resp.aread() + err_text = err_bytes.decode("utf-8", errors = "replace") + logger.error( + "openai passthrough upstream error: status=%s body=%s", + resp.status_code, + err_text[:500], + ) + upstream_status = resp.status_code try: await resp.aclose() except Exception: @@ -3889,16 +4091,91 @@ async def _openai_passthrough_stream( await client.aclose() except Exception: pass + raise HTTPException( + status_code = upstream_status, + detail = f"llama-server error: {err_text[:500]}", + ) - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) + async def _stream(): + # Same httpx lifecycle pattern as _anthropic_passthrough_stream: + # save resp.aiter_lines() so the finally block can aclose() it + # on our task. See that function for full rationale. + lines_iter = None + # During llama-server prefill, `aiter_lines()` blocks until the + # first SSE chunk arrives. The in-loop `cancel_event` check + # cannot fire until then, which is the exact proxy/Colab + # scenario the cancel POST is meant to recover from. Run a + # tiny watcher that closes `resp` as soon as cancel fires, + # unblocking the iterator with a RemoteProtocolError caught + # in the except clause below. + cancel_watcher = asyncio.create_task( + _await_cancel_then_close(cancel_event, resp) + ) + try: + lines_iter = resp.aiter_lines() + async for raw_line in lines_iter: + if cancel_event.is_set(): + break + if await request.is_disconnected(): + cancel_event.set() + break + if not raw_line: + continue + if not raw_line.startswith("data: "): + continue + # Relay verbatim to preserve llama-server's native id, + # finish_reason, delta.tool_calls, and usage chunks. + yield raw_line + "\n\n" + if raw_line[6:].strip() == "[DONE]": + break + except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): + # Watcher closed resp on cancel. Emit nothing extra; the + # client either initiated the cancel or already disconnected. + if not cancel_event.is_set(): + raise + except Exception as e: + # 200 headers are already flushed; errors must be in the SSE body. + logger.error("openai passthrough stream error: %s", e) + err = { + "error": { + "message": _friendly_error(e), + "type": "server_error", + }, + } + yield f"data: {json.dumps(err)}\n\n" + finally: + cancel_watcher.cancel() + try: + await cancel_watcher + except (asyncio.CancelledError, Exception): + pass + if lines_iter is not None: + try: + await lines_iter.aclose() + except Exception: + pass + try: + await resp.aclose() + except Exception: + pass + try: + await client.aclose() + except Exception: + pass + _tracker.__exit__(None, None, None) + + return StreamingResponse( + _stream(), + media_type = "text/event-stream", + headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + except BaseException: + _tracker.__exit__(None, None, None) + raise async def _openai_passthrough_non_streaming( @@ -3914,7 +4191,9 @@ async def _openai_passthrough_non_streaming( token counts. """ target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body(payload) + body = _build_openai_passthrough_body( + payload, backend_ctx = llama_backend.context_length + ) try: async with httpx.AsyncClient() as client: diff --git a/studio/frontend/src/features/chat/api/chat-adapter.ts b/studio/frontend/src/features/chat/api/chat-adapter.ts index a93120397..dde597ca1 100644 --- a/studio/frontend/src/features/chat/api/chat-adapter.ts +++ b/studio/frontend/src/features/chat/api/chat-adapter.ts @@ -4,6 +4,8 @@ import type { ChatModelAdapter } from "@assistant-ui/react"; import type { MessageTiming, ToolCallMessagePart } from "@assistant-ui/core"; import { toast } from "sonner"; +import { getAuthToken } from "@/features/auth/session"; +import { apiUrl } from "@/lib/api-base"; import { generateAudio, listCachedGguf, @@ -707,7 +709,42 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { const toolCallParts: ToolCallMessagePart[] = []; let serverMetadata: { usage?: ServerUsage; timings?: ServerTimings } | null = null; + // Per-run cancellation token so a delayed stop POST cannot match + // the next run on the same thread. + const cancelId = + typeof crypto !== "undefined" && "randomUUID" in crypto + ? crypto.randomUUID() + : `${Date.now()}-${Math.random().toString(36).slice(2)}`; + + // Colab-style proxies can swallow fetch aborts, so also POST + // /inference/cancel explicitly on abort. + const onAbortCancel = () => { + const body: Record = { cancel_id: cancelId }; + if (resolvedThreadId) body.session_id = resolvedThreadId; + // Plain fetch, not authFetch: authFetch redirects to login on + // 401, which would kick the user out mid-stop. + const token = getAuthToken(); + // Use apiUrl so the cancel POST reaches the right origin in + // Tauri production builds (where the webview origin is not the + // backend at 127.0.0.1:). Browser/dev builds get the empty + // base, so the path is unchanged there. + void fetch(apiUrl("/api/inference/cancel"), { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(token ? { Authorization: `Bearer ${token}` } : {}), + }, + body: JSON.stringify(body), + keepalive: true, + }).catch(() => {}); + }; try { + if (abortSignal.aborted) { + onAbortCancel(); + } else { + abortSignal.addEventListener("abort", onAbortCancel, { once: true }); + } + const { supportsReasoning, reasoningEnabled, @@ -730,6 +767,8 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { presence_penalty: params.presencePenalty, image_base64: imageBase64, audio_base64: audioBase64, + cancel_id: cancelId, + ...(resolvedThreadId ? { session_id: resolvedThreadId } : {}), ...(useAdapter === undefined ? {} : { use_adapter: useAdapter }), ...(supportsReasoning ? reasoningStyle === "reasoning_effort" @@ -750,7 +789,6 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { const mins = useChatRuntimeStore.getState().toolCallTimeout; return mins >= 9999 ? 9999 : mins * 60; })(), - session_id: resolvedThreadId, } : {}), }, @@ -948,6 +986,7 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { } throw err; } finally { + abortSignal.removeEventListener("abort", onAbortCancel); runtime.setGeneratingStatus(null); runtime.setToolStatus(null); clearTimeout(warmupTimer); diff --git a/studio/frontend/src/features/chat/types/api.ts b/studio/frontend/src/features/chat/types/api.ts index ccfc8b2bf..25957f4a7 100644 --- a/studio/frontend/src/features/chat/types/api.ts +++ b/studio/frontend/src/features/chat/types/api.ts @@ -179,6 +179,7 @@ export interface OpenAIChatCompletionsRequest { max_tool_calls_per_message?: number; tool_call_timeout?: number; session_id?: string; + cancel_id?: string; } export interface OpenAIChatDelta { diff --git a/tests/studio/test_cancel_atomicity.py b/tests/studio/test_cancel_atomicity.py new file mode 100644 index 000000000..a8d483945 --- /dev/null +++ b/tests/studio/test_cancel_atomicity.py @@ -0,0 +1,289 @@ +""" +TOCTOU atomicity guards for the cancel path. + +Structural: cancel_inference, _cancel_by_cancel_id_or_stash, and +_TrackedCancel.__enter__ must each use a single _CANCEL_LOCK critical +section over lookup + stash / register + consume-pending. + +Behavioral: parallel cancel-POST vs __enter__ must never drop a cancel. +""" + +from __future__ import annotations + +import ast +import random +import threading +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "routes" + / "inference.py" +) +_SRC = SOURCE_PATH.read_text() +_TREE = ast.parse(_SRC) + + +def _find_function(name: str) -> ast.FunctionDef | ast.AsyncFunctionDef: + for node in ast.walk(_TREE): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == name + ): + return node + raise AssertionError(f"function {name!r} not found") + + +def _find_class(name: str) -> ast.ClassDef: + for node in ast.walk(_TREE): + if isinstance(node, ast.ClassDef) and node.name == name: + return node + raise AssertionError(f"class {name!r} not found") + + +def _count_with_cancel_lock_blocks(node: ast.AST) -> int: + n = 0 + for sub in ast.walk(node): + if not isinstance(sub, ast.With): + continue + for item in sub.items: + ctx = item.context_expr + if isinstance(ctx, ast.Name) and ctx.id == "_CANCEL_LOCK": + n += 1 + break + return n + + +def test_cancel_by_cancel_id_or_stash_is_single_lock_critical_section(): + fn = _find_function("_cancel_by_cancel_id_or_stash") + assert _count_with_cancel_lock_blocks(fn) == 1, ( + "_cancel_by_cancel_id_or_stash must use exactly one `with " + "_CANCEL_LOCK:` block; splitting into two acquisitions reopens " + "the TOCTOU race with _TrackedCancel.__enter__" + ) + src = ast.unparse(fn) + assert "_CANCEL_REGISTRY.get(cancel_id)" in src + assert "_PENDING_CANCELS[cancel_id]" in src + + +def test_tracked_cancel_enter_registers_and_consumes_pending_under_one_lock(): + cls = _find_class("_TrackedCancel") + enter = None + for n in cls.body: + if isinstance(n, ast.FunctionDef) and n.name == "__enter__": + enter = n + break + assert enter is not None + assert _count_with_cancel_lock_blocks(enter) == 1, ( + "_TrackedCancel.__enter__ must acquire _CANCEL_LOCK exactly once. " + "A second acquisition for consume-pending lets a concurrent " + "cancel POST stash after consume sees an empty map, silently " + "dropping the cancel" + ) + with_block = None + for sub in ast.walk(enter): + if isinstance(sub, ast.With) and any( + isinstance(i.context_expr, ast.Name) and i.context_expr.id == "_CANCEL_LOCK" + for i in sub.items + ): + with_block = sub + break + assert with_block is not None + block_src = "\n".join(ast.unparse(s) for s in with_block.body) + assert "_CANCEL_REGISTRY.setdefault" in block_src + assert "_PENDING_CANCELS.pop" in block_src, ( + "__enter__ critical section must consume from _PENDING_CANCELS " + "inside the same lock, not a later re-acquisition" + ) + + +def test_cancel_inference_uses_atomic_helper_for_cancel_id_path(): + fn = _find_function("cancel_inference") + src = ast.unparse(fn) + assert "_cancel_by_cancel_id_or_stash" in src + # The pre-fix two-step idiom must be gone. + assert "_remember_pending_cancel(cancel_id)" not in src, ( + "two-step _cancel_by_keys + _remember_pending_cancel produced " + "the TOCTOU race and must not return" + ) + + +_WANTED = { + "_CANCEL_REGISTRY", + "_CANCEL_LOCK", + "_PENDING_CANCELS", + "_PENDING_CANCEL_TTL_S", + "_prune_pending", + "_remember_pending_cancel", + "_TrackedCancel", + "_cancel_by_keys", + "_cancel_by_cancel_id_or_stash", +} + + +def _load_registry_module(): + chunks = [] + for n in _TREE.body: + seg = ast.get_source_segment(_SRC, n) + if seg is None: + continue + if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED: + chunks.append(seg) + elif isinstance(n, ast.Assign): + names = [t.id for t in n.targets if isinstance(t, ast.Name)] + if any(name in _WANTED for name in names): + chunks.append(seg) + elif ( + isinstance(n, ast.AnnAssign) + and isinstance(n.target, ast.Name) + and n.target.id in _WANTED + ): + chunks.append(seg) + mod = {} + exec( + "import threading, time\nfrom typing import Optional\n" + "\n\n".join(chunks), + mod, + ) + return mod + + +def test_parallel_cancel_vs_register_never_drops(): + m = _load_registry_module() + trials = 500 + dropped = 0 + for i in range(trials): + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + cid = f"cid-{i}" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread") + start = threading.Event() + + def do_cancel(): + start.wait() + m["_cancel_by_cancel_id_or_stash"](cid) + + def do_enter(): + start.wait() + tracker.__enter__() + + threads = [ + threading.Thread(target = do_cancel), + threading.Thread(target = do_enter), + ] + random.shuffle(threads) + for t in threads: + t.start() + start.set() + for t in threads: + t.join(timeout = 5.0) + assert not t.is_alive() + + if not ev.is_set(): + dropped += 1 + tracker.__exit__(None, None, None) + + assert dropped == 0, ( + f"TOCTOU regression: {dropped}/{trials} parallel trials silently " + f"dropped the cancel" + ) + + +def test_cancel_before_register_replays_atomically(): + m = _load_registry_module() + cid = "early-cid" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread-x") + + assert m["_cancel_by_cancel_id_or_stash"](cid) == 0 + assert cid in m["_PENDING_CANCELS"] + + tracker.__enter__() + assert ev.is_set() + assert cid not in m["_PENDING_CANCELS"] + tracker.__exit__(None, None, None) + + +def test_cancel_after_register_signals_without_stash(): + m = _load_registry_module() + cid = "post-cid" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread-y") + tracker.__enter__() + + assert m["_cancel_by_cancel_id_or_stash"](cid) == 1 + assert ev.is_set() + assert cid not in m["_PENDING_CANCELS"] + tracker.__exit__(None, None, None) + + +def test_cancel_by_keys_tolerates_empty_and_falsy_keys(): + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + assert m["_cancel_by_keys"]([]) == 0 + assert m["_cancel_by_keys"](["", None, "unknown"]) == 0 + # Non-stashing fallback must never leak into _PENDING_CANCELS. + assert m["_PENDING_CANCELS"] == {} + + +def test_cancel_by_keys_fans_out_to_all_streams_on_same_session(): + # Compare mode and other flows launch concurrent streams under a + # shared session_id; a single session cancel POST must hit all of them. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + session = "shared-thread" + ev_a = threading.Event() + ev_b = threading.Event() + tracker_a = m["_TrackedCancel"](ev_a, "cancel-a", session, "chatcmpl-a") + tracker_b = m["_TrackedCancel"](ev_b, "cancel-b", session, "chatcmpl-b") + tracker_a.__enter__() + tracker_b.__enter__() + try: + assert m["_cancel_by_keys"]([session]) == 2 + assert ev_a.is_set() and ev_b.is_set() + finally: + tracker_a.__exit__(None, None, None) + tracker_b.__exit__(None, None, None) + assert session not in m["_CANCEL_REGISTRY"] + + +def test_cancel_by_cancel_id_is_exclusive_to_single_run(): + # cancel_id is per-run unique; cancelling run A must not touch run B + # even when both share a session_id. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + session = "shared-thread-2" + ev_a = threading.Event() + ev_b = threading.Event() + tracker_a = m["_TrackedCancel"](ev_a, "cancel-only-a", session, "chatcmpl-a") + tracker_b = m["_TrackedCancel"](ev_b, "cancel-only-b", session, "chatcmpl-b") + tracker_a.__enter__() + tracker_b.__enter__() + try: + assert m["_cancel_by_cancel_id_or_stash"]("cancel-only-a") == 1 + assert ev_a.is_set() + assert not ev_b.is_set() + finally: + tracker_a.__exit__(None, None, None) + tracker_b.__exit__(None, None, None) + + +def test_tracked_cancel_exit_is_idempotent(): + # Outer except BaseException + the generator's finally may both call + # __exit__ under certain race combos; must not raise. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, "cid", "sess", "chatcmpl-x") + tracker.__enter__() + tracker.__exit__(None, None, None) + tracker.__exit__(None, None, None) + tracker.__exit__(None, None, None) + assert not m["_CANCEL_REGISTRY"] diff --git a/tests/studio/test_cancel_id_wiring.py b/tests/studio/test_cancel_id_wiring.py new file mode 100644 index 000000000..5fd76ded9 --- /dev/null +++ b/tests/studio/test_cancel_id_wiring.py @@ -0,0 +1,169 @@ +""" +Wiring tests for the per-run cancel_id field. + +A chat-thread-scoped session_id is not safe as a cancel key because a +late stop POST can match a subsequent run on the same thread. The fix +adds cancel_id (a fresh UUID per generation) that is sent both in the +completion payload and in the /api/inference/cancel body. + +Verifies: + - ChatCompletionRequest exposes an Optional[str] `cancel_id` field. + - /api/inference/cancel accepts `cancel_id` as the first-preferred key. + - OpenAIChatCompletionsRequest (frontend type) includes cancel_id. + - chat-adapter.ts generates a per-run cancelId (crypto.randomUUID + with a Math.random fallback), sends it in the completion payload, + and includes it in the /inference/cancel body on abort. +""" + +from __future__ import annotations + +import ast +import re +from pathlib import Path + + +WORKSPACE = Path(__file__).resolve().parents[2] +MODELS_SRC = (WORKSPACE / "studio/backend/models/inference.py").read_text() +ROUTES_SRC = (WORKSPACE / "studio/backend/routes/inference.py").read_text() +ADAPTER_SRC = ( + WORKSPACE / "studio/frontend/src/features/chat/api/chat-adapter.ts" +).read_text() +API_TYPES_SRC = ( + WORKSPACE / "studio/frontend/src/features/chat/types/api.ts" +).read_text() + + +def _find_class(tree: ast.AST, name: str) -> ast.ClassDef | None: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == name: + return node + return None + + +def test_chat_completion_request_has_cancel_id_field(): + tree = ast.parse(MODELS_SRC) + cls = _find_class(tree, "ChatCompletionRequest") + assert cls is not None + fields = { + n.target.id + for n in cls.body + if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name) + } + assert "cancel_id" in fields, ( + "ChatCompletionRequest must expose a cancel_id field for per-run " + "cancellation routing" + ) + + +def test_cancel_route_matches_cancel_id_exclusively_when_present(): + # A stale cancel POST carrying cancel_id AND session_id must not + # cancel a later run on the same thread via the shared session_id. + # Enforce this by requiring the handler to early-return through an + # exclusive-cancel_id path -- either an atomic helper or a keys + # list containing ONLY cancel_id (never session_id). + for node in ast.walk(ast.parse(ROUTES_SRC)): + if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference": + break + else: + raise AssertionError("cancel_inference handler missing") + + cancel_id_exclusive_branch = False + for sub in ast.walk(node): + if not isinstance(sub, ast.If): + continue + test_src = ast.unparse(sub.test) + if "cancel_id" not in test_src or "isinstance" not in test_src: + continue + branch_src = "\n".join(ast.unparse(s) for s in sub.body) + before_return = branch_src.split("return", 1)[0] + matches_cancel_id_only = ( + "_cancel_by_cancel_id_or_stash(cancel_id)" in branch_src + or "_cancel_by_keys([cancel_id])" in branch_src + ) + if matches_cancel_id_only and "session_id" not in before_return: + cancel_id_exclusive_branch = True + break + assert cancel_id_exclusive_branch, ( + "cancel_inference must early-return with an exclusive cancel_id " + "match when a cancel_id is supplied, so a stale stop POST " + "cannot cancel a later run on the same thread via session_id" + ) + + +def test_cancel_route_falls_back_to_session_or_completion_when_no_cancel_id(): + for node in ast.walk(ast.parse(ROUTES_SRC)): + if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference": + break + else: + raise AssertionError("cancel_inference handler missing") + + src = ast.unparse(node) + assert "session_id" in src and "completion_id" in src, ( + "cancel_inference must still accept session_id / completion_id as " + "fallback keys when cancel_id is absent" + ) + + +def test_frontend_request_type_has_cancel_id(): + assert re.search( + r"cancel_id\?\s*:\s*string\s*;", API_TYPES_SRC + ), "OpenAIChatCompletionsRequest must expose an optional cancel_id" + + +def test_chat_adapter_generates_cancel_id_per_run(): + m = re.search( + r"const\s+cancelId\s*=\s*([^;]+);", + ADAPTER_SRC, + ) + assert m, "chat-adapter.ts must declare a per-run `cancelId` constant" + rhs = m.group(1) + assert ( + "randomUUID" in rhs + ), "cancelId should prefer crypto.randomUUID() for uniqueness" + + +def test_chat_adapter_sends_cancel_id_in_completion_payload(): + assert "cancel_id: cancelId" in ADAPTER_SRC, ( + "chat-adapter.ts must include cancel_id in the streamChatCompletions " + "request payload so the backend registers under that key" + ) + + +def test_chat_adapter_sends_cancel_id_in_abort_cancel_post(): + m = re.search( + r"const\s+onAbortCancel\s*=\s*\(\)\s*=>\s*\{(.*?)\};", + ADAPTER_SRC, + flags = re.DOTALL, + ) + assert m, "onAbortCancel arrow function missing" + body = m.group(1) + assert re.search(r"cancel_id\s*:\s*cancelId", body), ( + "onAbortCancel must include cancel_id in the /inference/cancel body " + "so a stop POST matches the specific run, not the whole thread" + ) + + +def test_abort_cancel_post_uses_plain_fetch_with_manual_auth_header(): + # authFetch redirects to login on 401, which would kick the user to + # the login page mid-stop if the access token expired during a long + # stream. Use plain fetch + manual Authorization header for a + # best-effort cancel that never triggers the refresh/redirect flow. + start = ADAPTER_SRC.find("const onAbortCancel") + assert start >= 0, "onAbortCancel handler missing" + rest = ADAPTER_SRC[start:] + end = rest.find("\n try {") + body = rest if end < 0 else rest[:end] + assert "/api/inference/cancel" in body + assert "authFetch(" not in body, ( + "onAbortCancel must NOT call authFetch; a 401 from it would " + "redirect the user to the login page during a stop click" + ) + assert "fetch(" in body, "onAbortCancel must use plain fetch(...)" + assert "getAuthToken" in body, ( + "onAbortCancel must read the bearer token via getAuthToken() " + "rather than relying on authFetch's 401 flow" + ) + assert "Authorization" in body + assert ( + "keepalive: true" in body + ), "keepalive is required so the fetch survives page unload during stop" diff --git a/tests/studio/test_llama_cpp_wall_clock_cap.py b/tests/studio/test_llama_cpp_wall_clock_cap.py new file mode 100644 index 000000000..671abea82 --- /dev/null +++ b/tests/studio/test_llama_cpp_wall_clock_cap.py @@ -0,0 +1,123 @@ +""" +Tests for the llama-server wall-clock cap (t_max_predict_ms). + +The UI always sends max_tokens = context_length, so gating +t_max_predict_ms on `max_tokens is None` makes the safety net dead +code. The fix applies the wall-clock cap unconditionally on all three +streaming payload sites and raises the default to 10 minutes so slow +CPU / macOS / Windows installs are not cut off mid-generation. + +Verifies: + - t_max_predict_ms is assigned unconditionally at the three + payload-builder sites (not inside an `if max_tokens is None` else + branch). + - _DEFAULT_T_MAX_PREDICT_MS is at least 10 minutes (previously + 120_000). + - The default max_tokens path still applies _DEFAULT_MAX_TOKENS. + - The three payload variable names (payload x2, stream_payload x1) + each get both `max_tokens` and `t_max_predict_ms`. +""" + +from __future__ import annotations + +import ast +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "core" + / "inference" + / "llama_cpp.py" +) +SRC = SOURCE_PATH.read_text() +TREE = ast.parse(SRC) + + +def _is_subscript_assign(stmt: ast.stmt, target_name: str, key: str) -> bool: + if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1: + return False + t = stmt.targets[0] + if not isinstance(t, ast.Subscript): + return False + if not (isinstance(t.value, ast.Name) and t.value.id == target_name): + return False + slc = t.slice + return isinstance(slc, ast.Constant) and slc.value == key + + +def _collect_assignments(tree, target_name, key): + """Return list of (node, stack_of_enclosing_ifs) for each match.""" + hits = [] + + def visit(node, stack): + if _is_subscript_assign(node, target_name, key): + hits.append((node, stack)) + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.If): + for sub in child.body: + visit(sub, stack + [(child, "body")]) + for sub in child.orelse: + visit(sub, stack + [(child, "orelse")]) + else: + visit(child, stack) + + visit(tree, []) + return hits + + +def test_default_t_max_predict_ms_is_at_least_ten_minutes(): + for node in TREE.body: + if isinstance(node, ast.Assign) and len(node.targets) == 1: + t = node.targets[0] + if isinstance(t, ast.Name) and t.id == "_DEFAULT_T_MAX_PREDICT_MS": + value = node.value + assert isinstance(value, ast.Constant) + assert value.value >= 600_000, ( + f"_DEFAULT_T_MAX_PREDICT_MS must be >= 10 minutes " + f"(600_000 ms) to avoid cutting off slow-CPU generations; " + f"got {value.value}" + ) + return + raise AssertionError("_DEFAULT_T_MAX_PREDICT_MS constant missing") + + +def test_t_max_predict_ms_set_unconditionally_at_three_sites(): + hits_payload = _collect_assignments(TREE, "payload", "t_max_predict_ms") + hits_stream = _collect_assignments(TREE, "stream_payload", "t_max_predict_ms") + total = len(hits_payload) + len(hits_stream) + assert total == 3, ( + f"expected 3 total t_max_predict_ms assignments " + f"(payload x2 + stream_payload x1), got {total}" + ) + for node, stack in hits_payload + hits_stream: + for parent_if, branch in stack: + # The assignment must not be gated by a test that checks + # `max_tokens is None` (which would make it dead code for + # the UI path where max_tokens is always set). + test_src = ast.unparse(parent_if.test) + assert "max_tokens" not in test_src, ( + f"t_max_predict_ms at line {node.lineno} is nested under " + f"`if {test_src}:` -- it must be applied unconditionally so " + f"the wall-clock cap is not dead code for callers that set " + f"max_tokens" + ) + + +def test_max_tokens_default_cap_still_applied(): + # _DEFAULT_MAX_TOKENS must still kick in when caller passes None. + # We check the conditional expression `max_tokens if max_tokens is not + # None else _DEFAULT_MAX_TOKENS` appears at each site. + matches = 0 + for node in ast.walk(TREE): + if not isinstance(node, ast.IfExp): + continue + src = ast.unparse(node) + if "max_tokens" in src and "_DEFAULT_MAX_TOKENS" in src: + matches += 1 + assert matches >= 3, ( + f"expected >=3 `max_tokens if max_tokens is not None else " + f"_DEFAULT_MAX_TOKENS` expressions; got {matches}" + ) diff --git a/tests/studio/test_stream_cancel_registration_timing.py b/tests/studio/test_stream_cancel_registration_timing.py new file mode 100644 index 000000000..40ec3d6e1 --- /dev/null +++ b/tests/studio/test_stream_cancel_registration_timing.py @@ -0,0 +1,718 @@ +""" +Tests that the cancel tracker is registered BEFORE StreamingResponse is +returned, and that cleanup runs via a `finally` block inside each +async generator. + +The zombie-generation scenario is: user clicks Stop during prefill / +warmup / proxy buffering, before the first SSE chunk. If _tracker +__enter__ lives inside the async generator body, the registry is empty +at the moment /api/inference/cancel lands -- so cancel returns 0 and +the decode runs to completion. + +The fix moves _tracker = _TrackedCancel(...) and _tracker.__enter__() +to the synchronous body of openai_chat_completions (before the +StreamingResponse is returned) and places _tracker.__exit__ inside +each generator's `finally` block. Using a generator `finally` (rather +than a Starlette BackgroundTask) guarantees cleanup on every +termination path -- normal exhaustion, CancelledError from +ClientDisconnect, and OSError / BrokenPipeError during send() -- +because Starlette skips `background` callbacks when stream_response +raises. + +Structural verifies: + - No `async def ...:` body contains `_tracker.__enter__()` in + routes/inference.py (registration moved to sync body). + - Each of the four async generators (gguf_tool_stream, + gguf_stream_chunks, stream_chunks, audio_input_stream) contains + `_tracker.__exit__(None, None, None)` inside a try/finally block. + - No StreamingResponse in openai_chat_completions passes + `background=` (cleanup now lives in the generator finally). + +Behavioral verifies (extracting `_TrackedCancel` from source and +exercising the actual runtime semantics): + - `finally: _tracker.__exit__(...)` runs on normal completion, + mid-stream exception (OSError / BrokenPipeError from send()), + and aclose() from Starlette ClientDisconnect. + - A pre-set cancel_event (from `_TrackedCancel.__enter__` replaying + a pending cancel POST) lets the GGUF while-loop break cleanly + and emit final_chunk + [DONE] instead of propagating + `GeneratorExit` out of `_stream_with_retry` into the async + generator's `except Exception` (which would not catch it). +""" + +from __future__ import annotations + +import ast +import asyncio +import threading +import time +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "routes" + / "inference.py" +) +SRC = SOURCE_PATH.read_text() +_TREE = ast.parse(SRC) + + +# ── Structural (AST) helpers ───────────────────────────────── + + +def _collect_async_functions(tree: ast.AST): + return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)] + + +def _has_tracker_enter_call(node: ast.AST) -> bool: + for sub in ast.walk(node): + if not isinstance(sub, ast.Call): + continue + fn = sub.func + if ( + isinstance(fn, ast.Attribute) + and fn.attr == "__enter__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + return True + return False + + +def _finalbody_has_tracker_exit(finalbody) -> bool: + for stmt in finalbody: + if not isinstance(stmt, ast.Expr): + continue + call = stmt.value + if not (isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute)): + continue + fn = call.func + if ( + fn.attr == "__exit__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + return True + return False + + +# ── Structural tests ───────────────────────────────────────── + + +def test_no_tracker_enter_inside_async_generators(): + offenders = [] + for fn in _collect_async_functions(_TREE): + if fn.name in { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + }: + if _has_tracker_enter_call(fn): + offenders.append(fn.name) + assert not offenders, ( + f"Cancel tracker registration must live OUTSIDE the async generator " + f"body so a stop POST can find the registry entry before the first " + f"SSE chunk. Offending generators: {offenders}" + ) + + +def test_tracker_enter_exists_in_sync_body_of_chat_completions(): + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None, "openai_chat_completions handler missing" + count = 0 + for sub in ast.walk(top): + if not isinstance(sub, ast.Call): + continue + fn = sub.func + if ( + isinstance(fn, ast.Attribute) + and fn.attr == "__enter__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + count += 1 + assert count >= 3, ( + f"expected >=3 _tracker.__enter__() calls in openai_chat_completions " + f"(one per streaming path), got {count}" + ) + + +def test_async_generators_cleanup_tracker_in_finally(): + required = { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + } + found: set[str] = set() + for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]: + if fn.name not in required: + continue + for sub in ast.walk(fn): + if isinstance(sub, ast.Try) and sub.finalbody: + if _finalbody_has_tracker_exit(sub.finalbody): + found.add(fn.name) + break + missing = required - found + assert not missing, ( + f"Cleanup must run via `finally: _tracker.__exit__(None, None, None)` " + f"inside each streaming generator so ClientDisconnect / OSError paths " + f"also release registry entries (Starlette skips `background` callbacks " + f"when stream_response raises). Missing in: {sorted(missing)}" + ) + + +def test_streaming_responses_have_no_background_task(): + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None + for sub in ast.walk(top): + if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)): + continue + if sub.func.id != "StreamingResponse": + continue + kwargs = {kw.arg for kw in sub.keywords if kw.arg} + assert "background" not in kwargs, ( + "StreamingResponse in openai_chat_completions must not pass " + "`background=` -- cleanup now lives in the generator's finally " + "block; a BackgroundTask would be skipped on abrupt disconnect" + ) + + +# ── Behavioral helpers ─────────────────────────────────────── + +_WANTED = { + "_CANCEL_REGISTRY", + "_CANCEL_LOCK", + "_PENDING_CANCELS", + "_PENDING_CANCEL_TTL_S", + "_prune_pending", + "_TrackedCancel", + "_cancel_by_keys", + "_cancel_by_cancel_id_or_stash", +} + + +def _load_registry_module(): + chunks = [] + for n in _TREE.body: + seg = ast.get_source_segment(SRC, n) + if seg is None: + continue + if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED: + chunks.append(seg) + elif isinstance(n, ast.Assign): + names = [t.id for t in n.targets if isinstance(t, ast.Name)] + if any(name in _WANTED for name in names): + chunks.append(seg) + elif ( + isinstance(n, ast.AnnAssign) + and isinstance(n.target, ast.Name) + and n.target.id in _WANTED + ): + chunks.append(seg) + mod = {} + exec("import threading, time\n" + "\n\n".join(chunks), mod) + return mod + + +def _make_stream(tracker, raise_exc): + async def gen(): + try: + try: + yield "data: first\n\n" + if raise_exc is not None: + raise raise_exc + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + raise + except Exception: + yield "data: error\n\n" + finally: + tracker.__exit__(None, None, None) + except BaseException: + raise + + return gen() + + +async def _consume(agen): + out = [] + try: + async for ch in agen: + out.append(ch) + except BaseException as e: + out.append(type(e).__name__) + return out + + +def _llama_stub_raises_on_preset_cancel(cancel_event): + # Reproduces llama_cpp.py _stream_with_retry:2240 `raise GeneratorExit` + # when cancel_event is already set at entry. + if cancel_event.is_set(): + raise GeneratorExit + yield "cumulative-1" + yield "cumulative-2" + + +async def _post_fix_gguf_loop(cancel_event): + yield "first_chunk" + gen = _llama_stub_raises_on_preset_cancel(cancel_event) + sentinel = object() + while True: + if cancel_event.is_set(): + break + cumulative = await asyncio.to_thread(next, gen, sentinel) + if cumulative is sentinel: + break + yield cumulative + yield "final_chunk" + yield "[DONE]" + + +# ── Behavioral tests ───────────────────────────────────────── + + +def test_finally_cleanup_on_normal_completion(): + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-ok", "sid-ok") + tr.__enter__() + assert "cid-ok" in m["_CANCEL_REGISTRY"] + chunks = asyncio.run(_consume(_make_stream(tr, None))) + assert chunks == ["data: first\n\n", "data: [DONE]\n\n"] + assert "cid-ok" not in m["_CANCEL_REGISTRY"] + assert "sid-ok" not in m["_CANCEL_REGISTRY"] + + +def test_finally_cleanup_on_mid_stream_exception(): + # Simulates OSError / BrokenPipeError from Starlette send() mid-stream -- + # the exact case where pre-fix `background = BackgroundTask(...)` was + # skipped and leaked the registry entry. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-err", "sid-err") + tr.__enter__() + assert "cid-err" in m["_CANCEL_REGISTRY"] + asyncio.run(_consume(_make_stream(tr, OSError("disconnect")))) + assert "cid-err" not in m["_CANCEL_REGISTRY"] + assert "sid-err" not in m["_CANCEL_REGISTRY"] + + +def test_finally_cleanup_on_aclose(): + # Starlette calls aclose() on the async generator when the client + # disconnects mid-stream. The generator's finally block must run. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-abort", "sid-abort") + tr.__enter__() + assert "cid-abort" in m["_CANCEL_REGISTRY"] + + async def run(): + gen = _make_stream(tr, None) + it = gen.__aiter__() + await it.__anext__() + await gen.aclose() + + asyncio.run(run()) + assert "cid-abort" not in m["_CANCEL_REGISTRY"] + assert "sid-abort" not in m["_CANCEL_REGISTRY"] + + +def test_preset_cancel_event_exits_cleanly_with_done(): + # Pending-replay: POST /cancel arrived before the stream registered, + # was stashed, then consumed by _TrackedCancel.__enter__ which set + # cancel_event. The generator must break out of the loop cleanly + # and emit final_chunk + [DONE] rather than calling next(gen) and + # propagating `GeneratorExit` out of the GGUF stream wrapper. + ev = threading.Event() + ev.set() + chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev))) + assert "first_chunk" in chunks + assert "final_chunk" in chunks + assert "[DONE]" in chunks + assert "GeneratorExit" not in chunks + assert "cumulative-1" not in chunks + assert "cumulative-2" not in chunks + + +def test_normal_path_streams_all_tokens(): + # Regression: the top-of-loop cancel_event check must not short-circuit + # when cancel_event is unset. + ev = threading.Event() + chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev))) + assert chunks == [ + "first_chunk", + "cumulative-1", + "cumulative-2", + "final_chunk", + "[DONE]", + ] + + +def test_cancel_during_streaming_stops_iteration_promptly(): + # Setting cancel_event between yields breaks out on the next iteration + # rather than draining the stub generator. + ev = threading.Event() + + async def _run(): + gen = _post_fix_gguf_loop(ev) + seen = [] + async for ch in gen: + seen.append(ch) + if ch == "cumulative-1": + ev.set() + return seen + + seen = asyncio.run(_run()) + assert "first_chunk" in seen + assert "cumulative-1" in seen + assert "cumulative-2" not in seen + assert "final_chunk" in seen + assert "[DONE]" in seen + + +# ── Cancel-event responsiveness in the streaming loops ─────── + + +def _loop_has_cancel_event_check(fn) -> bool: + # An `if cancel_event.is_set():` statement anywhere inside a + # `while`/`for` loop body is sufficient -- without it, a cancel POST + # cannot interrupt the loop because Colab-style proxies do not + # propagate request.is_disconnected(). + for sub in ast.walk(fn): + if not isinstance(sub, (ast.While, ast.For, ast.AsyncFor)): + continue + for stmt in ast.walk(sub): + if not isinstance(stmt, ast.If): + continue + t = stmt.test + if ( + isinstance(t, ast.Call) + and isinstance(t.func, ast.Attribute) + and t.func.attr == "is_set" + and isinstance(t.func.value, ast.Name) + and t.func.value.id == "cancel_event" + ): + return True + return False + + +def test_streaming_generators_check_cancel_event_in_loop(): + required = { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + } + missing = [] + for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]: + if fn.name not in required: + continue + if not _loop_has_cancel_event_check(fn): + missing.append(fn.name) + assert not missing, ( + f"Each streaming generator must check `cancel_event.is_set()` inside " + f"its main loop so `POST /api/inference/cancel` can interrupt the " + f"stream through proxies that do not forward fetch aborts. " + f"Missing in: {sorted(missing)}" + ) + + +def test_audio_input_stream_offloads_blocking_next_to_thread(): + # Guards against regression back to `for chunk_text in + # audio_input_generate():` -- which blocks the event loop on each + # whisper chunk and prevents POST /api/inference/cancel from being + # serviced until the chunk yields. + audio = None + for fn in ast.walk(_TREE): + if isinstance(fn, ast.AsyncFunctionDef) and fn.name == "audio_input_stream": + audio = fn + break + assert audio is not None, "audio_input_stream generator missing" + + for sub in ast.walk(audio): + if isinstance(sub, (ast.For, ast.AsyncFor)): + it_src = ast.unparse(sub.iter) + assert "audio_input_generate" not in it_src, ( + "audio_input_stream must not iterate audio_input_generate() " + "directly -- that blocks the event loop. Use " + "`await asyncio.to_thread(next, gen, _DONE)` inside a " + "`while True` loop instead" + ) + + found_to_thread_next = False + for sub in ast.walk(audio): + if not isinstance(sub, ast.Call): + continue + fn_expr = sub.func + if not ( + isinstance(fn_expr, ast.Attribute) + and fn_expr.attr == "to_thread" + and isinstance(fn_expr.value, ast.Name) + and fn_expr.value.id == "asyncio" + ): + continue + if sub.args and isinstance(sub.args[0], ast.Name) and sub.args[0].id == "next": + found_to_thread_next = True + break + assert found_to_thread_next, ( + "audio_input_stream must call `asyncio.to_thread(next, gen, ...)` " + "to keep the event loop free while whisper yields the next chunk" + ) + + +def test_stream_chunks_cancel_branch_resets_backend_state(): + # The Unsloth path's cancel branch must flush GPU / KV-cache state + # via `backend.reset_generation_state()` -- the orchestrator's + # internal cancel path does not do this, so a cancel-via-POST that + # only broke the loop would leave the subprocess in a dirty state + # for the next request. + fn = None + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None + for n in ast.walk(top): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks": + fn = n + break + assert fn is not None, "stream_chunks generator missing" + + for sub in ast.walk(fn): + if not isinstance(sub, ast.If): + continue + t = sub.test + if not ( + isinstance(t, ast.Call) + and isinstance(t.func, ast.Attribute) + and t.func.attr == "is_set" + and isinstance(t.func.value, ast.Name) + and t.func.value.id == "cancel_event" + ): + continue + body_src = "\n".join(ast.unparse(s) for s in sub.body) + if "backend.reset_generation_state()" in body_src: + return + raise AssertionError( + "stream_chunks `if cancel_event.is_set():` branch must call " + "backend.reset_generation_state() -- matches the existing " + "request.is_disconnected() / CancelledError cleanup paths and " + "prevents KV-cache drift after cancel-via-POST" + ) + + +# ── Behavioral simulations for the iter-1 fixes ────────────── + + +def test_unsloth_stream_loop_breaks_on_external_cancel_event(): + cancel_event = threading.Event() + reset_calls = [0] + + class _Backend: + def reset_generation_state(self): + reset_calls[0] += 1 + + backend = _Backend() + + def _generate(): + for i in range(200): + time.sleep(0.005) + yield f"cum-{i}" + + async def _loop(): + _DONE = object() + loop = asyncio.get_event_loop() + gen = _generate() + seen = [] + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + cumulative = await loop.run_in_executor(None, next, gen, _DONE) + if cumulative is _DONE: + break + seen.append(cumulative) + return seen + + async def _fire(): + await asyncio.sleep(0.05) + cancel_event.set() + + async def _main(): + return await asyncio.gather(_loop(), _fire()) + + seen, _ = asyncio.run(_main()) + assert ( + len(seen) < 200 + ), f"loop must not drain the generator after cancel; got {len(seen)} tokens" + assert reset_calls[0] == 1, ( + f"backend.reset_generation_state() must be called exactly once on " + f"cancel-via-POST, got {reset_calls[0]}" + ) + + +def test_audio_stream_stays_responsive_under_blocking_next(): + # Regression guard: replace the post-fix loop with the pre-fix + # `for chunk in audio_input_generate()` pattern and assert it blocks + # the event loop; then confirm the post-fix pattern exits promptly. + cancel_event = threading.Event() + + def _audio_gen(): + for i in range(8): + time.sleep(0.15) + yield f"chunk-{i}" + + async def _prefix_loop(): + seen = [] + for chunk_text in _audio_gen(): + if cancel_event.is_set(): + break + seen.append(chunk_text) + return seen + + async def _postfix_loop(): + _DONE = object() + gen = _audio_gen() + seen = [] + while True: + if cancel_event.is_set(): + break + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break + seen.append(chunk_text) + return seen + + async def _fire_early(): + await asyncio.sleep(0.05) + cancel_event.set() + + async def _run(loop_coro): + return await asyncio.gather(loop_coro, _fire_early()) + + cancel_event.clear() + t0 = time.monotonic() + prefix_seen, _ = asyncio.run(_run(_prefix_loop())) + prefix_elapsed = time.monotonic() - t0 + assert prefix_elapsed >= 0.13, ( + f"pre-fix pattern should block event loop for >=1 chunk time " + f"(~150ms); got {prefix_elapsed:.3f}s, {len(prefix_seen)} chunks" + ) + + cancel_event.clear() + t0 = time.monotonic() + postfix_seen, _ = asyncio.run(_run(_postfix_loop())) + postfix_elapsed = time.monotonic() - t0 + assert postfix_elapsed < prefix_elapsed, ( + f"post-fix pattern must exit faster than pre-fix (blocking) " + f"pattern; post={postfix_elapsed:.3f}s vs pre={prefix_elapsed:.3f}s" + ) + assert ( + len(postfix_seen) < 8 + ), f"post-fix loop must not drain all chunks; got {len(postfix_seen)}" + + +def test_unsloth_stream_loop_emits_zero_tokens_on_preset_cancel(): + # Pending-cancel replay: _TrackedCancel.__enter__ already set + # cancel_event before the generator body starts iterating. The + # top-of-loop check must short-circuit the very first iteration so + # no token is emitted. Catches a regression that moves the check + # below `next()` -- the mid-loop test would still pass but this + # test would observe one extra token leak. + cancel_event = threading.Event() + cancel_event.set() + reset_calls = [0] + + class _Backend: + def reset_generation_state(self): + reset_calls[0] += 1 + + backend = _Backend() + + next_calls = [0] + + def _generate(): + while True: + next_calls[0] += 1 + yield f"cum-{next_calls[0]}" + + async def _loop(): + _DONE = object() + loop = asyncio.get_event_loop() + gen = _generate() + seen = [] + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + cumulative = await loop.run_in_executor(None, next, gen, _DONE) + if cumulative is _DONE: + break + seen.append(cumulative) + return seen + + seen = asyncio.run(_loop()) + assert seen == [], ( + f"loop must emit zero tokens when cancel_event is pre-set " + f"(pending-replay path); got {seen}" + ) + assert next_calls[0] == 0, ( + f"loop must not call next() at all on pre-set cancel; got " + f"{next_calls[0]} calls" + ) + assert reset_calls[0] == 1, ( + f"backend.reset_generation_state() must still fire exactly once " + f"on pre-set cancel; got {reset_calls[0]}" + ) + + +def test_audio_stream_emits_zero_chunks_on_preset_cancel(): + # Symmetric to the Unsloth pre-set test: the audio loop's top-of-loop + # cancel check must skip the asyncio.to_thread(next, ...) call when + # cancel_event was already set via pending-replay. + cancel_event = threading.Event() + cancel_event.set() + + next_calls = [0] + + def _audio_gen(): + while True: + next_calls[0] += 1 + yield f"chunk-{next_calls[0]}" + + async def _loop(): + _DONE = object() + gen = _audio_gen() + seen = [] + while True: + if cancel_event.is_set(): + break + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break + seen.append(chunk_text) + return seen + + seen = asyncio.run(_loop()) + assert seen == [], f"audio loop must emit zero chunks on pre-set cancel; got {seen}" + assert next_calls[0] == 0, ( + f"audio loop must not call next() on pre-set cancel; got " + f"{next_calls[0]} calls" + )