mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-26 10:31:03 +00:00
Studio: make stop button actually stop generation (#5069)
* Studio: make stop button actually stop generation The UI stop button routes through assistant-ui's cancelRun, which aborts the frontend fetch. Four issues combined to let llama-server keep decoding long after the user clicked stop: 1. request.is_disconnected() does not fire reliably behind proxies (e.g. Colab) that don't propagate fetch aborts. 2. llama-server defaults n_predict to n_ctx when max_tokens is not sent, so a cancelled request keeps producing tokens up to 262144. 3. The httpx.Client pool keeps TCP keep-alive, so even a cleanly closed stream reuses the same connection and llama-server's liveness poll never sees a disconnect. 4. No explicit backend route to cancel - every cancel path relied on is_disconnected. Changes: - Add POST /api/inference/cancel keyed by session_id/completion_id, with a registry populated for the lifetime of each streaming response. - Have the frontend (chat-adapter.ts) POST /inference/cancel on AbortController abort, alongside the existing fetch teardown. - Send max_tokens=4096 + t_max_predict_ms=120000 as defaults on every outbound chat completion to llama-server; honoured by user overrides. - Disable httpx keep-alive on the streaming client so connection close reaches llama-server and its 1s liveness check fires. No behaviour changes for non-streaming paths or for existing callers that already pass max_tokens/session_id. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel path and scope cancel route - Require at least one identifier for /api/inference/cancel so a missing thread id cannot silently cancel every in-flight generation. - Scope /cancel to a dedicated studio_router so it is not exposed under the /v1 OpenAI-compat prefix as a surprise endpoint. - Store a set of cancel events per key in _CANCEL_REGISTRY so concurrent requests on the same session_id do not overwrite each other, and deduplicate in _cancel_by_keys so the cancelled count reflects unique requests. - Always send session_id with chat completions (not only when tools are enabled) so non-tool GGUF streams register under it and are reachable from /cancel. - Register the non-GGUF stream_chunks path in the cancel registry too, so transformers-based stop-button works behind proxies that swallow fetch aborts. - Only apply the 2-minute t_max_predict_ms wall-clock cap when the caller did not pass max_tokens, so legitimate long generations on slow CPU/macOS/Windows supported installs are not silently truncated. - Remove the abort listener on normal stream completion so reused AbortSignals cannot fire a spurious cancel POST after the fact. * studio: close cancel-race and stale-cancel gaps in stop path - Register the cancel tracker before returning StreamingResponse so a stop POST that arrives during prefill / warmup / proxy buffering finds an entry in _CANCEL_REGISTRY. Cleanup now runs via a Starlette BackgroundTask instead of a finally inside the async generator body. - Add a per-run cancel_id on the frontend (crypto.randomUUID) and in ChatCompletionRequest so /api/inference/cancel matches one specific generation. Removes the stale-cancel bug where pressing stop then starting a new run in the same thread would cancel the retry. - Apply t_max_predict_ms unconditionally in all three llama-server payload builders (previously gated on max_tokens=None, which made it dead code for UI callers that always send params.maxTokens). Raise the default to 10 minutes so slow CPU / macOS / Windows installs are not cut off mid-generation. - Make _cancel_by_keys refuse empty input (return 0) so a future internal caller can not accidentally mass-cancel every in-flight request. - Accept cancel_id (primary), session_id, and completion_id on the /api/inference/cancel route. Unify the three streaming sites on the same _cancel_keys / _tracker variable names. - Annotate _CANCEL_REGISTRY as dict[str, set[threading.Event]]. * Add review tests for PR #5069 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel semantics and wall-clock cap - Make /inference/cancel match cancel_id EXCLUSIVELY when supplied. Previously the handler iterated ('cancel_id','session_id','completion_id') and unioned matches, so a stale cancel POST carrying {cancel_id:old, session_id:thr} would still cancel a later run on the same thread via the shared session_id. cancel_id is now a per-run exclusive key; session_id / completion_id are only used as fallbacks when cancel_id is absent. - Close the early-cancel race. If /inference/cancel lands before the streaming handler reaches _TrackedCancel.__enter__() (stop clicked during prefill / warmup / proxy buffering), the cancel was silently dropped. Stash unmatched cancel_ids in _PENDING_CANCELS with a 30 s TTL; _TrackedCancel.__enter__() now replays any matching pending cancel by set()-ing the event immediately after registration. - Make t_max_predict_ms = _DEFAULT_T_MAX_PREDICT_MS conditional on max_tokens is None at all three llama-server payload sites. The cap is a safety net for callers who leave max_tokens unset (otherwise llama-server defaults n_predict to n_ctx, up to 262144). Callers who set an explicit max_tokens are already self-limiting and must not be silently truncated at 10 minutes on slow CPU / macOS / Windows legitimate long generations. - Guard each StreamingResponse return with try/except BaseException so _tracker.__exit__ runs even if StreamingResponse construction or any preceding statement raises between _tracker.__enter__() and the BackgroundTask attachment. Prevents a registry leak on that narrow window. * studio: close TOCTOU race and restore wall-clock backstop on UI path - Close TOCTOU race in the pending-cancel mechanism. The previous fix split cancel_inference's (cancel_by_keys + remember_pending_cancel) and _TrackedCancel.__enter__'s (register + consume_pending) into four separate lock acquisitions. Under contention a cancel POST could acquire-then-release the lock, find the registry empty, and stash ONLY AFTER __enter__ had already registered and consumed an empty pending map -- silently dropping the cancel. Both call sites now do their work inside a single _CANCEL_LOCK critical section, via the new atomic helper _cancel_by_cancel_id_or_stash() and an inlined consume-pending step in __enter__. Reproduced the race under forced interleaving pre-fix; 0/2000 drops post-fix under parallel stress. - Apply t_max_predict_ms UNCONDITIONALLY at all three llama-server payload sites. The previous iteration gated the cap on `max_tokens is None`, which turned out to be dead code on the primary Studio UI path: chat-adapter.ts sets maxTokens=loadResp.context_length after every model load, so every chat request carries an explicit max_tokens and the wall-clock safety net never fired. The cap's original purpose is to bound stuck decodes regardless of the token budget; it must always apply. - Raise _DEFAULT_T_MAX_PREDICT_MS from 10 minutes to 1 hour. 10 minutes was too aggressive for legitimate slow-CPU chat responses (a 4096-token reply at 2 tok/s takes ~34 min); 1 hour accommodates that and still catches genuine zombie decodes. - Prune _PENDING_CANCELS inside _cancel_by_keys as well, so stashed entries expire proportionally to overall cancel traffic rather than only to cancel_id-specific POSTs. * studio: trim verbose comments and docstrings in cancel path * studio/llama_cpp: drop upstream PR hashes from benchmark comment * Add review tests for Studio stop button * Consolidate review tests for Studio stop button * Align cancel-route test with exclusive cancel_id semantics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: move cancel cleanup to generator finally; drop dead helper - Move _tracker.__exit__ from Starlette BackgroundTask into each streaming generator's finally block. Starlette skips the background callback when stream_response raises (OSError / ClientDisconnect), which leaked _CANCEL_REGISTRY entries on abrupt disconnect. - Check cancel_event.is_set() at the top of each GGUF while loop so a pending-replay cancel falls through to final_chunk + [DONE] instead of propagating GeneratorExit out of _stream_with_retry. - Remove unused _remember_pending_cancel; _cancel_by_cancel_id_or_stash superseded it. * Add review tests for Studio stop-button * studio: wire audio-input stream into cancel registry - Register cancel_event with _TrackedCancel on the audio-input streaming path so POST /api/inference/cancel can stop whisper / audio-input GGUF runs. Previously the registry stayed empty on this branch, so the stop button returned {"cancelled":0} and the decode ran to completion. - Apply the same finally-based cleanup and pre-iteration cancel-event check used on the other three streaming paths. - Update the _CANCEL_REGISTRY block comment to list cancel_id as the primary key (was stale "session_id preferred"). * Consolidate review tests for Studio stop-button cancel flow - Merge the 6 behavioral tests from test_stream_cleanup_on_disconnect.py (finally cleanup on normal/exception/aclose, pre-set cancel_event pattern, and its regressions) into test_stream_cancel_registration_timing.py, which is the PR's existing file covering the same area. - Extend structural invariants to include audio_input_stream alongside the three GGUF / Unsloth streaming generators: no _tracker.__enter__ inside the async gen body, cleanup via try/finally, no background= on StreamingResponse. - Delete test_stream_cleanup_on_disconnect.py (now empty). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: make cancel-via-POST interrupt Unsloth and audio-input streams Close two remaining gaps in the stop-button cancellation wiring: - stream_chunks (Unsloth path): add a top-of-loop cancel_event check and call backend.reset_generation_state() so cancel POSTs flush GPU state and close the SSE cleanly instead of relying on request.is_disconnected (which does not fire through proxies like Colab's). - audio_input_stream: run the synchronous audio_input_generate() via asyncio.to_thread so blocking whisper chunks do not freeze the event loop, matching the pattern already used by the GGUF streaming paths. * Add review tests for Studio stop-button cancel flow * Consolidate review tests for Studio stop-button cancel flow - Delete standalone test_cancel_registry.py at repo root: tests duplicated test_cancel_atomicity.py / test_cancel_id_wiring.py and re-implemented registry primitives inline (scaffolding). - Extend tests/studio/test_stream_cancel_registration_timing.py with regression guards for the iter-1 cancel-loop fixes: structural: each streaming generator checks cancel_event in its loop; audio_input_stream offloads next() via asyncio.to_thread; stream_chunks cancel branch calls reset_generation_state(). runtime: Unsloth loop breaks on external cancel and resets state; audio loop stays responsive under blocking next(); both loops emit zero tokens on pre-set cancel (replay path). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: extend stop-path to passthrough streams; tighten wall-clock cap - Lower _DEFAULT_T_MAX_PREDICT_MS from 1 hour to 10 minutes so the wall-clock backstop actually bounds runaway decodes when cancel signaling fails. - Wire _TrackedCancel and cancel_event.is_set() into _openai_passthrough_stream and _anthropic_passthrough_stream and disable httpx keepalive so stop requests from /v1 and /v1/messages tool-calling clients reach llama-server. - Apply t_max_predict_ms to the tool-passthrough request body so the backstop covers passthrough paths as well. - Symmetric pre-registration stash for session_id/completion_id cancels (_cancel_by_keys_or_stash) so early cancels by those keys replay on later registration like cancel_id. - Drop dead except BaseException guards around StreamingResponse() at four streaming sites; cleanup lives in the generator's finally. * studio: harden cancel registry against ghost-cancel and leak paths - Revert the session_id/completion_id stash in the fallback cancel helper. session_id is thread-scoped and reused across runs, so stashing it on an unmatched POST would fire cancel_event for the user's next unrelated request via _TrackedCancel.__enter__. cancel_id remains the only per-run unique key that gets stashed. - Default max_tokens to _DEFAULT_MAX_TOKENS in the tool-passthrough body. Mirror the direct GGUF path so OpenAI/Anthropic passthrough callers who omit max_tokens get the same zombie-decode cap instead of relying on the wall-clock backstop alone. - Wrap _openai_passthrough_stream setup with an outer try/except BaseException. The inner except httpx.RequestError does not catch asyncio.CancelledError at await client.send, which would otherwise leave _tracker registered in _CANCEL_REGISTRY indefinitely. - Frontend stop POST uses plain fetch + manual Authorization header instead of authFetch. A 401 on the cancel POST no longer refreshes tokens or redirects the user to the login page mid-stop. * Add review tests for Studio stop-button cancel flow * studio: trim comments on stop-button review changes Collapse multi-paragraph rationale blocks on the cancel registry, _openai_passthrough_stream, and the frontend onAbortCancel handler into one-line explanations of why the non-obvious behaviour exists. Drop authFetch import that became unused when the cancel POST switched to plain fetch. * Consolidate review tests for Studio stop-button cancel flow Move review-added tests out of test_cancel_dispatch_edges.py into the existing PR test files that already cover the same areas: - backend registry fan-out / exclusivity / idempotency / falsy-keys edge cases moved into tests/studio/test_cancel_atomicity.py - frontend plain-fetch (not authFetch) + manual Authorization header moved into tests/studio/test_cancel_id_wiring.py Delete the now-empty test_cancel_dispatch_edges.py. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Studio: stop default-capping responses at 4096 tokens (follow-up to #5069) (#5174) * Studio: stop default-capping responses at 4096 tokens Follow-up to #5069. The 4096 default introduced for runaway-decode defense silently truncates any caller that omits max_tokens. The Studio chat UI sets params.maxTokens = loadResp.context_length after a GGUF load, so it's fine, but every other consumer is not: - OpenAI-API direct callers (/v1/chat/completions, /v1/responses, /v1/messages, /v1/completions) where the OpenAI default is effectively unlimited per response. langchain, llama-index, raw curl, and the openai SDK all rely on that. - Reasoning models. Qwen3 / gpt-oss reasoning traces routinely exceed 4096 tokens before the model emits a single visible content token. The user sees the trace cut off mid-thought. - Long-form generation ("write a chapter", "produce a full SVG"). Reproduced on this branch: gemma-4-E2B-it-GGUF Q8_0, prompt asking for a 10000-word story, no max_tokens in the request: finish_reason: stop (misleading -- should be 'length') content_chars: 19772 content_tail: ...'a comforting, yet immense, pressure.\n\n*"' Body ended mid-sentence on a stray opening quote, right at the 4096 token mark. After this patch the same request returns 38357 chars ending with '...held in a perfect, dynamic equilibrium.' -- a natural stop, not a truncation. Implementation: rename the constant to _DEFAULT_MAX_TOKENS_FLOOR and set it to 32768. Each call site now uses the model's effective context length when known, falling back to the floor: default_cap = self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR The 10-minute t_max_predict_ms wall-clock backstop from #5069 is preserved as the second line of defense. Plumbed _build_passthrough_payload + _build_openai_passthrough_body through the routes layer so the Anthropic and OpenAI passthrough paths also respect the model's context length. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Studio: cancel passthrough streams during llama-server prefill + route through apiUrl for Tauri Three reviewer-flagged correctness gaps in the stop-button mechanism. 1) `_openai_passthrough_stream` could not honor cancel during prefill. The cancel check ran inside the `async for raw_line in lines_iter` body, so a cancel POST that arrived before llama-server emitted the first SSE line was unobservable until prefill completed. With a long prompt under proxy/Colab conditions -- the exact target scenario for this PR -- that left the model decoding for a long time after the user clicked Stop. Add an asyncio watcher task that closes `resp` as soon as `cancel_event` is set, raising in `aiter_lines` so the generator can exit. The watcher polls a threading.Event because the cancel registry is keyed by threading.Event for the synchronous /cancel handler. 2) `_anthropic_passthrough_stream` had the same blocking-prefill pattern. Same fix. 3) The frontend's stop-button cancel POST used a bare relative `fetch("/api/inference/cancel", ...)`, which targets the webview origin in Tauri production builds (where the backend is at `http://127.0.0.1:8888`). Route through the existing `apiUrl()` helper from `lib/api-base.ts` to match every other Studio call. Browser/dev builds get the empty base, so behavior is unchanged there. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: cancel during prefill terminates within ~250ms on both passthrough paths (was 145s+ on the Anthropic path before this change), and the standard non-passthrough chat path still cancels with no regression. * Studio: log cancel-body parse errors instead of silently swallowing Reviewer-flagged defensive logging gap. The bare `except Exception: pass` in `cancel_inference` would mask malformed payloads that hint at a buggy client or a transport issue. Log at debug so future investigation isn't left guessing whether `body={}` came from a missing body or a parse failure. Behavior is unchanged: an unparseable body still falls through to the empty-dict path and the cancel call returns `{"cancelled": 0}`. * Studio: Anthropic passthrough cancel parity with OpenAI passthrough Two reviewer-flagged consistency gaps in the cancel surface for /v1/messages. 1) Anthropic passthrough did not register cancel_id, so a per-run cancel POST (the cleanest Studio-style cancel path) silently missed when the route hit `_anthropic_passthrough_stream`. The OpenAI passthrough has registered (cancel_id, session_id, completion_id) since this PR was first opened; mirror that here. Also add `cancel_id` to `AnthropicMessagesRequest` so the route handler can plumb it through. 2) The cancel handler's fallback key list checked only completion_id and session_id, never message_id. Anthropic clients that send their native `id` (returned in the SSE message_start event) for cancel had no way to hit the registry. Add message_id to the fallback list. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: P2 now cancels by cancel_id in 137ms (was hanging pre-fix), and the new P2b case cancels by message_id in 77ms. P1 (OpenAI) and P3 (standard chat) still pass with no regression. --------- Co-authored-by: danielhanchen <michaelhan2050@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
parent
8264e80dd9
commit
eb8b0dee2e
11 changed files with 1785 additions and 119 deletions
|
|
@ -53,6 +53,21 @@ _INTENT_SIGNAL = re.compile(
|
|||
r")"
|
||||
)
|
||||
_MAX_REPROMPTS = 3
|
||||
|
||||
# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to
|
||||
# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel
|
||||
# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally,
|
||||
# but the llama.cpp README notes it ONLY fires after a newline has been
|
||||
# generated -- a model stuck in a long unbroken non-newline sequence is
|
||||
# unbounded by it. So we still want a token cap as the front-line limiter.
|
||||
#
|
||||
# The cap is the model's effective context length when we know it,
|
||||
# falling back to a generous floor when metadata is unavailable. 4096 was
|
||||
# too low: Qwen3 / gpt-oss reasoning traces routinely exceed it, and any
|
||||
# OpenAI-API caller that omits max_tokens (langchain, llama-index, raw
|
||||
# curl) sees responses silently truncated mid-sentence.
|
||||
_DEFAULT_MAX_TOKENS_FLOOR = 32768
|
||||
_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min
|
||||
_REPROMPT_MAX_CHARS = 2000
|
||||
|
||||
# ── Pre-compiled patterns for GGUF shard detection ───────────
|
||||
|
|
@ -1636,7 +1651,7 @@ class LlamaCppBackend:
|
|||
# existing text (code refactoring, summarization, reasoning).
|
||||
# For general chat with low repetition, overhead is ~5 ms.
|
||||
#
|
||||
# Benchmarks from llama.cpp PRs #18471, #19164:
|
||||
# Benchmarks from upstream llama.cpp speculative-decoding PRs:
|
||||
# Scenario | Without | With | Speedup
|
||||
# gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x
|
||||
# Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x
|
||||
|
|
@ -2549,8 +2564,15 @@ class LlamaCppBackend:
|
|||
)
|
||||
if _reasoning_kw is not None:
|
||||
payload["chat_template_kwargs"] = _reasoning_kw
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
# Default cap to the model's effective context length when known,
|
||||
# otherwise the conservative floor. The wall-clock backstop below
|
||||
# keeps a stuck model from running indefinitely either way.
|
||||
payload["max_tokens"] = (
|
||||
max_tokens
|
||||
if max_tokens is not None
|
||||
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
|
||||
)
|
||||
payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
|
|
@ -2570,7 +2592,9 @@ class LlamaCppBackend:
|
|||
_auth_headers = (
|
||||
{"Authorization": f"Bearer {self._api_key}"} if self._api_key else None
|
||||
)
|
||||
with httpx.Client(timeout = stream_timeout) as client:
|
||||
with httpx.Client(
|
||||
timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0)
|
||||
) as client:
|
||||
with self._stream_with_retry(
|
||||
client,
|
||||
url,
|
||||
|
|
@ -2769,8 +2793,12 @@ class LlamaCppBackend:
|
|||
)
|
||||
if _reasoning_kw is not None:
|
||||
payload["chat_template_kwargs"] = _reasoning_kw
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
payload["max_tokens"] = (
|
||||
max_tokens
|
||||
if max_tokens is not None
|
||||
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
|
||||
)
|
||||
payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
|
||||
|
|
@ -2809,7 +2837,10 @@ class LlamaCppBackend:
|
|||
write = 10,
|
||||
pool = 10,
|
||||
)
|
||||
with httpx.Client(timeout = stream_timeout) as client:
|
||||
with httpx.Client(
|
||||
timeout = stream_timeout,
|
||||
limits = httpx.Limits(max_keepalive_connections = 0),
|
||||
) as client:
|
||||
with self._stream_with_retry(
|
||||
client,
|
||||
url,
|
||||
|
|
@ -3422,8 +3453,12 @@ class LlamaCppBackend:
|
|||
)
|
||||
if _reasoning_kw is not None:
|
||||
stream_payload["chat_template_kwargs"] = _reasoning_kw
|
||||
if max_tokens is not None:
|
||||
stream_payload["max_tokens"] = max_tokens
|
||||
stream_payload["max_tokens"] = (
|
||||
max_tokens
|
||||
if max_tokens is not None
|
||||
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
|
||||
)
|
||||
stream_payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
|
||||
if stop:
|
||||
stream_payload["stop"] = stop
|
||||
stream_payload["stream_options"] = {"include_usage": True}
|
||||
|
|
@ -3442,7 +3477,9 @@ class LlamaCppBackend:
|
|||
_auth_headers = (
|
||||
{"Authorization": f"Bearer {self._api_key}"} if self._api_key else None
|
||||
)
|
||||
with httpx.Client(timeout = stream_timeout) as client:
|
||||
with httpx.Client(
|
||||
timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0)
|
||||
) as client:
|
||||
with self._stream_with_retry(
|
||||
client,
|
||||
url,
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ from routes import (
|
|||
datasets_router,
|
||||
export_router,
|
||||
inference_router,
|
||||
inference_studio_router,
|
||||
models_router,
|
||||
training_history_router,
|
||||
training_router,
|
||||
|
|
@ -207,6 +208,9 @@ app.include_router(auth_router, prefix = "/api/auth", tags = ["auth"])
|
|||
app.include_router(training_router, prefix = "/api/train", tags = ["training"])
|
||||
app.include_router(models_router, prefix = "/api/models", tags = ["models"])
|
||||
app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"])
|
||||
# Studio-only inference endpoints (cancel, etc.) are intentionally NOT
|
||||
# exposed on the /v1 OpenAI-compat prefix below.
|
||||
app.include_router(inference_studio_router, prefix = "/api/inference", tags = ["inference"])
|
||||
|
||||
# OpenAI-compatible endpoints: mount the same inference router at /v1
|
||||
# so external tools (Open WebUI, SillyTavern, etc.) can use the
|
||||
|
|
|
|||
|
|
@ -531,6 +531,10 @@ class ChatCompletionRequest(BaseModel):
|
|||
None,
|
||||
description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.",
|
||||
)
|
||||
cancel_id: Optional[str] = Field(
|
||||
None,
|
||||
description = "[x-unsloth] Per-request cancellation token. Frontend sends a fresh UUID per run so /inference/cancel matches one specific generation.",
|
||||
)
|
||||
|
||||
|
||||
# ── Streaming response chunks ────────────────────────────────────
|
||||
|
|
@ -992,6 +996,7 @@ class AnthropicMessagesRequest(BaseModel):
|
|||
enable_tools: Optional[bool] = None
|
||||
enabled_tools: Optional[list[str]] = None
|
||||
session_id: Optional[str] = None
|
||||
cancel_id: Optional[str] = None
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ API Routes
|
|||
from routes.training import router as training_router
|
||||
from routes.models import router as models_router
|
||||
from routes.inference import router as inference_router
|
||||
from routes.inference import studio_router as inference_studio_router
|
||||
from routes.datasets import router as datasets_router
|
||||
from routes.auth import router as auth_router
|
||||
from routes.data_recipe import router as data_recipe_router
|
||||
|
|
@ -18,6 +19,7 @@ __all__ = [
|
|||
"training_router",
|
||||
"models_router",
|
||||
"inference_router",
|
||||
"inference_studio_router",
|
||||
"datasets_router",
|
||||
"auth_router",
|
||||
"data_recipe_router",
|
||||
|
|
|
|||
|
|
@ -113,7 +113,12 @@ if str(backend_path) not in sys.path:
|
|||
# Import backend functions
|
||||
try:
|
||||
from core.inference import get_inference_backend
|
||||
from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags
|
||||
from core.inference.llama_cpp import (
|
||||
LlamaCppBackend,
|
||||
_DEFAULT_MAX_TOKENS_FLOOR,
|
||||
_DEFAULT_T_MAX_PREDICT_MS,
|
||||
detect_reasoning_flags,
|
||||
)
|
||||
from utils.models import ModelConfig
|
||||
from utils.inference import load_inference_config
|
||||
from utils.models.model_config import load_model_defaults
|
||||
|
|
@ -122,7 +127,12 @@ except ImportError:
|
|||
if str(parent_backend) not in sys.path:
|
||||
sys.path.insert(0, str(parent_backend))
|
||||
from core.inference import get_inference_backend
|
||||
from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags
|
||||
from core.inference.llama_cpp import (
|
||||
LlamaCppBackend,
|
||||
_DEFAULT_MAX_TOKENS_FLOOR,
|
||||
_DEFAULT_T_MAX_PREDICT_MS,
|
||||
detect_reasoning_flags,
|
||||
)
|
||||
from utils.models import ModelConfig
|
||||
from utils.inference import load_inference_config
|
||||
from utils.models.model_config import load_model_defaults
|
||||
|
|
@ -185,6 +195,126 @@ import numpy as np
|
|||
from datetime import date as _date
|
||||
|
||||
router = APIRouter()
|
||||
# Studio-only router (not mounted on /v1 OpenAI-compat).
|
||||
studio_router = APIRouter()
|
||||
|
||||
|
||||
# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts
|
||||
# so is_disconnected() never fires. POST /inference/cancel looks up
|
||||
# in-flight cancel_events here by cancel_id (per-run) or session_id /
|
||||
# completion_id (fallbacks).
|
||||
_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {}
|
||||
_CANCEL_LOCK = threading.Lock()
|
||||
|
||||
# Cancel POSTs that arrive before registration are stashed; the next
|
||||
# matching __enter__ replays set() within the TTL.
|
||||
_PENDING_CANCELS: dict[str, float] = {}
|
||||
_PENDING_CANCEL_TTL_S = 30.0
|
||||
|
||||
|
||||
def _prune_pending(now: float) -> None:
|
||||
for k in [
|
||||
k for k, ts in _PENDING_CANCELS.items() if now - ts > _PENDING_CANCEL_TTL_S
|
||||
]:
|
||||
_PENDING_CANCELS.pop(k, None)
|
||||
|
||||
|
||||
class _TrackedCancel:
|
||||
"""Register cancel_event in _CANCEL_REGISTRY for the block's duration."""
|
||||
|
||||
def __init__(self, event: threading.Event, *keys):
|
||||
self.event = event
|
||||
self.keys = tuple(k for k in keys if k)
|
||||
|
||||
def __enter__(self):
|
||||
# Register + consume-pending must be one critical section to close
|
||||
# the TOCTOU race against a concurrent cancel POST.
|
||||
should_cancel = False
|
||||
with _CANCEL_LOCK:
|
||||
for k in self.keys:
|
||||
_CANCEL_REGISTRY.setdefault(k, set()).add(self.event)
|
||||
now = time.monotonic()
|
||||
_prune_pending(now)
|
||||
for k in self.keys:
|
||||
if k and _PENDING_CANCELS.pop(k, None) is not None:
|
||||
should_cancel = True
|
||||
if should_cancel:
|
||||
self.event.set()
|
||||
return self.event
|
||||
|
||||
def __exit__(self, *exc):
|
||||
with _CANCEL_LOCK:
|
||||
for k in self.keys:
|
||||
bucket = _CANCEL_REGISTRY.get(k)
|
||||
if bucket is None:
|
||||
continue
|
||||
bucket.discard(self.event)
|
||||
if not bucket:
|
||||
_CANCEL_REGISTRY.pop(k, None)
|
||||
return False
|
||||
|
||||
|
||||
def _cancel_by_keys(keys) -> int:
|
||||
"""Set cancel_event for matching registry entries; no stash.
|
||||
session_id/completion_id are shared across runs on the same thread,
|
||||
so stashing them would ghost-cancel the user's next request. Only
|
||||
cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash)."""
|
||||
if not keys:
|
||||
return 0
|
||||
events: set[threading.Event] = set()
|
||||
with _CANCEL_LOCK:
|
||||
_prune_pending(time.monotonic())
|
||||
for k in keys:
|
||||
bucket = _CANCEL_REGISTRY.get(k)
|
||||
if bucket:
|
||||
events.update(bucket)
|
||||
for ev in events:
|
||||
ev.set()
|
||||
return len(events)
|
||||
|
||||
|
||||
def _cancel_by_cancel_id_or_stash(cancel_id: str) -> int:
|
||||
"""Atomic lookup-or-stash; pairs with _TrackedCancel.__enter__ to
|
||||
close the TOCTOU race."""
|
||||
now = time.monotonic()
|
||||
events: set[threading.Event] = set()
|
||||
with _CANCEL_LOCK:
|
||||
_prune_pending(now)
|
||||
bucket = _CANCEL_REGISTRY.get(cancel_id)
|
||||
if bucket:
|
||||
events.update(bucket)
|
||||
else:
|
||||
_PENDING_CANCELS[cancel_id] = now
|
||||
for ev in events:
|
||||
ev.set()
|
||||
return len(events)
|
||||
|
||||
|
||||
async def _await_cancel_then_close(cancel_event, resp) -> None:
|
||||
"""Watch a threading.Event from asyncio and close ``resp`` when it fires.
|
||||
|
||||
Used by the passthrough streamers so a /cancel POST can interrupt
|
||||
while the async iterator is blocked waiting for llama-server prefill.
|
||||
Without this watcher the in-loop ``cancel_event.is_set()`` check is
|
||||
unreachable until the first SSE chunk arrives, which is exactly the
|
||||
proxy/Colab scenario the cancel POST exists to handle.
|
||||
|
||||
Polls a threading.Event because the cancel registry is keyed by
|
||||
threading.Event so the synchronous /cancel handler can call .set().
|
||||
50ms cadence adds at most that much latency to a prefill cancel; the
|
||||
common-case streaming cancel path still observes the event in the
|
||||
iterator's first iteration after the next chunk.
|
||||
"""
|
||||
try:
|
||||
while not cancel_event.is_set():
|
||||
await asyncio.sleep(0.05)
|
||||
try:
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
# Appended to tool-use nudge to discourage plan-without-action
|
||||
_TOOL_ACTION_NUDGE = (
|
||||
|
|
@ -706,6 +836,48 @@ async def unload_model(
|
|||
raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}")
|
||||
|
||||
|
||||
@studio_router.post("/cancel")
|
||||
async def cancel_inference(
|
||||
request: Request,
|
||||
current_subject: str = Depends(get_current_subject),
|
||||
):
|
||||
"""Cancel in-flight inference requests.
|
||||
|
||||
Body (JSON, at least one key required):
|
||||
cancel_id - preferred: per-run UUID, matched exclusively.
|
||||
session_id - fallback when cancel_id is absent.
|
||||
completion_id - fallback when cancel_id is absent.
|
||||
|
||||
A cancel_id arriving before its stream registers is stashed briefly
|
||||
and replayed on registration. Returns {"cancelled": N}.
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
if not isinstance(body, dict):
|
||||
body = {}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to parse cancel request body: %s", e)
|
||||
body = {}
|
||||
|
||||
cancel_id = body.get("cancel_id")
|
||||
if isinstance(cancel_id, str) and cancel_id:
|
||||
return {"cancelled": _cancel_by_cancel_id_or_stash(cancel_id)}
|
||||
|
||||
keys = []
|
||||
# `message_id` is the Anthropic passthrough's per-run identifier --
|
||||
# included so /v1/messages clients can cancel by their native id.
|
||||
for k in ("completion_id", "session_id", "message_id"):
|
||||
v = body.get(k)
|
||||
if isinstance(v, str) and v:
|
||||
keys.append(v)
|
||||
|
||||
if not keys:
|
||||
return {"cancelled": 0}
|
||||
|
||||
n = _cancel_by_keys(keys)
|
||||
return {"cancelled": n}
|
||||
|
||||
|
||||
@router.post("/generate/stream")
|
||||
async def generate_stream(
|
||||
request: GenerateRequest,
|
||||
|
|
@ -1169,6 +1341,9 @@ async def openai_chat_completions(
|
|||
)
|
||||
|
||||
if payload.stream:
|
||||
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
|
||||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
async def audio_input_stream():
|
||||
try:
|
||||
|
|
@ -1185,10 +1360,17 @@ async def openai_chat_completions(
|
|||
)
|
||||
yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n"
|
||||
|
||||
for chunk_text in audio_input_generate():
|
||||
gen = audio_input_generate()
|
||||
_DONE = object()
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
return
|
||||
chunk_text = await asyncio.to_thread(next, gen, _DONE)
|
||||
if chunk_text is _DONE:
|
||||
break
|
||||
if chunk_text:
|
||||
chunk = ChatCompletionChunk(
|
||||
id = completion_id,
|
||||
|
|
@ -1221,6 +1403,8 @@ async def openai_chat_completions(
|
|||
f"Error during audio input streaming: {e}", exc_info = True
|
||||
)
|
||||
yield f"data: {json.dumps({'error': {'message': _friendly_error(e), 'type': 'server_error'}})}\n\n"
|
||||
finally:
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
return StreamingResponse(
|
||||
audio_input_stream(),
|
||||
|
|
@ -1466,6 +1650,10 @@ async def openai_chat_completions(
|
|||
|
||||
_tool_sentinel = object()
|
||||
|
||||
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
|
||||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
async def gguf_tool_stream():
|
||||
try:
|
||||
first_chunk = ChatCompletionChunk(
|
||||
|
|
@ -1488,6 +1676,8 @@ async def openai_chat_completions(
|
|||
_stream_usage = None
|
||||
_stream_timings = None
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
return
|
||||
|
|
@ -1595,6 +1785,8 @@ async def openai_chat_completions(
|
|||
},
|
||||
}
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
finally:
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
return StreamingResponse(
|
||||
gguf_tool_stream(),
|
||||
|
|
@ -1628,6 +1820,9 @@ async def openai_chat_completions(
|
|||
_gguf_sentinel = object()
|
||||
|
||||
if payload.stream:
|
||||
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
|
||||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
async def gguf_stream_chunks():
|
||||
try:
|
||||
|
|
@ -1652,6 +1847,8 @@ async def openai_chat_completions(
|
|||
_stream_usage = None
|
||||
_stream_timings = None
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
return
|
||||
|
|
@ -1735,6 +1932,8 @@ async def openai_chat_completions(
|
|||
},
|
||||
}
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
finally:
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
return StreamingResponse(
|
||||
gguf_stream_chunks(),
|
||||
|
|
@ -1834,6 +2033,9 @@ async def openai_chat_completions(
|
|||
|
||||
# ── Streaming response ────────────────────────────────────────
|
||||
if payload.stream:
|
||||
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
|
||||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
async def stream_chunks():
|
||||
try:
|
||||
|
|
@ -1861,6 +2063,9 @@ async def openai_chat_completions(
|
|||
loop = asyncio.get_event_loop()
|
||||
gen = generate()
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
backend.reset_generation_state()
|
||||
break
|
||||
# next(gen, _DONE) returns _DONE instead of raising
|
||||
# StopIteration — StopIteration cannot propagate
|
||||
# through asyncio futures (Python limitation).
|
||||
|
|
@ -1916,6 +2121,8 @@ async def openai_chat_completions(
|
|||
},
|
||||
}
|
||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
||||
finally:
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_chunks(),
|
||||
|
|
@ -2596,7 +2803,9 @@ async def _responses_stream(
|
|||
),
|
||||
)
|
||||
|
||||
body = _build_openai_passthrough_body(chat_req)
|
||||
body = _build_openai_passthrough_body(
|
||||
chat_req, backend_ctx = llama_backend.context_length
|
||||
)
|
||||
target_url = f"{llama_backend.base_url}/v1/chat/completions"
|
||||
|
||||
async def event_generator():
|
||||
|
|
@ -3081,6 +3290,8 @@ async def anthropic_messages(
|
|||
repetition_penalty = repetition_penalty,
|
||||
presence_penalty = presence_penalty,
|
||||
tool_choice = openai_tool_choice,
|
||||
session_id = payload.session_id,
|
||||
cancel_id = payload.cancel_id,
|
||||
)
|
||||
return await _anthropic_passthrough_non_streaming(
|
||||
llama_backend,
|
||||
|
|
@ -3441,6 +3652,7 @@ def _build_passthrough_payload(
|
|||
repetition_penalty = None,
|
||||
presence_penalty = None,
|
||||
tool_choice = "auto",
|
||||
backend_ctx = None,
|
||||
):
|
||||
body = {
|
||||
"messages": openai_messages,
|
||||
|
|
@ -3453,8 +3665,12 @@ def _build_passthrough_payload(
|
|||
}
|
||||
if stream:
|
||||
body["stream_options"] = {"include_usage": True}
|
||||
if max_tokens is not None:
|
||||
body["max_tokens"] = max_tokens
|
||||
body["max_tokens"] = (
|
||||
max_tokens
|
||||
if max_tokens is not None
|
||||
else (backend_ctx or _DEFAULT_MAX_TOKENS_FLOOR)
|
||||
)
|
||||
body["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
|
||||
if stop:
|
||||
body["stop"] = stop
|
||||
if min_p is not None:
|
||||
|
|
@ -3484,6 +3700,8 @@ async def _anthropic_passthrough_stream(
|
|||
repetition_penalty = None,
|
||||
presence_penalty = None,
|
||||
tool_choice = "auto",
|
||||
session_id = None,
|
||||
cancel_id = None,
|
||||
):
|
||||
"""Streaming client-side pass-through: forward tools to llama-server and
|
||||
translate its streaming response to Anthropic SSE without executing anything."""
|
||||
|
|
@ -3501,8 +3719,14 @@ async def _anthropic_passthrough_stream(
|
|||
repetition_penalty = repetition_penalty,
|
||||
presence_penalty = presence_penalty,
|
||||
tool_choice = tool_choice,
|
||||
backend_ctx = llama_backend.context_length,
|
||||
)
|
||||
|
||||
# cancel_id mirrors the OpenAI passthrough so a per-run cancel POST
|
||||
# works without the caller having to know the local message_id.
|
||||
_tracker = _TrackedCancel(cancel_event, cancel_id, session_id, message_id)
|
||||
_tracker.__enter__()
|
||||
|
||||
async def _stream():
|
||||
emitter = AnthropicPassthroughEmitter()
|
||||
for line in emitter.start(message_id, model_name):
|
||||
|
|
@ -3535,15 +3759,28 @@ async def _anthropic_passthrough_stream(
|
|||
# has anything orphaned to finalize. Each aclose is wrapped in
|
||||
# `try: ... except Exception: pass` so anyio cleanup noise from
|
||||
# nested aclose paths can't bubble out.
|
||||
client = httpx.AsyncClient(timeout = 600)
|
||||
client = httpx.AsyncClient(
|
||||
timeout = 600,
|
||||
limits = httpx.Limits(max_keepalive_connections = 0),
|
||||
)
|
||||
resp = None
|
||||
lines_iter = None
|
||||
cancel_watcher = None
|
||||
try:
|
||||
req = client.build_request("POST", target_url, json = body)
|
||||
resp = await client.send(req, stream = True)
|
||||
|
||||
# See _openai_passthrough_stream for rationale: aiter_lines()
|
||||
# blocks during llama-server prefill, so the in-loop cancel
|
||||
# check is unreachable until the first SSE chunk arrives.
|
||||
# The watcher closes `resp` on cancel, raising in aiter_lines.
|
||||
cancel_watcher = asyncio.create_task(
|
||||
_await_cancel_then_close(cancel_event, resp)
|
||||
)
|
||||
lines_iter = resp.aiter_lines()
|
||||
async for raw_line in lines_iter:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
break
|
||||
|
|
@ -3558,9 +3795,18 @@ async def _anthropic_passthrough_stream(
|
|||
continue
|
||||
for line in emitter.feed_chunk(chunk):
|
||||
yield line
|
||||
except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError):
|
||||
if not cancel_event.is_set():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("anthropic_messages passthrough stream error: %s", e)
|
||||
finally:
|
||||
if cancel_watcher is not None:
|
||||
cancel_watcher.cancel()
|
||||
try:
|
||||
await cancel_watcher
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
if lines_iter is not None:
|
||||
try:
|
||||
await lines_iter.aclose()
|
||||
|
|
@ -3575,6 +3821,7 @@ async def _anthropic_passthrough_stream(
|
|||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
for line in emitter.finish():
|
||||
yield line
|
||||
|
|
@ -3621,6 +3868,7 @@ async def _anthropic_passthrough_non_streaming(
|
|||
repetition_penalty = repetition_penalty,
|
||||
presence_penalty = presence_penalty,
|
||||
tool_choice = tool_choice,
|
||||
backend_ctx = llama_backend.context_length,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -3742,7 +3990,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
|
|||
return messages
|
||||
|
||||
|
||||
def _build_openai_passthrough_body(payload) -> dict:
|
||||
def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict:
|
||||
"""Assemble the llama-server request body from a ChatCompletionRequest.
|
||||
|
||||
Only explicitly-known OpenAI / llama-server fields are forwarded so that
|
||||
|
|
@ -3764,6 +4012,7 @@ def _build_openai_passthrough_body(payload) -> dict:
|
|||
repetition_penalty = payload.repetition_penalty,
|
||||
presence_penalty = payload.presence_penalty,
|
||||
tool_choice = tool_choice,
|
||||
backend_ctx = backend_ctx,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -3784,103 +4033,56 @@ async def _openai_passthrough_stream(
|
|||
observes a standard OpenAI response.
|
||||
"""
|
||||
target_url = f"{llama_backend.base_url}/v1/chat/completions"
|
||||
body = _build_openai_passthrough_body(payload)
|
||||
body = _build_openai_passthrough_body(
|
||||
payload, backend_ctx = llama_backend.context_length
|
||||
)
|
||||
|
||||
# Dispatch the upstream request BEFORE returning StreamingResponse so
|
||||
# transport errors and non-200 upstream statuses surface as real HTTP
|
||||
# errors to the client. OpenAI SDKs rely on status codes to raise
|
||||
# ``APIError``/``BadRequestError``/...; burying the failure inside a
|
||||
# 200 SSE ``error`` frame silently breaks their error handling.
|
||||
client = httpx.AsyncClient(timeout = 600)
|
||||
resp = None
|
||||
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
|
||||
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
|
||||
_tracker.__enter__()
|
||||
|
||||
# Outer guard: asyncio.CancelledError at `await client.send(...)` is
|
||||
# a BaseException that bypasses `except httpx.RequestError`; without
|
||||
# this the tracker leaks. The generator's finally only runs once
|
||||
# iteration starts.
|
||||
try:
|
||||
req = client.build_request("POST", target_url, json = body)
|
||||
resp = await client.send(req, stream = True)
|
||||
except httpx.RequestError as e:
|
||||
# llama-server subprocess crashed / still starting / unreachable.
|
||||
logger.error("openai passthrough stream: upstream unreachable: %s", e)
|
||||
if resp is not None:
|
||||
try:
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code = 502,
|
||||
detail = _friendly_error(e),
|
||||
# Dispatch BEFORE returning StreamingResponse so transport errors
|
||||
# and non-200 upstream statuses surface as real HTTP errors --
|
||||
# OpenAI SDKs rely on status codes to raise APIError/BadRequestError.
|
||||
client = httpx.AsyncClient(
|
||||
timeout = 600,
|
||||
limits = httpx.Limits(max_keepalive_connections = 0),
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
err_bytes = await resp.aread()
|
||||
err_text = err_bytes.decode("utf-8", errors = "replace")
|
||||
logger.error(
|
||||
"openai passthrough upstream error: status=%s body=%s",
|
||||
resp.status_code,
|
||||
err_text[:500],
|
||||
)
|
||||
upstream_status = resp.status_code
|
||||
resp = None
|
||||
try:
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code = upstream_status,
|
||||
detail = f"llama-server error: {err_text[:500]}",
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
# Same httpx lifecycle pattern as _anthropic_passthrough_stream:
|
||||
# avoid `async with` on the client/response AND explicitly save
|
||||
# resp.aiter_lines() so we can close it ourselves in the finally
|
||||
# block. See the long comment there for the full rationale on
|
||||
# why the anonymous `async for raw_line in resp.aiter_lines():`
|
||||
# pattern leaks an unclosed async generator that Python's
|
||||
# asyncgen GC hook then finalizes in a different asyncio task,
|
||||
# producing "Exception ignored in:" / "async generator ignored
|
||||
# GeneratorExit" / anyio cancel-scope traces on Python 3.13 +
|
||||
# httpcore 1.0.x.
|
||||
lines_iter = None
|
||||
try:
|
||||
lines_iter = resp.aiter_lines()
|
||||
async for raw_line in lines_iter:
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
break
|
||||
if not raw_line:
|
||||
continue
|
||||
if not raw_line.startswith("data: "):
|
||||
continue
|
||||
# Relay the llama-server SSE chunk verbatim so the client
|
||||
# sees its native `id`, `finish_reason`, `delta.tool_calls`,
|
||||
# and final `usage` unchanged.
|
||||
yield raw_line + "\n\n"
|
||||
if raw_line[6:].strip() == "[DONE]":
|
||||
break
|
||||
except Exception as e:
|
||||
# Mid-stream failures still have to be reported inside the SSE
|
||||
# body because the 200 response headers have already been
|
||||
# committed by the time the first chunk flushes.
|
||||
logger.error("openai passthrough stream error: %s", e)
|
||||
err = {
|
||||
"error": {
|
||||
"message": _friendly_error(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(err)}\n\n"
|
||||
finally:
|
||||
if lines_iter is not None:
|
||||
req = client.build_request("POST", target_url, json = body)
|
||||
resp = await client.send(req, stream = True)
|
||||
except httpx.RequestError as e:
|
||||
# llama-server subprocess crashed / still starting / unreachable.
|
||||
logger.error("openai passthrough stream: upstream unreachable: %s", e)
|
||||
if resp is not None:
|
||||
try:
|
||||
await lines_iter.aclose()
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code = 502,
|
||||
detail = _friendly_error(e),
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
err_bytes = await resp.aread()
|
||||
err_text = err_bytes.decode("utf-8", errors = "replace")
|
||||
logger.error(
|
||||
"openai passthrough upstream error: status=%s body=%s",
|
||||
resp.status_code,
|
||||
err_text[:500],
|
||||
)
|
||||
upstream_status = resp.status_code
|
||||
try:
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
|
|
@ -3889,16 +4091,91 @@ async def _openai_passthrough_stream(
|
|||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code = upstream_status,
|
||||
detail = f"llama-server error: {err_text[:500]}",
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type = "text/event-stream",
|
||||
headers = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
async def _stream():
|
||||
# Same httpx lifecycle pattern as _anthropic_passthrough_stream:
|
||||
# save resp.aiter_lines() so the finally block can aclose() it
|
||||
# on our task. See that function for full rationale.
|
||||
lines_iter = None
|
||||
# During llama-server prefill, `aiter_lines()` blocks until the
|
||||
# first SSE chunk arrives. The in-loop `cancel_event` check
|
||||
# cannot fire until then, which is the exact proxy/Colab
|
||||
# scenario the cancel POST is meant to recover from. Run a
|
||||
# tiny watcher that closes `resp` as soon as cancel fires,
|
||||
# unblocking the iterator with a RemoteProtocolError caught
|
||||
# in the except clause below.
|
||||
cancel_watcher = asyncio.create_task(
|
||||
_await_cancel_then_close(cancel_event, resp)
|
||||
)
|
||||
try:
|
||||
lines_iter = resp.aiter_lines()
|
||||
async for raw_line in lines_iter:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
if await request.is_disconnected():
|
||||
cancel_event.set()
|
||||
break
|
||||
if not raw_line:
|
||||
continue
|
||||
if not raw_line.startswith("data: "):
|
||||
continue
|
||||
# Relay verbatim to preserve llama-server's native id,
|
||||
# finish_reason, delta.tool_calls, and usage chunks.
|
||||
yield raw_line + "\n\n"
|
||||
if raw_line[6:].strip() == "[DONE]":
|
||||
break
|
||||
except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError):
|
||||
# Watcher closed resp on cancel. Emit nothing extra; the
|
||||
# client either initiated the cancel or already disconnected.
|
||||
if not cancel_event.is_set():
|
||||
raise
|
||||
except Exception as e:
|
||||
# 200 headers are already flushed; errors must be in the SSE body.
|
||||
logger.error("openai passthrough stream error: %s", e)
|
||||
err = {
|
||||
"error": {
|
||||
"message": _friendly_error(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(err)}\n\n"
|
||||
finally:
|
||||
cancel_watcher.cancel()
|
||||
try:
|
||||
await cancel_watcher
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
if lines_iter is not None:
|
||||
try:
|
||||
await lines_iter.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await resp.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
_tracker.__exit__(None, None, None)
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type = "text/event-stream",
|
||||
headers = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
except BaseException:
|
||||
_tracker.__exit__(None, None, None)
|
||||
raise
|
||||
|
||||
|
||||
async def _openai_passthrough_non_streaming(
|
||||
|
|
@ -3914,7 +4191,9 @@ async def _openai_passthrough_non_streaming(
|
|||
token counts.
|
||||
"""
|
||||
target_url = f"{llama_backend.base_url}/v1/chat/completions"
|
||||
body = _build_openai_passthrough_body(payload)
|
||||
body = _build_openai_passthrough_body(
|
||||
payload, backend_ctx = llama_backend.context_length
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
import type { ChatModelAdapter } from "@assistant-ui/react";
|
||||
import type { MessageTiming, ToolCallMessagePart } from "@assistant-ui/core";
|
||||
import { toast } from "sonner";
|
||||
import { getAuthToken } from "@/features/auth/session";
|
||||
import { apiUrl } from "@/lib/api-base";
|
||||
import {
|
||||
generateAudio,
|
||||
listCachedGguf,
|
||||
|
|
@ -707,7 +709,42 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
|
|||
const toolCallParts: ToolCallMessagePart[] = [];
|
||||
let serverMetadata: { usage?: ServerUsage; timings?: ServerTimings } | null = null;
|
||||
|
||||
// Per-run cancellation token so a delayed stop POST cannot match
|
||||
// the next run on the same thread.
|
||||
const cancelId =
|
||||
typeof crypto !== "undefined" && "randomUUID" in crypto
|
||||
? crypto.randomUUID()
|
||||
: `${Date.now()}-${Math.random().toString(36).slice(2)}`;
|
||||
|
||||
// Colab-style proxies can swallow fetch aborts, so also POST
|
||||
// /inference/cancel explicitly on abort.
|
||||
const onAbortCancel = () => {
|
||||
const body: Record<string, string> = { cancel_id: cancelId };
|
||||
if (resolvedThreadId) body.session_id = resolvedThreadId;
|
||||
// Plain fetch, not authFetch: authFetch redirects to login on
|
||||
// 401, which would kick the user out mid-stop.
|
||||
const token = getAuthToken();
|
||||
// Use apiUrl so the cancel POST reaches the right origin in
|
||||
// Tauri production builds (where the webview origin is not the
|
||||
// backend at 127.0.0.1:<port>). Browser/dev builds get the empty
|
||||
// base, so the path is unchanged there.
|
||||
void fetch(apiUrl("/api/inference/cancel"), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
keepalive: true,
|
||||
}).catch(() => {});
|
||||
};
|
||||
try {
|
||||
if (abortSignal.aborted) {
|
||||
onAbortCancel();
|
||||
} else {
|
||||
abortSignal.addEventListener("abort", onAbortCancel, { once: true });
|
||||
}
|
||||
|
||||
const {
|
||||
supportsReasoning,
|
||||
reasoningEnabled,
|
||||
|
|
@ -730,6 +767,8 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
|
|||
presence_penalty: params.presencePenalty,
|
||||
image_base64: imageBase64,
|
||||
audio_base64: audioBase64,
|
||||
cancel_id: cancelId,
|
||||
...(resolvedThreadId ? { session_id: resolvedThreadId } : {}),
|
||||
...(useAdapter === undefined ? {} : { use_adapter: useAdapter }),
|
||||
...(supportsReasoning
|
||||
? reasoningStyle === "reasoning_effort"
|
||||
|
|
@ -750,7 +789,6 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
|
|||
const mins = useChatRuntimeStore.getState().toolCallTimeout;
|
||||
return mins >= 9999 ? 9999 : mins * 60;
|
||||
})(),
|
||||
session_id: resolvedThreadId,
|
||||
}
|
||||
: {}),
|
||||
},
|
||||
|
|
@ -948,6 +986,7 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
|
|||
}
|
||||
throw err;
|
||||
} finally {
|
||||
abortSignal.removeEventListener("abort", onAbortCancel);
|
||||
runtime.setGeneratingStatus(null);
|
||||
runtime.setToolStatus(null);
|
||||
clearTimeout(warmupTimer);
|
||||
|
|
|
|||
|
|
@ -179,6 +179,7 @@ export interface OpenAIChatCompletionsRequest {
|
|||
max_tool_calls_per_message?: number;
|
||||
tool_call_timeout?: number;
|
||||
session_id?: string;
|
||||
cancel_id?: string;
|
||||
}
|
||||
|
||||
export interface OpenAIChatDelta {
|
||||
|
|
|
|||
289
tests/studio/test_cancel_atomicity.py
Normal file
289
tests/studio/test_cancel_atomicity.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
TOCTOU atomicity guards for the cancel path.
|
||||
|
||||
Structural: cancel_inference, _cancel_by_cancel_id_or_stash, and
|
||||
_TrackedCancel.__enter__ must each use a single _CANCEL_LOCK critical
|
||||
section over lookup + stash / register + consume-pending.
|
||||
|
||||
Behavioral: parallel cancel-POST vs __enter__ must never drop a cancel.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import random
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SOURCE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "studio"
|
||||
/ "backend"
|
||||
/ "routes"
|
||||
/ "inference.py"
|
||||
)
|
||||
_SRC = SOURCE_PATH.read_text()
|
||||
_TREE = ast.parse(_SRC)
|
||||
|
||||
|
||||
def _find_function(name: str) -> ast.FunctionDef | ast.AsyncFunctionDef:
|
||||
for node in ast.walk(_TREE):
|
||||
if (
|
||||
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and node.name == name
|
||||
):
|
||||
return node
|
||||
raise AssertionError(f"function {name!r} not found")
|
||||
|
||||
|
||||
def _find_class(name: str) -> ast.ClassDef:
|
||||
for node in ast.walk(_TREE):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
return node
|
||||
raise AssertionError(f"class {name!r} not found")
|
||||
|
||||
|
||||
def _count_with_cancel_lock_blocks(node: ast.AST) -> int:
|
||||
n = 0
|
||||
for sub in ast.walk(node):
|
||||
if not isinstance(sub, ast.With):
|
||||
continue
|
||||
for item in sub.items:
|
||||
ctx = item.context_expr
|
||||
if isinstance(ctx, ast.Name) and ctx.id == "_CANCEL_LOCK":
|
||||
n += 1
|
||||
break
|
||||
return n
|
||||
|
||||
|
||||
def test_cancel_by_cancel_id_or_stash_is_single_lock_critical_section():
|
||||
fn = _find_function("_cancel_by_cancel_id_or_stash")
|
||||
assert _count_with_cancel_lock_blocks(fn) == 1, (
|
||||
"_cancel_by_cancel_id_or_stash must use exactly one `with "
|
||||
"_CANCEL_LOCK:` block; splitting into two acquisitions reopens "
|
||||
"the TOCTOU race with _TrackedCancel.__enter__"
|
||||
)
|
||||
src = ast.unparse(fn)
|
||||
assert "_CANCEL_REGISTRY.get(cancel_id)" in src
|
||||
assert "_PENDING_CANCELS[cancel_id]" in src
|
||||
|
||||
|
||||
def test_tracked_cancel_enter_registers_and_consumes_pending_under_one_lock():
|
||||
cls = _find_class("_TrackedCancel")
|
||||
enter = None
|
||||
for n in cls.body:
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "__enter__":
|
||||
enter = n
|
||||
break
|
||||
assert enter is not None
|
||||
assert _count_with_cancel_lock_blocks(enter) == 1, (
|
||||
"_TrackedCancel.__enter__ must acquire _CANCEL_LOCK exactly once. "
|
||||
"A second acquisition for consume-pending lets a concurrent "
|
||||
"cancel POST stash after consume sees an empty map, silently "
|
||||
"dropping the cancel"
|
||||
)
|
||||
with_block = None
|
||||
for sub in ast.walk(enter):
|
||||
if isinstance(sub, ast.With) and any(
|
||||
isinstance(i.context_expr, ast.Name) and i.context_expr.id == "_CANCEL_LOCK"
|
||||
for i in sub.items
|
||||
):
|
||||
with_block = sub
|
||||
break
|
||||
assert with_block is not None
|
||||
block_src = "\n".join(ast.unparse(s) for s in with_block.body)
|
||||
assert "_CANCEL_REGISTRY.setdefault" in block_src
|
||||
assert "_PENDING_CANCELS.pop" in block_src, (
|
||||
"__enter__ critical section must consume from _PENDING_CANCELS "
|
||||
"inside the same lock, not a later re-acquisition"
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_inference_uses_atomic_helper_for_cancel_id_path():
|
||||
fn = _find_function("cancel_inference")
|
||||
src = ast.unparse(fn)
|
||||
assert "_cancel_by_cancel_id_or_stash" in src
|
||||
# The pre-fix two-step idiom must be gone.
|
||||
assert "_remember_pending_cancel(cancel_id)" not in src, (
|
||||
"two-step _cancel_by_keys + _remember_pending_cancel produced "
|
||||
"the TOCTOU race and must not return"
|
||||
)
|
||||
|
||||
|
||||
_WANTED = {
|
||||
"_CANCEL_REGISTRY",
|
||||
"_CANCEL_LOCK",
|
||||
"_PENDING_CANCELS",
|
||||
"_PENDING_CANCEL_TTL_S",
|
||||
"_prune_pending",
|
||||
"_remember_pending_cancel",
|
||||
"_TrackedCancel",
|
||||
"_cancel_by_keys",
|
||||
"_cancel_by_cancel_id_or_stash",
|
||||
}
|
||||
|
||||
|
||||
def _load_registry_module():
|
||||
chunks = []
|
||||
for n in _TREE.body:
|
||||
seg = ast.get_source_segment(_SRC, n)
|
||||
if seg is None:
|
||||
continue
|
||||
if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED:
|
||||
chunks.append(seg)
|
||||
elif isinstance(n, ast.Assign):
|
||||
names = [t.id for t in n.targets if isinstance(t, ast.Name)]
|
||||
if any(name in _WANTED for name in names):
|
||||
chunks.append(seg)
|
||||
elif (
|
||||
isinstance(n, ast.AnnAssign)
|
||||
and isinstance(n.target, ast.Name)
|
||||
and n.target.id in _WANTED
|
||||
):
|
||||
chunks.append(seg)
|
||||
mod = {}
|
||||
exec(
|
||||
"import threading, time\nfrom typing import Optional\n" + "\n\n".join(chunks),
|
||||
mod,
|
||||
)
|
||||
return mod
|
||||
|
||||
|
||||
def test_parallel_cancel_vs_register_never_drops():
|
||||
m = _load_registry_module()
|
||||
trials = 500
|
||||
dropped = 0
|
||||
for i in range(trials):
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
m["_PENDING_CANCELS"].clear()
|
||||
cid = f"cid-{i}"
|
||||
ev = threading.Event()
|
||||
tracker = m["_TrackedCancel"](ev, cid, "thread")
|
||||
start = threading.Event()
|
||||
|
||||
def do_cancel():
|
||||
start.wait()
|
||||
m["_cancel_by_cancel_id_or_stash"](cid)
|
||||
|
||||
def do_enter():
|
||||
start.wait()
|
||||
tracker.__enter__()
|
||||
|
||||
threads = [
|
||||
threading.Thread(target = do_cancel),
|
||||
threading.Thread(target = do_enter),
|
||||
]
|
||||
random.shuffle(threads)
|
||||
for t in threads:
|
||||
t.start()
|
||||
start.set()
|
||||
for t in threads:
|
||||
t.join(timeout = 5.0)
|
||||
assert not t.is_alive()
|
||||
|
||||
if not ev.is_set():
|
||||
dropped += 1
|
||||
tracker.__exit__(None, None, None)
|
||||
|
||||
assert dropped == 0, (
|
||||
f"TOCTOU regression: {dropped}/{trials} parallel trials silently "
|
||||
f"dropped the cancel"
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_before_register_replays_atomically():
|
||||
m = _load_registry_module()
|
||||
cid = "early-cid"
|
||||
ev = threading.Event()
|
||||
tracker = m["_TrackedCancel"](ev, cid, "thread-x")
|
||||
|
||||
assert m["_cancel_by_cancel_id_or_stash"](cid) == 0
|
||||
assert cid in m["_PENDING_CANCELS"]
|
||||
|
||||
tracker.__enter__()
|
||||
assert ev.is_set()
|
||||
assert cid not in m["_PENDING_CANCELS"]
|
||||
tracker.__exit__(None, None, None)
|
||||
|
||||
|
||||
def test_cancel_after_register_signals_without_stash():
|
||||
m = _load_registry_module()
|
||||
cid = "post-cid"
|
||||
ev = threading.Event()
|
||||
tracker = m["_TrackedCancel"](ev, cid, "thread-y")
|
||||
tracker.__enter__()
|
||||
|
||||
assert m["_cancel_by_cancel_id_or_stash"](cid) == 1
|
||||
assert ev.is_set()
|
||||
assert cid not in m["_PENDING_CANCELS"]
|
||||
tracker.__exit__(None, None, None)
|
||||
|
||||
|
||||
def test_cancel_by_keys_tolerates_empty_and_falsy_keys():
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
m["_PENDING_CANCELS"].clear()
|
||||
assert m["_cancel_by_keys"]([]) == 0
|
||||
assert m["_cancel_by_keys"](["", None, "unknown"]) == 0
|
||||
# Non-stashing fallback must never leak into _PENDING_CANCELS.
|
||||
assert m["_PENDING_CANCELS"] == {}
|
||||
|
||||
|
||||
def test_cancel_by_keys_fans_out_to_all_streams_on_same_session():
|
||||
# Compare mode and other flows launch concurrent streams under a
|
||||
# shared session_id; a single session cancel POST must hit all of them.
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
m["_PENDING_CANCELS"].clear()
|
||||
session = "shared-thread"
|
||||
ev_a = threading.Event()
|
||||
ev_b = threading.Event()
|
||||
tracker_a = m["_TrackedCancel"](ev_a, "cancel-a", session, "chatcmpl-a")
|
||||
tracker_b = m["_TrackedCancel"](ev_b, "cancel-b", session, "chatcmpl-b")
|
||||
tracker_a.__enter__()
|
||||
tracker_b.__enter__()
|
||||
try:
|
||||
assert m["_cancel_by_keys"]([session]) == 2
|
||||
assert ev_a.is_set() and ev_b.is_set()
|
||||
finally:
|
||||
tracker_a.__exit__(None, None, None)
|
||||
tracker_b.__exit__(None, None, None)
|
||||
assert session not in m["_CANCEL_REGISTRY"]
|
||||
|
||||
|
||||
def test_cancel_by_cancel_id_is_exclusive_to_single_run():
|
||||
# cancel_id is per-run unique; cancelling run A must not touch run B
|
||||
# even when both share a session_id.
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
m["_PENDING_CANCELS"].clear()
|
||||
session = "shared-thread-2"
|
||||
ev_a = threading.Event()
|
||||
ev_b = threading.Event()
|
||||
tracker_a = m["_TrackedCancel"](ev_a, "cancel-only-a", session, "chatcmpl-a")
|
||||
tracker_b = m["_TrackedCancel"](ev_b, "cancel-only-b", session, "chatcmpl-b")
|
||||
tracker_a.__enter__()
|
||||
tracker_b.__enter__()
|
||||
try:
|
||||
assert m["_cancel_by_cancel_id_or_stash"]("cancel-only-a") == 1
|
||||
assert ev_a.is_set()
|
||||
assert not ev_b.is_set()
|
||||
finally:
|
||||
tracker_a.__exit__(None, None, None)
|
||||
tracker_b.__exit__(None, None, None)
|
||||
|
||||
|
||||
def test_tracked_cancel_exit_is_idempotent():
|
||||
# Outer except BaseException + the generator's finally may both call
|
||||
# __exit__ under certain race combos; must not raise.
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
m["_PENDING_CANCELS"].clear()
|
||||
ev = threading.Event()
|
||||
tracker = m["_TrackedCancel"](ev, "cid", "sess", "chatcmpl-x")
|
||||
tracker.__enter__()
|
||||
tracker.__exit__(None, None, None)
|
||||
tracker.__exit__(None, None, None)
|
||||
tracker.__exit__(None, None, None)
|
||||
assert not m["_CANCEL_REGISTRY"]
|
||||
169
tests/studio/test_cancel_id_wiring.py
Normal file
169
tests/studio/test_cancel_id_wiring.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
Wiring tests for the per-run cancel_id field.
|
||||
|
||||
A chat-thread-scoped session_id is not safe as a cancel key because a
|
||||
late stop POST can match a subsequent run on the same thread. The fix
|
||||
adds cancel_id (a fresh UUID per generation) that is sent both in the
|
||||
completion payload and in the /api/inference/cancel body.
|
||||
|
||||
Verifies:
|
||||
- ChatCompletionRequest exposes an Optional[str] `cancel_id` field.
|
||||
- /api/inference/cancel accepts `cancel_id` as the first-preferred key.
|
||||
- OpenAIChatCompletionsRequest (frontend type) includes cancel_id.
|
||||
- chat-adapter.ts generates a per-run cancelId (crypto.randomUUID
|
||||
with a Math.random fallback), sends it in the completion payload,
|
||||
and includes it in the /inference/cancel body on abort.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
WORKSPACE = Path(__file__).resolve().parents[2]
|
||||
MODELS_SRC = (WORKSPACE / "studio/backend/models/inference.py").read_text()
|
||||
ROUTES_SRC = (WORKSPACE / "studio/backend/routes/inference.py").read_text()
|
||||
ADAPTER_SRC = (
|
||||
WORKSPACE / "studio/frontend/src/features/chat/api/chat-adapter.ts"
|
||||
).read_text()
|
||||
API_TYPES_SRC = (
|
||||
WORKSPACE / "studio/frontend/src/features/chat/types/api.ts"
|
||||
).read_text()
|
||||
|
||||
|
||||
def _find_class(tree: ast.AST, name: str) -> ast.ClassDef | None:
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
def test_chat_completion_request_has_cancel_id_field():
|
||||
tree = ast.parse(MODELS_SRC)
|
||||
cls = _find_class(tree, "ChatCompletionRequest")
|
||||
assert cls is not None
|
||||
fields = {
|
||||
n.target.id
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name)
|
||||
}
|
||||
assert "cancel_id" in fields, (
|
||||
"ChatCompletionRequest must expose a cancel_id field for per-run "
|
||||
"cancellation routing"
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_route_matches_cancel_id_exclusively_when_present():
|
||||
# A stale cancel POST carrying cancel_id AND session_id must not
|
||||
# cancel a later run on the same thread via the shared session_id.
|
||||
# Enforce this by requiring the handler to early-return through an
|
||||
# exclusive-cancel_id path -- either an atomic helper or a keys
|
||||
# list containing ONLY cancel_id (never session_id).
|
||||
for node in ast.walk(ast.parse(ROUTES_SRC)):
|
||||
if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference":
|
||||
break
|
||||
else:
|
||||
raise AssertionError("cancel_inference handler missing")
|
||||
|
||||
cancel_id_exclusive_branch = False
|
||||
for sub in ast.walk(node):
|
||||
if not isinstance(sub, ast.If):
|
||||
continue
|
||||
test_src = ast.unparse(sub.test)
|
||||
if "cancel_id" not in test_src or "isinstance" not in test_src:
|
||||
continue
|
||||
branch_src = "\n".join(ast.unparse(s) for s in sub.body)
|
||||
before_return = branch_src.split("return", 1)[0]
|
||||
matches_cancel_id_only = (
|
||||
"_cancel_by_cancel_id_or_stash(cancel_id)" in branch_src
|
||||
or "_cancel_by_keys([cancel_id])" in branch_src
|
||||
)
|
||||
if matches_cancel_id_only and "session_id" not in before_return:
|
||||
cancel_id_exclusive_branch = True
|
||||
break
|
||||
assert cancel_id_exclusive_branch, (
|
||||
"cancel_inference must early-return with an exclusive cancel_id "
|
||||
"match when a cancel_id is supplied, so a stale stop POST "
|
||||
"cannot cancel a later run on the same thread via session_id"
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_route_falls_back_to_session_or_completion_when_no_cancel_id():
|
||||
for node in ast.walk(ast.parse(ROUTES_SRC)):
|
||||
if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference":
|
||||
break
|
||||
else:
|
||||
raise AssertionError("cancel_inference handler missing")
|
||||
|
||||
src = ast.unparse(node)
|
||||
assert "session_id" in src and "completion_id" in src, (
|
||||
"cancel_inference must still accept session_id / completion_id as "
|
||||
"fallback keys when cancel_id is absent"
|
||||
)
|
||||
|
||||
|
||||
def test_frontend_request_type_has_cancel_id():
|
||||
assert re.search(
|
||||
r"cancel_id\?\s*:\s*string\s*;", API_TYPES_SRC
|
||||
), "OpenAIChatCompletionsRequest must expose an optional cancel_id"
|
||||
|
||||
|
||||
def test_chat_adapter_generates_cancel_id_per_run():
|
||||
m = re.search(
|
||||
r"const\s+cancelId\s*=\s*([^;]+);",
|
||||
ADAPTER_SRC,
|
||||
)
|
||||
assert m, "chat-adapter.ts must declare a per-run `cancelId` constant"
|
||||
rhs = m.group(1)
|
||||
assert (
|
||||
"randomUUID" in rhs
|
||||
), "cancelId should prefer crypto.randomUUID() for uniqueness"
|
||||
|
||||
|
||||
def test_chat_adapter_sends_cancel_id_in_completion_payload():
|
||||
assert "cancel_id: cancelId" in ADAPTER_SRC, (
|
||||
"chat-adapter.ts must include cancel_id in the streamChatCompletions "
|
||||
"request payload so the backend registers under that key"
|
||||
)
|
||||
|
||||
|
||||
def test_chat_adapter_sends_cancel_id_in_abort_cancel_post():
|
||||
m = re.search(
|
||||
r"const\s+onAbortCancel\s*=\s*\(\)\s*=>\s*\{(.*?)\};",
|
||||
ADAPTER_SRC,
|
||||
flags = re.DOTALL,
|
||||
)
|
||||
assert m, "onAbortCancel arrow function missing"
|
||||
body = m.group(1)
|
||||
assert re.search(r"cancel_id\s*:\s*cancelId", body), (
|
||||
"onAbortCancel must include cancel_id in the /inference/cancel body "
|
||||
"so a stop POST matches the specific run, not the whole thread"
|
||||
)
|
||||
|
||||
|
||||
def test_abort_cancel_post_uses_plain_fetch_with_manual_auth_header():
|
||||
# authFetch redirects to login on 401, which would kick the user to
|
||||
# the login page mid-stop if the access token expired during a long
|
||||
# stream. Use plain fetch + manual Authorization header for a
|
||||
# best-effort cancel that never triggers the refresh/redirect flow.
|
||||
start = ADAPTER_SRC.find("const onAbortCancel")
|
||||
assert start >= 0, "onAbortCancel handler missing"
|
||||
rest = ADAPTER_SRC[start:]
|
||||
end = rest.find("\n try {")
|
||||
body = rest if end < 0 else rest[:end]
|
||||
assert "/api/inference/cancel" in body
|
||||
assert "authFetch(" not in body, (
|
||||
"onAbortCancel must NOT call authFetch; a 401 from it would "
|
||||
"redirect the user to the login page during a stop click"
|
||||
)
|
||||
assert "fetch(" in body, "onAbortCancel must use plain fetch(...)"
|
||||
assert "getAuthToken" in body, (
|
||||
"onAbortCancel must read the bearer token via getAuthToken() "
|
||||
"rather than relying on authFetch's 401 flow"
|
||||
)
|
||||
assert "Authorization" in body
|
||||
assert (
|
||||
"keepalive: true" in body
|
||||
), "keepalive is required so the fetch survives page unload during stop"
|
||||
123
tests/studio/test_llama_cpp_wall_clock_cap.py
Normal file
123
tests/studio/test_llama_cpp_wall_clock_cap.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
Tests for the llama-server wall-clock cap (t_max_predict_ms).
|
||||
|
||||
The UI always sends max_tokens = context_length, so gating
|
||||
t_max_predict_ms on `max_tokens is None` makes the safety net dead
|
||||
code. The fix applies the wall-clock cap unconditionally on all three
|
||||
streaming payload sites and raises the default to 10 minutes so slow
|
||||
CPU / macOS / Windows installs are not cut off mid-generation.
|
||||
|
||||
Verifies:
|
||||
- t_max_predict_ms is assigned unconditionally at the three
|
||||
payload-builder sites (not inside an `if max_tokens is None` else
|
||||
branch).
|
||||
- _DEFAULT_T_MAX_PREDICT_MS is at least 10 minutes (previously
|
||||
120_000).
|
||||
- The default max_tokens path still applies _DEFAULT_MAX_TOKENS.
|
||||
- The three payload variable names (payload x2, stream_payload x1)
|
||||
each get both `max_tokens` and `t_max_predict_ms`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SOURCE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "studio"
|
||||
/ "backend"
|
||||
/ "core"
|
||||
/ "inference"
|
||||
/ "llama_cpp.py"
|
||||
)
|
||||
SRC = SOURCE_PATH.read_text()
|
||||
TREE = ast.parse(SRC)
|
||||
|
||||
|
||||
def _is_subscript_assign(stmt: ast.stmt, target_name: str, key: str) -> bool:
|
||||
if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1:
|
||||
return False
|
||||
t = stmt.targets[0]
|
||||
if not isinstance(t, ast.Subscript):
|
||||
return False
|
||||
if not (isinstance(t.value, ast.Name) and t.value.id == target_name):
|
||||
return False
|
||||
slc = t.slice
|
||||
return isinstance(slc, ast.Constant) and slc.value == key
|
||||
|
||||
|
||||
def _collect_assignments(tree, target_name, key):
|
||||
"""Return list of (node, stack_of_enclosing_ifs) for each match."""
|
||||
hits = []
|
||||
|
||||
def visit(node, stack):
|
||||
if _is_subscript_assign(node, target_name, key):
|
||||
hits.append((node, stack))
|
||||
for child in ast.iter_child_nodes(node):
|
||||
if isinstance(child, ast.If):
|
||||
for sub in child.body:
|
||||
visit(sub, stack + [(child, "body")])
|
||||
for sub in child.orelse:
|
||||
visit(sub, stack + [(child, "orelse")])
|
||||
else:
|
||||
visit(child, stack)
|
||||
|
||||
visit(tree, [])
|
||||
return hits
|
||||
|
||||
|
||||
def test_default_t_max_predict_ms_is_at_least_ten_minutes():
|
||||
for node in TREE.body:
|
||||
if isinstance(node, ast.Assign) and len(node.targets) == 1:
|
||||
t = node.targets[0]
|
||||
if isinstance(t, ast.Name) and t.id == "_DEFAULT_T_MAX_PREDICT_MS":
|
||||
value = node.value
|
||||
assert isinstance(value, ast.Constant)
|
||||
assert value.value >= 600_000, (
|
||||
f"_DEFAULT_T_MAX_PREDICT_MS must be >= 10 minutes "
|
||||
f"(600_000 ms) to avoid cutting off slow-CPU generations; "
|
||||
f"got {value.value}"
|
||||
)
|
||||
return
|
||||
raise AssertionError("_DEFAULT_T_MAX_PREDICT_MS constant missing")
|
||||
|
||||
|
||||
def test_t_max_predict_ms_set_unconditionally_at_three_sites():
|
||||
hits_payload = _collect_assignments(TREE, "payload", "t_max_predict_ms")
|
||||
hits_stream = _collect_assignments(TREE, "stream_payload", "t_max_predict_ms")
|
||||
total = len(hits_payload) + len(hits_stream)
|
||||
assert total == 3, (
|
||||
f"expected 3 total t_max_predict_ms assignments "
|
||||
f"(payload x2 + stream_payload x1), got {total}"
|
||||
)
|
||||
for node, stack in hits_payload + hits_stream:
|
||||
for parent_if, branch in stack:
|
||||
# The assignment must not be gated by a test that checks
|
||||
# `max_tokens is None` (which would make it dead code for
|
||||
# the UI path where max_tokens is always set).
|
||||
test_src = ast.unparse(parent_if.test)
|
||||
assert "max_tokens" not in test_src, (
|
||||
f"t_max_predict_ms at line {node.lineno} is nested under "
|
||||
f"`if {test_src}:` -- it must be applied unconditionally so "
|
||||
f"the wall-clock cap is not dead code for callers that set "
|
||||
f"max_tokens"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tokens_default_cap_still_applied():
|
||||
# _DEFAULT_MAX_TOKENS must still kick in when caller passes None.
|
||||
# We check the conditional expression `max_tokens if max_tokens is not
|
||||
# None else _DEFAULT_MAX_TOKENS` appears at each site.
|
||||
matches = 0
|
||||
for node in ast.walk(TREE):
|
||||
if not isinstance(node, ast.IfExp):
|
||||
continue
|
||||
src = ast.unparse(node)
|
||||
if "max_tokens" in src and "_DEFAULT_MAX_TOKENS" in src:
|
||||
matches += 1
|
||||
assert matches >= 3, (
|
||||
f"expected >=3 `max_tokens if max_tokens is not None else "
|
||||
f"_DEFAULT_MAX_TOKENS` expressions; got {matches}"
|
||||
)
|
||||
718
tests/studio/test_stream_cancel_registration_timing.py
Normal file
718
tests/studio/test_stream_cancel_registration_timing.py
Normal file
|
|
@ -0,0 +1,718 @@
|
|||
"""
|
||||
Tests that the cancel tracker is registered BEFORE StreamingResponse is
|
||||
returned, and that cleanup runs via a `finally` block inside each
|
||||
async generator.
|
||||
|
||||
The zombie-generation scenario is: user clicks Stop during prefill /
|
||||
warmup / proxy buffering, before the first SSE chunk. If _tracker
|
||||
__enter__ lives inside the async generator body, the registry is empty
|
||||
at the moment /api/inference/cancel lands -- so cancel returns 0 and
|
||||
the decode runs to completion.
|
||||
|
||||
The fix moves _tracker = _TrackedCancel(...) and _tracker.__enter__()
|
||||
to the synchronous body of openai_chat_completions (before the
|
||||
StreamingResponse is returned) and places _tracker.__exit__ inside
|
||||
each generator's `finally` block. Using a generator `finally` (rather
|
||||
than a Starlette BackgroundTask) guarantees cleanup on every
|
||||
termination path -- normal exhaustion, CancelledError from
|
||||
ClientDisconnect, and OSError / BrokenPipeError during send() --
|
||||
because Starlette skips `background` callbacks when stream_response
|
||||
raises.
|
||||
|
||||
Structural verifies:
|
||||
- No `async def ...:` body contains `_tracker.__enter__()` in
|
||||
routes/inference.py (registration moved to sync body).
|
||||
- Each of the four async generators (gguf_tool_stream,
|
||||
gguf_stream_chunks, stream_chunks, audio_input_stream) contains
|
||||
`_tracker.__exit__(None, None, None)` inside a try/finally block.
|
||||
- No StreamingResponse in openai_chat_completions passes
|
||||
`background=` (cleanup now lives in the generator finally).
|
||||
|
||||
Behavioral verifies (extracting `_TrackedCancel` from source and
|
||||
exercising the actual runtime semantics):
|
||||
- `finally: _tracker.__exit__(...)` runs on normal completion,
|
||||
mid-stream exception (OSError / BrokenPipeError from send()),
|
||||
and aclose() from Starlette ClientDisconnect.
|
||||
- A pre-set cancel_event (from `_TrackedCancel.__enter__` replaying
|
||||
a pending cancel POST) lets the GGUF while-loop break cleanly
|
||||
and emit final_chunk + [DONE] instead of propagating
|
||||
`GeneratorExit` out of `_stream_with_retry` into the async
|
||||
generator's `except Exception` (which would not catch it).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SOURCE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "studio"
|
||||
/ "backend"
|
||||
/ "routes"
|
||||
/ "inference.py"
|
||||
)
|
||||
SRC = SOURCE_PATH.read_text()
|
||||
_TREE = ast.parse(SRC)
|
||||
|
||||
|
||||
# ── Structural (AST) helpers ─────────────────────────────────
|
||||
|
||||
|
||||
def _collect_async_functions(tree: ast.AST):
|
||||
return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)]
|
||||
|
||||
|
||||
def _has_tracker_enter_call(node: ast.AST) -> bool:
|
||||
for sub in ast.walk(node):
|
||||
if not isinstance(sub, ast.Call):
|
||||
continue
|
||||
fn = sub.func
|
||||
if (
|
||||
isinstance(fn, ast.Attribute)
|
||||
and fn.attr == "__enter__"
|
||||
and isinstance(fn.value, ast.Name)
|
||||
and fn.value.id.startswith("_tracker")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _finalbody_has_tracker_exit(finalbody) -> bool:
|
||||
for stmt in finalbody:
|
||||
if not isinstance(stmt, ast.Expr):
|
||||
continue
|
||||
call = stmt.value
|
||||
if not (isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute)):
|
||||
continue
|
||||
fn = call.func
|
||||
if (
|
||||
fn.attr == "__exit__"
|
||||
and isinstance(fn.value, ast.Name)
|
||||
and fn.value.id.startswith("_tracker")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ── Structural tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_tracker_enter_inside_async_generators():
|
||||
offenders = []
|
||||
for fn in _collect_async_functions(_TREE):
|
||||
if fn.name in {
|
||||
"gguf_tool_stream",
|
||||
"gguf_stream_chunks",
|
||||
"stream_chunks",
|
||||
"audio_input_stream",
|
||||
}:
|
||||
if _has_tracker_enter_call(fn):
|
||||
offenders.append(fn.name)
|
||||
assert not offenders, (
|
||||
f"Cancel tracker registration must live OUTSIDE the async generator "
|
||||
f"body so a stop POST can find the registry entry before the first "
|
||||
f"SSE chunk. Offending generators: {offenders}"
|
||||
)
|
||||
|
||||
|
||||
def test_tracker_enter_exists_in_sync_body_of_chat_completions():
|
||||
top = None
|
||||
for n in ast.walk(_TREE):
|
||||
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
||||
top = n
|
||||
break
|
||||
assert top is not None, "openai_chat_completions handler missing"
|
||||
count = 0
|
||||
for sub in ast.walk(top):
|
||||
if not isinstance(sub, ast.Call):
|
||||
continue
|
||||
fn = sub.func
|
||||
if (
|
||||
isinstance(fn, ast.Attribute)
|
||||
and fn.attr == "__enter__"
|
||||
and isinstance(fn.value, ast.Name)
|
||||
and fn.value.id.startswith("_tracker")
|
||||
):
|
||||
count += 1
|
||||
assert count >= 3, (
|
||||
f"expected >=3 _tracker.__enter__() calls in openai_chat_completions "
|
||||
f"(one per streaming path), got {count}"
|
||||
)
|
||||
|
||||
|
||||
def test_async_generators_cleanup_tracker_in_finally():
|
||||
required = {
|
||||
"gguf_tool_stream",
|
||||
"gguf_stream_chunks",
|
||||
"stream_chunks",
|
||||
"audio_input_stream",
|
||||
}
|
||||
found: set[str] = set()
|
||||
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
|
||||
if fn.name not in required:
|
||||
continue
|
||||
for sub in ast.walk(fn):
|
||||
if isinstance(sub, ast.Try) and sub.finalbody:
|
||||
if _finalbody_has_tracker_exit(sub.finalbody):
|
||||
found.add(fn.name)
|
||||
break
|
||||
missing = required - found
|
||||
assert not missing, (
|
||||
f"Cleanup must run via `finally: _tracker.__exit__(None, None, None)` "
|
||||
f"inside each streaming generator so ClientDisconnect / OSError paths "
|
||||
f"also release registry entries (Starlette skips `background` callbacks "
|
||||
f"when stream_response raises). Missing in: {sorted(missing)}"
|
||||
)
|
||||
|
||||
|
||||
def test_streaming_responses_have_no_background_task():
|
||||
top = None
|
||||
for n in ast.walk(_TREE):
|
||||
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
||||
top = n
|
||||
break
|
||||
assert top is not None
|
||||
for sub in ast.walk(top):
|
||||
if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)):
|
||||
continue
|
||||
if sub.func.id != "StreamingResponse":
|
||||
continue
|
||||
kwargs = {kw.arg for kw in sub.keywords if kw.arg}
|
||||
assert "background" not in kwargs, (
|
||||
"StreamingResponse in openai_chat_completions must not pass "
|
||||
"`background=` -- cleanup now lives in the generator's finally "
|
||||
"block; a BackgroundTask would be skipped on abrupt disconnect"
|
||||
)
|
||||
|
||||
|
||||
# ── Behavioral helpers ───────────────────────────────────────
|
||||
|
||||
_WANTED = {
|
||||
"_CANCEL_REGISTRY",
|
||||
"_CANCEL_LOCK",
|
||||
"_PENDING_CANCELS",
|
||||
"_PENDING_CANCEL_TTL_S",
|
||||
"_prune_pending",
|
||||
"_TrackedCancel",
|
||||
"_cancel_by_keys",
|
||||
"_cancel_by_cancel_id_or_stash",
|
||||
}
|
||||
|
||||
|
||||
def _load_registry_module():
|
||||
chunks = []
|
||||
for n in _TREE.body:
|
||||
seg = ast.get_source_segment(SRC, n)
|
||||
if seg is None:
|
||||
continue
|
||||
if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED:
|
||||
chunks.append(seg)
|
||||
elif isinstance(n, ast.Assign):
|
||||
names = [t.id for t in n.targets if isinstance(t, ast.Name)]
|
||||
if any(name in _WANTED for name in names):
|
||||
chunks.append(seg)
|
||||
elif (
|
||||
isinstance(n, ast.AnnAssign)
|
||||
and isinstance(n.target, ast.Name)
|
||||
and n.target.id in _WANTED
|
||||
):
|
||||
chunks.append(seg)
|
||||
mod = {}
|
||||
exec("import threading, time\n" + "\n\n".join(chunks), mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_stream(tracker, raise_exc):
|
||||
async def gen():
|
||||
try:
|
||||
try:
|
||||
yield "data: first\n\n"
|
||||
if raise_exc is not None:
|
||||
raise raise_exc
|
||||
yield "data: [DONE]\n\n"
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
yield "data: error\n\n"
|
||||
finally:
|
||||
tracker.__exit__(None, None, None)
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
return gen()
|
||||
|
||||
|
||||
async def _consume(agen):
|
||||
out = []
|
||||
try:
|
||||
async for ch in agen:
|
||||
out.append(ch)
|
||||
except BaseException as e:
|
||||
out.append(type(e).__name__)
|
||||
return out
|
||||
|
||||
|
||||
def _llama_stub_raises_on_preset_cancel(cancel_event):
|
||||
# Reproduces llama_cpp.py _stream_with_retry:2240 `raise GeneratorExit`
|
||||
# when cancel_event is already set at entry.
|
||||
if cancel_event.is_set():
|
||||
raise GeneratorExit
|
||||
yield "cumulative-1"
|
||||
yield "cumulative-2"
|
||||
|
||||
|
||||
async def _post_fix_gguf_loop(cancel_event):
|
||||
yield "first_chunk"
|
||||
gen = _llama_stub_raises_on_preset_cancel(cancel_event)
|
||||
sentinel = object()
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
cumulative = await asyncio.to_thread(next, gen, sentinel)
|
||||
if cumulative is sentinel:
|
||||
break
|
||||
yield cumulative
|
||||
yield "final_chunk"
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
# ── Behavioral tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_finally_cleanup_on_normal_completion():
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
ev = threading.Event()
|
||||
tr = m["_TrackedCancel"](ev, "cid-ok", "sid-ok")
|
||||
tr.__enter__()
|
||||
assert "cid-ok" in m["_CANCEL_REGISTRY"]
|
||||
chunks = asyncio.run(_consume(_make_stream(tr, None)))
|
||||
assert chunks == ["data: first\n\n", "data: [DONE]\n\n"]
|
||||
assert "cid-ok" not in m["_CANCEL_REGISTRY"]
|
||||
assert "sid-ok" not in m["_CANCEL_REGISTRY"]
|
||||
|
||||
|
||||
def test_finally_cleanup_on_mid_stream_exception():
|
||||
# Simulates OSError / BrokenPipeError from Starlette send() mid-stream --
|
||||
# the exact case where pre-fix `background = BackgroundTask(...)` was
|
||||
# skipped and leaked the registry entry.
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
ev = threading.Event()
|
||||
tr = m["_TrackedCancel"](ev, "cid-err", "sid-err")
|
||||
tr.__enter__()
|
||||
assert "cid-err" in m["_CANCEL_REGISTRY"]
|
||||
asyncio.run(_consume(_make_stream(tr, OSError("disconnect"))))
|
||||
assert "cid-err" not in m["_CANCEL_REGISTRY"]
|
||||
assert "sid-err" not in m["_CANCEL_REGISTRY"]
|
||||
|
||||
|
||||
def test_finally_cleanup_on_aclose():
|
||||
# Starlette calls aclose() on the async generator when the client
|
||||
# disconnects mid-stream. The generator's finally block must run.
|
||||
m = _load_registry_module()
|
||||
m["_CANCEL_REGISTRY"].clear()
|
||||
ev = threading.Event()
|
||||
tr = m["_TrackedCancel"](ev, "cid-abort", "sid-abort")
|
||||
tr.__enter__()
|
||||
assert "cid-abort" in m["_CANCEL_REGISTRY"]
|
||||
|
||||
async def run():
|
||||
gen = _make_stream(tr, None)
|
||||
it = gen.__aiter__()
|
||||
await it.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
asyncio.run(run())
|
||||
assert "cid-abort" not in m["_CANCEL_REGISTRY"]
|
||||
assert "sid-abort" not in m["_CANCEL_REGISTRY"]
|
||||
|
||||
|
||||
def test_preset_cancel_event_exits_cleanly_with_done():
|
||||
# Pending-replay: POST /cancel arrived before the stream registered,
|
||||
# was stashed, then consumed by _TrackedCancel.__enter__ which set
|
||||
# cancel_event. The generator must break out of the loop cleanly
|
||||
# and emit final_chunk + [DONE] rather than calling next(gen) and
|
||||
# propagating `GeneratorExit` out of the GGUF stream wrapper.
|
||||
ev = threading.Event()
|
||||
ev.set()
|
||||
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
|
||||
assert "first_chunk" in chunks
|
||||
assert "final_chunk" in chunks
|
||||
assert "[DONE]" in chunks
|
||||
assert "GeneratorExit" not in chunks
|
||||
assert "cumulative-1" not in chunks
|
||||
assert "cumulative-2" not in chunks
|
||||
|
||||
|
||||
def test_normal_path_streams_all_tokens():
|
||||
# Regression: the top-of-loop cancel_event check must not short-circuit
|
||||
# when cancel_event is unset.
|
||||
ev = threading.Event()
|
||||
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
|
||||
assert chunks == [
|
||||
"first_chunk",
|
||||
"cumulative-1",
|
||||
"cumulative-2",
|
||||
"final_chunk",
|
||||
"[DONE]",
|
||||
]
|
||||
|
||||
|
||||
def test_cancel_during_streaming_stops_iteration_promptly():
|
||||
# Setting cancel_event between yields breaks out on the next iteration
|
||||
# rather than draining the stub generator.
|
||||
ev = threading.Event()
|
||||
|
||||
async def _run():
|
||||
gen = _post_fix_gguf_loop(ev)
|
||||
seen = []
|
||||
async for ch in gen:
|
||||
seen.append(ch)
|
||||
if ch == "cumulative-1":
|
||||
ev.set()
|
||||
return seen
|
||||
|
||||
seen = asyncio.run(_run())
|
||||
assert "first_chunk" in seen
|
||||
assert "cumulative-1" in seen
|
||||
assert "cumulative-2" not in seen
|
||||
assert "final_chunk" in seen
|
||||
assert "[DONE]" in seen
|
||||
|
||||
|
||||
# ── Cancel-event responsiveness in the streaming loops ───────
|
||||
|
||||
|
||||
def _loop_has_cancel_event_check(fn) -> bool:
|
||||
# An `if cancel_event.is_set():` statement anywhere inside a
|
||||
# `while`/`for` loop body is sufficient -- without it, a cancel POST
|
||||
# cannot interrupt the loop because Colab-style proxies do not
|
||||
# propagate request.is_disconnected().
|
||||
for sub in ast.walk(fn):
|
||||
if not isinstance(sub, (ast.While, ast.For, ast.AsyncFor)):
|
||||
continue
|
||||
for stmt in ast.walk(sub):
|
||||
if not isinstance(stmt, ast.If):
|
||||
continue
|
||||
t = stmt.test
|
||||
if (
|
||||
isinstance(t, ast.Call)
|
||||
and isinstance(t.func, ast.Attribute)
|
||||
and t.func.attr == "is_set"
|
||||
and isinstance(t.func.value, ast.Name)
|
||||
and t.func.value.id == "cancel_event"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_streaming_generators_check_cancel_event_in_loop():
|
||||
required = {
|
||||
"gguf_tool_stream",
|
||||
"gguf_stream_chunks",
|
||||
"stream_chunks",
|
||||
"audio_input_stream",
|
||||
}
|
||||
missing = []
|
||||
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
|
||||
if fn.name not in required:
|
||||
continue
|
||||
if not _loop_has_cancel_event_check(fn):
|
||||
missing.append(fn.name)
|
||||
assert not missing, (
|
||||
f"Each streaming generator must check `cancel_event.is_set()` inside "
|
||||
f"its main loop so `POST /api/inference/cancel` can interrupt the "
|
||||
f"stream through proxies that do not forward fetch aborts. "
|
||||
f"Missing in: {sorted(missing)}"
|
||||
)
|
||||
|
||||
|
||||
def test_audio_input_stream_offloads_blocking_next_to_thread():
|
||||
# Guards against regression back to `for chunk_text in
|
||||
# audio_input_generate():` -- which blocks the event loop on each
|
||||
# whisper chunk and prevents POST /api/inference/cancel from being
|
||||
# serviced until the chunk yields.
|
||||
audio = None
|
||||
for fn in ast.walk(_TREE):
|
||||
if isinstance(fn, ast.AsyncFunctionDef) and fn.name == "audio_input_stream":
|
||||
audio = fn
|
||||
break
|
||||
assert audio is not None, "audio_input_stream generator missing"
|
||||
|
||||
for sub in ast.walk(audio):
|
||||
if isinstance(sub, (ast.For, ast.AsyncFor)):
|
||||
it_src = ast.unparse(sub.iter)
|
||||
assert "audio_input_generate" not in it_src, (
|
||||
"audio_input_stream must not iterate audio_input_generate() "
|
||||
"directly -- that blocks the event loop. Use "
|
||||
"`await asyncio.to_thread(next, gen, _DONE)` inside a "
|
||||
"`while True` loop instead"
|
||||
)
|
||||
|
||||
found_to_thread_next = False
|
||||
for sub in ast.walk(audio):
|
||||
if not isinstance(sub, ast.Call):
|
||||
continue
|
||||
fn_expr = sub.func
|
||||
if not (
|
||||
isinstance(fn_expr, ast.Attribute)
|
||||
and fn_expr.attr == "to_thread"
|
||||
and isinstance(fn_expr.value, ast.Name)
|
||||
and fn_expr.value.id == "asyncio"
|
||||
):
|
||||
continue
|
||||
if sub.args and isinstance(sub.args[0], ast.Name) and sub.args[0].id == "next":
|
||||
found_to_thread_next = True
|
||||
break
|
||||
assert found_to_thread_next, (
|
||||
"audio_input_stream must call `asyncio.to_thread(next, gen, ...)` "
|
||||
"to keep the event loop free while whisper yields the next chunk"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chunks_cancel_branch_resets_backend_state():
|
||||
# The Unsloth path's cancel branch must flush GPU / KV-cache state
|
||||
# via `backend.reset_generation_state()` -- the orchestrator's
|
||||
# internal cancel path does not do this, so a cancel-via-POST that
|
||||
# only broke the loop would leave the subprocess in a dirty state
|
||||
# for the next request.
|
||||
fn = None
|
||||
top = None
|
||||
for n in ast.walk(_TREE):
|
||||
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
||||
top = n
|
||||
break
|
||||
assert top is not None
|
||||
for n in ast.walk(top):
|
||||
if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks":
|
||||
fn = n
|
||||
break
|
||||
assert fn is not None, "stream_chunks generator missing"
|
||||
|
||||
for sub in ast.walk(fn):
|
||||
if not isinstance(sub, ast.If):
|
||||
continue
|
||||
t = sub.test
|
||||
if not (
|
||||
isinstance(t, ast.Call)
|
||||
and isinstance(t.func, ast.Attribute)
|
||||
and t.func.attr == "is_set"
|
||||
and isinstance(t.func.value, ast.Name)
|
||||
and t.func.value.id == "cancel_event"
|
||||
):
|
||||
continue
|
||||
body_src = "\n".join(ast.unparse(s) for s in sub.body)
|
||||
if "backend.reset_generation_state()" in body_src:
|
||||
return
|
||||
raise AssertionError(
|
||||
"stream_chunks `if cancel_event.is_set():` branch must call "
|
||||
"backend.reset_generation_state() -- matches the existing "
|
||||
"request.is_disconnected() / CancelledError cleanup paths and "
|
||||
"prevents KV-cache drift after cancel-via-POST"
|
||||
)
|
||||
|
||||
|
||||
# ── Behavioral simulations for the iter-1 fixes ──────────────
|
||||
|
||||
|
||||
def test_unsloth_stream_loop_breaks_on_external_cancel_event():
|
||||
cancel_event = threading.Event()
|
||||
reset_calls = [0]
|
||||
|
||||
class _Backend:
|
||||
def reset_generation_state(self):
|
||||
reset_calls[0] += 1
|
||||
|
||||
backend = _Backend()
|
||||
|
||||
def _generate():
|
||||
for i in range(200):
|
||||
time.sleep(0.005)
|
||||
yield f"cum-{i}"
|
||||
|
||||
async def _loop():
|
||||
_DONE = object()
|
||||
loop = asyncio.get_event_loop()
|
||||
gen = _generate()
|
||||
seen = []
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
backend.reset_generation_state()
|
||||
break
|
||||
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
|
||||
if cumulative is _DONE:
|
||||
break
|
||||
seen.append(cumulative)
|
||||
return seen
|
||||
|
||||
async def _fire():
|
||||
await asyncio.sleep(0.05)
|
||||
cancel_event.set()
|
||||
|
||||
async def _main():
|
||||
return await asyncio.gather(_loop(), _fire())
|
||||
|
||||
seen, _ = asyncio.run(_main())
|
||||
assert (
|
||||
len(seen) < 200
|
||||
), f"loop must not drain the generator after cancel; got {len(seen)} tokens"
|
||||
assert reset_calls[0] == 1, (
|
||||
f"backend.reset_generation_state() must be called exactly once on "
|
||||
f"cancel-via-POST, got {reset_calls[0]}"
|
||||
)
|
||||
|
||||
|
||||
def test_audio_stream_stays_responsive_under_blocking_next():
|
||||
# Regression guard: replace the post-fix loop with the pre-fix
|
||||
# `for chunk in audio_input_generate()` pattern and assert it blocks
|
||||
# the event loop; then confirm the post-fix pattern exits promptly.
|
||||
cancel_event = threading.Event()
|
||||
|
||||
def _audio_gen():
|
||||
for i in range(8):
|
||||
time.sleep(0.15)
|
||||
yield f"chunk-{i}"
|
||||
|
||||
async def _prefix_loop():
|
||||
seen = []
|
||||
for chunk_text in _audio_gen():
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
seen.append(chunk_text)
|
||||
return seen
|
||||
|
||||
async def _postfix_loop():
|
||||
_DONE = object()
|
||||
gen = _audio_gen()
|
||||
seen = []
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
chunk_text = await asyncio.to_thread(next, gen, _DONE)
|
||||
if chunk_text is _DONE:
|
||||
break
|
||||
seen.append(chunk_text)
|
||||
return seen
|
||||
|
||||
async def _fire_early():
|
||||
await asyncio.sleep(0.05)
|
||||
cancel_event.set()
|
||||
|
||||
async def _run(loop_coro):
|
||||
return await asyncio.gather(loop_coro, _fire_early())
|
||||
|
||||
cancel_event.clear()
|
||||
t0 = time.monotonic()
|
||||
prefix_seen, _ = asyncio.run(_run(_prefix_loop()))
|
||||
prefix_elapsed = time.monotonic() - t0
|
||||
assert prefix_elapsed >= 0.13, (
|
||||
f"pre-fix pattern should block event loop for >=1 chunk time "
|
||||
f"(~150ms); got {prefix_elapsed:.3f}s, {len(prefix_seen)} chunks"
|
||||
)
|
||||
|
||||
cancel_event.clear()
|
||||
t0 = time.monotonic()
|
||||
postfix_seen, _ = asyncio.run(_run(_postfix_loop()))
|
||||
postfix_elapsed = time.monotonic() - t0
|
||||
assert postfix_elapsed < prefix_elapsed, (
|
||||
f"post-fix pattern must exit faster than pre-fix (blocking) "
|
||||
f"pattern; post={postfix_elapsed:.3f}s vs pre={prefix_elapsed:.3f}s"
|
||||
)
|
||||
assert (
|
||||
len(postfix_seen) < 8
|
||||
), f"post-fix loop must not drain all chunks; got {len(postfix_seen)}"
|
||||
|
||||
|
||||
def test_unsloth_stream_loop_emits_zero_tokens_on_preset_cancel():
|
||||
# Pending-cancel replay: _TrackedCancel.__enter__ already set
|
||||
# cancel_event before the generator body starts iterating. The
|
||||
# top-of-loop check must short-circuit the very first iteration so
|
||||
# no token is emitted. Catches a regression that moves the check
|
||||
# below `next()` -- the mid-loop test would still pass but this
|
||||
# test would observe one extra token leak.
|
||||
cancel_event = threading.Event()
|
||||
cancel_event.set()
|
||||
reset_calls = [0]
|
||||
|
||||
class _Backend:
|
||||
def reset_generation_state(self):
|
||||
reset_calls[0] += 1
|
||||
|
||||
backend = _Backend()
|
||||
|
||||
next_calls = [0]
|
||||
|
||||
def _generate():
|
||||
while True:
|
||||
next_calls[0] += 1
|
||||
yield f"cum-{next_calls[0]}"
|
||||
|
||||
async def _loop():
|
||||
_DONE = object()
|
||||
loop = asyncio.get_event_loop()
|
||||
gen = _generate()
|
||||
seen = []
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
backend.reset_generation_state()
|
||||
break
|
||||
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
|
||||
if cumulative is _DONE:
|
||||
break
|
||||
seen.append(cumulative)
|
||||
return seen
|
||||
|
||||
seen = asyncio.run(_loop())
|
||||
assert seen == [], (
|
||||
f"loop must emit zero tokens when cancel_event is pre-set "
|
||||
f"(pending-replay path); got {seen}"
|
||||
)
|
||||
assert next_calls[0] == 0, (
|
||||
f"loop must not call next() at all on pre-set cancel; got "
|
||||
f"{next_calls[0]} calls"
|
||||
)
|
||||
assert reset_calls[0] == 1, (
|
||||
f"backend.reset_generation_state() must still fire exactly once "
|
||||
f"on pre-set cancel; got {reset_calls[0]}"
|
||||
)
|
||||
|
||||
|
||||
def test_audio_stream_emits_zero_chunks_on_preset_cancel():
|
||||
# Symmetric to the Unsloth pre-set test: the audio loop's top-of-loop
|
||||
# cancel check must skip the asyncio.to_thread(next, ...) call when
|
||||
# cancel_event was already set via pending-replay.
|
||||
cancel_event = threading.Event()
|
||||
cancel_event.set()
|
||||
|
||||
next_calls = [0]
|
||||
|
||||
def _audio_gen():
|
||||
while True:
|
||||
next_calls[0] += 1
|
||||
yield f"chunk-{next_calls[0]}"
|
||||
|
||||
async def _loop():
|
||||
_DONE = object()
|
||||
gen = _audio_gen()
|
||||
seen = []
|
||||
while True:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
chunk_text = await asyncio.to_thread(next, gen, _DONE)
|
||||
if chunk_text is _DONE:
|
||||
break
|
||||
seen.append(chunk_text)
|
||||
return seen
|
||||
|
||||
seen = asyncio.run(_loop())
|
||||
assert seen == [], f"audio loop must emit zero chunks on pre-set cancel; got {seen}"
|
||||
assert next_calls[0] == 0, (
|
||||
f"audio loop must not call next() on pre-set cancel; got "
|
||||
f"{next_calls[0]} calls"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue