mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
* Studio: make stop button actually stop generation The UI stop button routes through assistant-ui's cancelRun, which aborts the frontend fetch. Four issues combined to let llama-server keep decoding long after the user clicked stop: 1. request.is_disconnected() does not fire reliably behind proxies (e.g. Colab) that don't propagate fetch aborts. 2. llama-server defaults n_predict to n_ctx when max_tokens is not sent, so a cancelled request keeps producing tokens up to 262144. 3. The httpx.Client pool keeps TCP keep-alive, so even a cleanly closed stream reuses the same connection and llama-server's liveness poll never sees a disconnect. 4. No explicit backend route to cancel - every cancel path relied on is_disconnected. Changes: - Add POST /api/inference/cancel keyed by session_id/completion_id, with a registry populated for the lifetime of each streaming response. - Have the frontend (chat-adapter.ts) POST /inference/cancel on AbortController abort, alongside the existing fetch teardown. - Send max_tokens=4096 + t_max_predict_ms=120000 as defaults on every outbound chat completion to llama-server; honoured by user overrides. - Disable httpx keep-alive on the streaming client so connection close reaches llama-server and its 1s liveness check fires. No behaviour changes for non-streaming paths or for existing callers that already pass max_tokens/session_id. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel path and scope cancel route - Require at least one identifier for /api/inference/cancel so a missing thread id cannot silently cancel every in-flight generation. - Scope /cancel to a dedicated studio_router so it is not exposed under the /v1 OpenAI-compat prefix as a surprise endpoint. - Store a set of cancel events per key in _CANCEL_REGISTRY so concurrent requests on the same session_id do not overwrite each other, and deduplicate in _cancel_by_keys so the cancelled count reflects unique requests. - Always send session_id with chat completions (not only when tools are enabled) so non-tool GGUF streams register under it and are reachable from /cancel. - Register the non-GGUF stream_chunks path in the cancel registry too, so transformers-based stop-button works behind proxies that swallow fetch aborts. - Only apply the 2-minute t_max_predict_ms wall-clock cap when the caller did not pass max_tokens, so legitimate long generations on slow CPU/macOS/Windows supported installs are not silently truncated. - Remove the abort listener on normal stream completion so reused AbortSignals cannot fire a spurious cancel POST after the fact. * studio: close cancel-race and stale-cancel gaps in stop path - Register the cancel tracker before returning StreamingResponse so a stop POST that arrives during prefill / warmup / proxy buffering finds an entry in _CANCEL_REGISTRY. Cleanup now runs via a Starlette BackgroundTask instead of a finally inside the async generator body. - Add a per-run cancel_id on the frontend (crypto.randomUUID) and in ChatCompletionRequest so /api/inference/cancel matches one specific generation. Removes the stale-cancel bug where pressing stop then starting a new run in the same thread would cancel the retry. - Apply t_max_predict_ms unconditionally in all three llama-server payload builders (previously gated on max_tokens=None, which made it dead code for UI callers that always send params.maxTokens). Raise the default to 10 minutes so slow CPU / macOS / Windows installs are not cut off mid-generation. - Make _cancel_by_keys refuse empty input (return 0) so a future internal caller can not accidentally mass-cancel every in-flight request. - Accept cancel_id (primary), session_id, and completion_id on the /api/inference/cancel route. Unify the three streaming sites on the same _cancel_keys / _tracker variable names. - Annotate _CANCEL_REGISTRY as dict[str, set[threading.Event]]. * Add review tests for PR #5069 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: harden stop-button cancel semantics and wall-clock cap - Make /inference/cancel match cancel_id EXCLUSIVELY when supplied. Previously the handler iterated ('cancel_id','session_id','completion_id') and unioned matches, so a stale cancel POST carrying {cancel_id:old, session_id:thr} would still cancel a later run on the same thread via the shared session_id. cancel_id is now a per-run exclusive key; session_id / completion_id are only used as fallbacks when cancel_id is absent. - Close the early-cancel race. If /inference/cancel lands before the streaming handler reaches _TrackedCancel.__enter__() (stop clicked during prefill / warmup / proxy buffering), the cancel was silently dropped. Stash unmatched cancel_ids in _PENDING_CANCELS with a 30 s TTL; _TrackedCancel.__enter__() now replays any matching pending cancel by set()-ing the event immediately after registration. - Make t_max_predict_ms = _DEFAULT_T_MAX_PREDICT_MS conditional on max_tokens is None at all three llama-server payload sites. The cap is a safety net for callers who leave max_tokens unset (otherwise llama-server defaults n_predict to n_ctx, up to 262144). Callers who set an explicit max_tokens are already self-limiting and must not be silently truncated at 10 minutes on slow CPU / macOS / Windows legitimate long generations. - Guard each StreamingResponse return with try/except BaseException so _tracker.__exit__ runs even if StreamingResponse construction or any preceding statement raises between _tracker.__enter__() and the BackgroundTask attachment. Prevents a registry leak on that narrow window. * studio: close TOCTOU race and restore wall-clock backstop on UI path - Close TOCTOU race in the pending-cancel mechanism. The previous fix split cancel_inference's (cancel_by_keys + remember_pending_cancel) and _TrackedCancel.__enter__'s (register + consume_pending) into four separate lock acquisitions. Under contention a cancel POST could acquire-then-release the lock, find the registry empty, and stash ONLY AFTER __enter__ had already registered and consumed an empty pending map -- silently dropping the cancel. Both call sites now do their work inside a single _CANCEL_LOCK critical section, via the new atomic helper _cancel_by_cancel_id_or_stash() and an inlined consume-pending step in __enter__. Reproduced the race under forced interleaving pre-fix; 0/2000 drops post-fix under parallel stress. - Apply t_max_predict_ms UNCONDITIONALLY at all three llama-server payload sites. The previous iteration gated the cap on `max_tokens is None`, which turned out to be dead code on the primary Studio UI path: chat-adapter.ts sets maxTokens=loadResp.context_length after every model load, so every chat request carries an explicit max_tokens and the wall-clock safety net never fired. The cap's original purpose is to bound stuck decodes regardless of the token budget; it must always apply. - Raise _DEFAULT_T_MAX_PREDICT_MS from 10 minutes to 1 hour. 10 minutes was too aggressive for legitimate slow-CPU chat responses (a 4096-token reply at 2 tok/s takes ~34 min); 1 hour accommodates that and still catches genuine zombie decodes. - Prune _PENDING_CANCELS inside _cancel_by_keys as well, so stashed entries expire proportionally to overall cancel traffic rather than only to cancel_id-specific POSTs. * studio: trim verbose comments and docstrings in cancel path * studio/llama_cpp: drop upstream PR hashes from benchmark comment * Add review tests for Studio stop button * Consolidate review tests for Studio stop button * Align cancel-route test with exclusive cancel_id semantics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: move cancel cleanup to generator finally; drop dead helper - Move _tracker.__exit__ from Starlette BackgroundTask into each streaming generator's finally block. Starlette skips the background callback when stream_response raises (OSError / ClientDisconnect), which leaked _CANCEL_REGISTRY entries on abrupt disconnect. - Check cancel_event.is_set() at the top of each GGUF while loop so a pending-replay cancel falls through to final_chunk + [DONE] instead of propagating GeneratorExit out of _stream_with_retry. - Remove unused _remember_pending_cancel; _cancel_by_cancel_id_or_stash superseded it. * Add review tests for Studio stop-button * studio: wire audio-input stream into cancel registry - Register cancel_event with _TrackedCancel on the audio-input streaming path so POST /api/inference/cancel can stop whisper / audio-input GGUF runs. Previously the registry stayed empty on this branch, so the stop button returned {"cancelled":0} and the decode ran to completion. - Apply the same finally-based cleanup and pre-iteration cancel-event check used on the other three streaming paths. - Update the _CANCEL_REGISTRY block comment to list cancel_id as the primary key (was stale "session_id preferred"). * Consolidate review tests for Studio stop-button cancel flow - Merge the 6 behavioral tests from test_stream_cleanup_on_disconnect.py (finally cleanup on normal/exception/aclose, pre-set cancel_event pattern, and its regressions) into test_stream_cancel_registration_timing.py, which is the PR's existing file covering the same area. - Extend structural invariants to include audio_input_stream alongside the three GGUF / Unsloth streaming generators: no _tracker.__enter__ inside the async gen body, cleanup via try/finally, no background= on StreamingResponse. - Delete test_stream_cleanup_on_disconnect.py (now empty). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: make cancel-via-POST interrupt Unsloth and audio-input streams Close two remaining gaps in the stop-button cancellation wiring: - stream_chunks (Unsloth path): add a top-of-loop cancel_event check and call backend.reset_generation_state() so cancel POSTs flush GPU state and close the SSE cleanly instead of relying on request.is_disconnected (which does not fire through proxies like Colab's). - audio_input_stream: run the synchronous audio_input_generate() via asyncio.to_thread so blocking whisper chunks do not freeze the event loop, matching the pattern already used by the GGUF streaming paths. * Add review tests for Studio stop-button cancel flow * Consolidate review tests for Studio stop-button cancel flow - Delete standalone test_cancel_registry.py at repo root: tests duplicated test_cancel_atomicity.py / test_cancel_id_wiring.py and re-implemented registry primitives inline (scaffolding). - Extend tests/studio/test_stream_cancel_registration_timing.py with regression guards for the iter-1 cancel-loop fixes: structural: each streaming generator checks cancel_event in its loop; audio_input_stream offloads next() via asyncio.to_thread; stream_chunks cancel branch calls reset_generation_state(). runtime: Unsloth loop breaks on external cancel and resets state; audio loop stays responsive under blocking next(); both loops emit zero tokens on pre-set cancel (replay path). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: extend stop-path to passthrough streams; tighten wall-clock cap - Lower _DEFAULT_T_MAX_PREDICT_MS from 1 hour to 10 minutes so the wall-clock backstop actually bounds runaway decodes when cancel signaling fails. - Wire _TrackedCancel and cancel_event.is_set() into _openai_passthrough_stream and _anthropic_passthrough_stream and disable httpx keepalive so stop requests from /v1 and /v1/messages tool-calling clients reach llama-server. - Apply t_max_predict_ms to the tool-passthrough request body so the backstop covers passthrough paths as well. - Symmetric pre-registration stash for session_id/completion_id cancels (_cancel_by_keys_or_stash) so early cancels by those keys replay on later registration like cancel_id. - Drop dead except BaseException guards around StreamingResponse() at four streaming sites; cleanup lives in the generator's finally. * studio: harden cancel registry against ghost-cancel and leak paths - Revert the session_id/completion_id stash in the fallback cancel helper. session_id is thread-scoped and reused across runs, so stashing it on an unmatched POST would fire cancel_event for the user's next unrelated request via _TrackedCancel.__enter__. cancel_id remains the only per-run unique key that gets stashed. - Default max_tokens to _DEFAULT_MAX_TOKENS in the tool-passthrough body. Mirror the direct GGUF path so OpenAI/Anthropic passthrough callers who omit max_tokens get the same zombie-decode cap instead of relying on the wall-clock backstop alone. - Wrap _openai_passthrough_stream setup with an outer try/except BaseException. The inner except httpx.RequestError does not catch asyncio.CancelledError at await client.send, which would otherwise leave _tracker registered in _CANCEL_REGISTRY indefinitely. - Frontend stop POST uses plain fetch + manual Authorization header instead of authFetch. A 401 on the cancel POST no longer refreshes tokens or redirects the user to the login page mid-stop. * Add review tests for Studio stop-button cancel flow * studio: trim comments on stop-button review changes Collapse multi-paragraph rationale blocks on the cancel registry, _openai_passthrough_stream, and the frontend onAbortCancel handler into one-line explanations of why the non-obvious behaviour exists. Drop authFetch import that became unused when the cancel POST switched to plain fetch. * Consolidate review tests for Studio stop-button cancel flow Move review-added tests out of test_cancel_dispatch_edges.py into the existing PR test files that already cover the same areas: - backend registry fan-out / exclusivity / idempotency / falsy-keys edge cases moved into tests/studio/test_cancel_atomicity.py - frontend plain-fetch (not authFetch) + manual Authorization header moved into tests/studio/test_cancel_id_wiring.py Delete the now-empty test_cancel_dispatch_edges.py. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Studio: stop default-capping responses at 4096 tokens (follow-up to #5069) (#5174) * Studio: stop default-capping responses at 4096 tokens Follow-up to #5069. The 4096 default introduced for runaway-decode defense silently truncates any caller that omits max_tokens. The Studio chat UI sets params.maxTokens = loadResp.context_length after a GGUF load, so it's fine, but every other consumer is not: - OpenAI-API direct callers (/v1/chat/completions, /v1/responses, /v1/messages, /v1/completions) where the OpenAI default is effectively unlimited per response. langchain, llama-index, raw curl, and the openai SDK all rely on that. - Reasoning models. Qwen3 / gpt-oss reasoning traces routinely exceed 4096 tokens before the model emits a single visible content token. The user sees the trace cut off mid-thought. - Long-form generation ("write a chapter", "produce a full SVG"). Reproduced on this branch: gemma-4-E2B-it-GGUF Q8_0, prompt asking for a 10000-word story, no max_tokens in the request: finish_reason: stop (misleading -- should be 'length') content_chars: 19772 content_tail: ...'a comforting, yet immense, pressure.\n\n*"' Body ended mid-sentence on a stray opening quote, right at the 4096 token mark. After this patch the same request returns 38357 chars ending with '...held in a perfect, dynamic equilibrium.' -- a natural stop, not a truncation. Implementation: rename the constant to _DEFAULT_MAX_TOKENS_FLOOR and set it to 32768. Each call site now uses the model's effective context length when known, falling back to the floor: default_cap = self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR The 10-minute t_max_predict_ms wall-clock backstop from #5069 is preserved as the second line of defense. Plumbed _build_passthrough_payload + _build_openai_passthrough_body through the routes layer so the Anthropic and OpenAI passthrough paths also respect the model's context length. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Studio: cancel passthrough streams during llama-server prefill + route through apiUrl for Tauri Three reviewer-flagged correctness gaps in the stop-button mechanism. 1) `_openai_passthrough_stream` could not honor cancel during prefill. The cancel check ran inside the `async for raw_line in lines_iter` body, so a cancel POST that arrived before llama-server emitted the first SSE line was unobservable until prefill completed. With a long prompt under proxy/Colab conditions -- the exact target scenario for this PR -- that left the model decoding for a long time after the user clicked Stop. Add an asyncio watcher task that closes `resp` as soon as `cancel_event` is set, raising in `aiter_lines` so the generator can exit. The watcher polls a threading.Event because the cancel registry is keyed by threading.Event for the synchronous /cancel handler. 2) `_anthropic_passthrough_stream` had the same blocking-prefill pattern. Same fix. 3) The frontend's stop-button cancel POST used a bare relative `fetch("/api/inference/cancel", ...)`, which targets the webview origin in Tauri production builds (where the backend is at `http://127.0.0.1:8888`). Route through the existing `apiUrl()` helper from `lib/api-base.ts` to match every other Studio call. Browser/dev builds get the empty base, so behavior is unchanged there. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: cancel during prefill terminates within ~250ms on both passthrough paths (was 145s+ on the Anthropic path before this change), and the standard non-passthrough chat path still cancels with no regression. * Studio: log cancel-body parse errors instead of silently swallowing Reviewer-flagged defensive logging gap. The bare `except Exception: pass` in `cancel_inference` would mask malformed payloads that hint at a buggy client or a transport issue. Log at debug so future investigation isn't left guessing whether `body={}` came from a missing body or a parse failure. Behavior is unchanged: an unparseable body still falls through to the empty-dict path and the cancel call returns `{"cancelled": 0}`. * Studio: Anthropic passthrough cancel parity with OpenAI passthrough Two reviewer-flagged consistency gaps in the cancel surface for /v1/messages. 1) Anthropic passthrough did not register cancel_id, so a per-run cancel POST (the cleanest Studio-style cancel path) silently missed when the route hit `_anthropic_passthrough_stream`. The OpenAI passthrough has registered (cancel_id, session_id, completion_id) since this PR was first opened; mirror that here. Also add `cancel_id` to `AnthropicMessagesRequest` so the route handler can plumb it through. 2) The cancel handler's fallback key list checked only completion_id and session_id, never message_id. Anthropic clients that send their native `id` (returned in the SSE message_start event) for cancel had no way to hit the registry. Add message_id to the fallback list. Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: P2 now cancels by cancel_id in 137ms (was hanging pre-fix), and the new P2b case cancels by message_id in 77ms. P1 (OpenAI) and P3 (standard chat) still pass with no regression. --------- Co-authored-by: danielhanchen <michaelhan2050@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
718 lines
24 KiB
Python
718 lines
24 KiB
Python
"""
|
|
Tests that the cancel tracker is registered BEFORE StreamingResponse is
|
|
returned, and that cleanup runs via a `finally` block inside each
|
|
async generator.
|
|
|
|
The zombie-generation scenario is: user clicks Stop during prefill /
|
|
warmup / proxy buffering, before the first SSE chunk. If _tracker
|
|
__enter__ lives inside the async generator body, the registry is empty
|
|
at the moment /api/inference/cancel lands -- so cancel returns 0 and
|
|
the decode runs to completion.
|
|
|
|
The fix moves _tracker = _TrackedCancel(...) and _tracker.__enter__()
|
|
to the synchronous body of openai_chat_completions (before the
|
|
StreamingResponse is returned) and places _tracker.__exit__ inside
|
|
each generator's `finally` block. Using a generator `finally` (rather
|
|
than a Starlette BackgroundTask) guarantees cleanup on every
|
|
termination path -- normal exhaustion, CancelledError from
|
|
ClientDisconnect, and OSError / BrokenPipeError during send() --
|
|
because Starlette skips `background` callbacks when stream_response
|
|
raises.
|
|
|
|
Structural verifies:
|
|
- No `async def ...:` body contains `_tracker.__enter__()` in
|
|
routes/inference.py (registration moved to sync body).
|
|
- Each of the four async generators (gguf_tool_stream,
|
|
gguf_stream_chunks, stream_chunks, audio_input_stream) contains
|
|
`_tracker.__exit__(None, None, None)` inside a try/finally block.
|
|
- No StreamingResponse in openai_chat_completions passes
|
|
`background=` (cleanup now lives in the generator finally).
|
|
|
|
Behavioral verifies (extracting `_TrackedCancel` from source and
|
|
exercising the actual runtime semantics):
|
|
- `finally: _tracker.__exit__(...)` runs on normal completion,
|
|
mid-stream exception (OSError / BrokenPipeError from send()),
|
|
and aclose() from Starlette ClientDisconnect.
|
|
- A pre-set cancel_event (from `_TrackedCancel.__enter__` replaying
|
|
a pending cancel POST) lets the GGUF while-loop break cleanly
|
|
and emit final_chunk + [DONE] instead of propagating
|
|
`GeneratorExit` out of `_stream_with_retry` into the async
|
|
generator's `except Exception` (which would not catch it).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
|
|
SOURCE_PATH = (
|
|
Path(__file__).resolve().parents[2]
|
|
/ "studio"
|
|
/ "backend"
|
|
/ "routes"
|
|
/ "inference.py"
|
|
)
|
|
SRC = SOURCE_PATH.read_text()
|
|
_TREE = ast.parse(SRC)
|
|
|
|
|
|
# ── Structural (AST) helpers ─────────────────────────────────
|
|
|
|
|
|
def _collect_async_functions(tree: ast.AST):
|
|
return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)]
|
|
|
|
|
|
def _has_tracker_enter_call(node: ast.AST) -> bool:
|
|
for sub in ast.walk(node):
|
|
if not isinstance(sub, ast.Call):
|
|
continue
|
|
fn = sub.func
|
|
if (
|
|
isinstance(fn, ast.Attribute)
|
|
and fn.attr == "__enter__"
|
|
and isinstance(fn.value, ast.Name)
|
|
and fn.value.id.startswith("_tracker")
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _finalbody_has_tracker_exit(finalbody) -> bool:
|
|
for stmt in finalbody:
|
|
if not isinstance(stmt, ast.Expr):
|
|
continue
|
|
call = stmt.value
|
|
if not (isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute)):
|
|
continue
|
|
fn = call.func
|
|
if (
|
|
fn.attr == "__exit__"
|
|
and isinstance(fn.value, ast.Name)
|
|
and fn.value.id.startswith("_tracker")
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
# ── Structural tests ─────────────────────────────────────────
|
|
|
|
|
|
def test_no_tracker_enter_inside_async_generators():
|
|
offenders = []
|
|
for fn in _collect_async_functions(_TREE):
|
|
if fn.name in {
|
|
"gguf_tool_stream",
|
|
"gguf_stream_chunks",
|
|
"stream_chunks",
|
|
"audio_input_stream",
|
|
}:
|
|
if _has_tracker_enter_call(fn):
|
|
offenders.append(fn.name)
|
|
assert not offenders, (
|
|
f"Cancel tracker registration must live OUTSIDE the async generator "
|
|
f"body so a stop POST can find the registry entry before the first "
|
|
f"SSE chunk. Offending generators: {offenders}"
|
|
)
|
|
|
|
|
|
def test_tracker_enter_exists_in_sync_body_of_chat_completions():
|
|
top = None
|
|
for n in ast.walk(_TREE):
|
|
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
|
top = n
|
|
break
|
|
assert top is not None, "openai_chat_completions handler missing"
|
|
count = 0
|
|
for sub in ast.walk(top):
|
|
if not isinstance(sub, ast.Call):
|
|
continue
|
|
fn = sub.func
|
|
if (
|
|
isinstance(fn, ast.Attribute)
|
|
and fn.attr == "__enter__"
|
|
and isinstance(fn.value, ast.Name)
|
|
and fn.value.id.startswith("_tracker")
|
|
):
|
|
count += 1
|
|
assert count >= 3, (
|
|
f"expected >=3 _tracker.__enter__() calls in openai_chat_completions "
|
|
f"(one per streaming path), got {count}"
|
|
)
|
|
|
|
|
|
def test_async_generators_cleanup_tracker_in_finally():
|
|
required = {
|
|
"gguf_tool_stream",
|
|
"gguf_stream_chunks",
|
|
"stream_chunks",
|
|
"audio_input_stream",
|
|
}
|
|
found: set[str] = set()
|
|
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
|
|
if fn.name not in required:
|
|
continue
|
|
for sub in ast.walk(fn):
|
|
if isinstance(sub, ast.Try) and sub.finalbody:
|
|
if _finalbody_has_tracker_exit(sub.finalbody):
|
|
found.add(fn.name)
|
|
break
|
|
missing = required - found
|
|
assert not missing, (
|
|
f"Cleanup must run via `finally: _tracker.__exit__(None, None, None)` "
|
|
f"inside each streaming generator so ClientDisconnect / OSError paths "
|
|
f"also release registry entries (Starlette skips `background` callbacks "
|
|
f"when stream_response raises). Missing in: {sorted(missing)}"
|
|
)
|
|
|
|
|
|
def test_streaming_responses_have_no_background_task():
|
|
top = None
|
|
for n in ast.walk(_TREE):
|
|
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
|
top = n
|
|
break
|
|
assert top is not None
|
|
for sub in ast.walk(top):
|
|
if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)):
|
|
continue
|
|
if sub.func.id != "StreamingResponse":
|
|
continue
|
|
kwargs = {kw.arg for kw in sub.keywords if kw.arg}
|
|
assert "background" not in kwargs, (
|
|
"StreamingResponse in openai_chat_completions must not pass "
|
|
"`background=` -- cleanup now lives in the generator's finally "
|
|
"block; a BackgroundTask would be skipped on abrupt disconnect"
|
|
)
|
|
|
|
|
|
# ── Behavioral helpers ───────────────────────────────────────
|
|
|
|
_WANTED = {
|
|
"_CANCEL_REGISTRY",
|
|
"_CANCEL_LOCK",
|
|
"_PENDING_CANCELS",
|
|
"_PENDING_CANCEL_TTL_S",
|
|
"_prune_pending",
|
|
"_TrackedCancel",
|
|
"_cancel_by_keys",
|
|
"_cancel_by_cancel_id_or_stash",
|
|
}
|
|
|
|
|
|
def _load_registry_module():
|
|
chunks = []
|
|
for n in _TREE.body:
|
|
seg = ast.get_source_segment(SRC, n)
|
|
if seg is None:
|
|
continue
|
|
if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED:
|
|
chunks.append(seg)
|
|
elif isinstance(n, ast.Assign):
|
|
names = [t.id for t in n.targets if isinstance(t, ast.Name)]
|
|
if any(name in _WANTED for name in names):
|
|
chunks.append(seg)
|
|
elif (
|
|
isinstance(n, ast.AnnAssign)
|
|
and isinstance(n.target, ast.Name)
|
|
and n.target.id in _WANTED
|
|
):
|
|
chunks.append(seg)
|
|
mod = {}
|
|
exec("import threading, time\n" + "\n\n".join(chunks), mod)
|
|
return mod
|
|
|
|
|
|
def _make_stream(tracker, raise_exc):
|
|
async def gen():
|
|
try:
|
|
try:
|
|
yield "data: first\n\n"
|
|
if raise_exc is not None:
|
|
raise raise_exc
|
|
yield "data: [DONE]\n\n"
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
yield "data: error\n\n"
|
|
finally:
|
|
tracker.__exit__(None, None, None)
|
|
except BaseException:
|
|
raise
|
|
|
|
return gen()
|
|
|
|
|
|
async def _consume(agen):
|
|
out = []
|
|
try:
|
|
async for ch in agen:
|
|
out.append(ch)
|
|
except BaseException as e:
|
|
out.append(type(e).__name__)
|
|
return out
|
|
|
|
|
|
def _llama_stub_raises_on_preset_cancel(cancel_event):
|
|
# Reproduces llama_cpp.py _stream_with_retry:2240 `raise GeneratorExit`
|
|
# when cancel_event is already set at entry.
|
|
if cancel_event.is_set():
|
|
raise GeneratorExit
|
|
yield "cumulative-1"
|
|
yield "cumulative-2"
|
|
|
|
|
|
async def _post_fix_gguf_loop(cancel_event):
|
|
yield "first_chunk"
|
|
gen = _llama_stub_raises_on_preset_cancel(cancel_event)
|
|
sentinel = object()
|
|
while True:
|
|
if cancel_event.is_set():
|
|
break
|
|
cumulative = await asyncio.to_thread(next, gen, sentinel)
|
|
if cumulative is sentinel:
|
|
break
|
|
yield cumulative
|
|
yield "final_chunk"
|
|
yield "[DONE]"
|
|
|
|
|
|
# ── Behavioral tests ─────────────────────────────────────────
|
|
|
|
|
|
def test_finally_cleanup_on_normal_completion():
|
|
m = _load_registry_module()
|
|
m["_CANCEL_REGISTRY"].clear()
|
|
ev = threading.Event()
|
|
tr = m["_TrackedCancel"](ev, "cid-ok", "sid-ok")
|
|
tr.__enter__()
|
|
assert "cid-ok" in m["_CANCEL_REGISTRY"]
|
|
chunks = asyncio.run(_consume(_make_stream(tr, None)))
|
|
assert chunks == ["data: first\n\n", "data: [DONE]\n\n"]
|
|
assert "cid-ok" not in m["_CANCEL_REGISTRY"]
|
|
assert "sid-ok" not in m["_CANCEL_REGISTRY"]
|
|
|
|
|
|
def test_finally_cleanup_on_mid_stream_exception():
|
|
# Simulates OSError / BrokenPipeError from Starlette send() mid-stream --
|
|
# the exact case where pre-fix `background = BackgroundTask(...)` was
|
|
# skipped and leaked the registry entry.
|
|
m = _load_registry_module()
|
|
m["_CANCEL_REGISTRY"].clear()
|
|
ev = threading.Event()
|
|
tr = m["_TrackedCancel"](ev, "cid-err", "sid-err")
|
|
tr.__enter__()
|
|
assert "cid-err" in m["_CANCEL_REGISTRY"]
|
|
asyncio.run(_consume(_make_stream(tr, OSError("disconnect"))))
|
|
assert "cid-err" not in m["_CANCEL_REGISTRY"]
|
|
assert "sid-err" not in m["_CANCEL_REGISTRY"]
|
|
|
|
|
|
def test_finally_cleanup_on_aclose():
|
|
# Starlette calls aclose() on the async generator when the client
|
|
# disconnects mid-stream. The generator's finally block must run.
|
|
m = _load_registry_module()
|
|
m["_CANCEL_REGISTRY"].clear()
|
|
ev = threading.Event()
|
|
tr = m["_TrackedCancel"](ev, "cid-abort", "sid-abort")
|
|
tr.__enter__()
|
|
assert "cid-abort" in m["_CANCEL_REGISTRY"]
|
|
|
|
async def run():
|
|
gen = _make_stream(tr, None)
|
|
it = gen.__aiter__()
|
|
await it.__anext__()
|
|
await gen.aclose()
|
|
|
|
asyncio.run(run())
|
|
assert "cid-abort" not in m["_CANCEL_REGISTRY"]
|
|
assert "sid-abort" not in m["_CANCEL_REGISTRY"]
|
|
|
|
|
|
def test_preset_cancel_event_exits_cleanly_with_done():
|
|
# Pending-replay: POST /cancel arrived before the stream registered,
|
|
# was stashed, then consumed by _TrackedCancel.__enter__ which set
|
|
# cancel_event. The generator must break out of the loop cleanly
|
|
# and emit final_chunk + [DONE] rather than calling next(gen) and
|
|
# propagating `GeneratorExit` out of the GGUF stream wrapper.
|
|
ev = threading.Event()
|
|
ev.set()
|
|
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
|
|
assert "first_chunk" in chunks
|
|
assert "final_chunk" in chunks
|
|
assert "[DONE]" in chunks
|
|
assert "GeneratorExit" not in chunks
|
|
assert "cumulative-1" not in chunks
|
|
assert "cumulative-2" not in chunks
|
|
|
|
|
|
def test_normal_path_streams_all_tokens():
|
|
# Regression: the top-of-loop cancel_event check must not short-circuit
|
|
# when cancel_event is unset.
|
|
ev = threading.Event()
|
|
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
|
|
assert chunks == [
|
|
"first_chunk",
|
|
"cumulative-1",
|
|
"cumulative-2",
|
|
"final_chunk",
|
|
"[DONE]",
|
|
]
|
|
|
|
|
|
def test_cancel_during_streaming_stops_iteration_promptly():
|
|
# Setting cancel_event between yields breaks out on the next iteration
|
|
# rather than draining the stub generator.
|
|
ev = threading.Event()
|
|
|
|
async def _run():
|
|
gen = _post_fix_gguf_loop(ev)
|
|
seen = []
|
|
async for ch in gen:
|
|
seen.append(ch)
|
|
if ch == "cumulative-1":
|
|
ev.set()
|
|
return seen
|
|
|
|
seen = asyncio.run(_run())
|
|
assert "first_chunk" in seen
|
|
assert "cumulative-1" in seen
|
|
assert "cumulative-2" not in seen
|
|
assert "final_chunk" in seen
|
|
assert "[DONE]" in seen
|
|
|
|
|
|
# ── Cancel-event responsiveness in the streaming loops ───────
|
|
|
|
|
|
def _loop_has_cancel_event_check(fn) -> bool:
|
|
# An `if cancel_event.is_set():` statement anywhere inside a
|
|
# `while`/`for` loop body is sufficient -- without it, a cancel POST
|
|
# cannot interrupt the loop because Colab-style proxies do not
|
|
# propagate request.is_disconnected().
|
|
for sub in ast.walk(fn):
|
|
if not isinstance(sub, (ast.While, ast.For, ast.AsyncFor)):
|
|
continue
|
|
for stmt in ast.walk(sub):
|
|
if not isinstance(stmt, ast.If):
|
|
continue
|
|
t = stmt.test
|
|
if (
|
|
isinstance(t, ast.Call)
|
|
and isinstance(t.func, ast.Attribute)
|
|
and t.func.attr == "is_set"
|
|
and isinstance(t.func.value, ast.Name)
|
|
and t.func.value.id == "cancel_event"
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def test_streaming_generators_check_cancel_event_in_loop():
|
|
required = {
|
|
"gguf_tool_stream",
|
|
"gguf_stream_chunks",
|
|
"stream_chunks",
|
|
"audio_input_stream",
|
|
}
|
|
missing = []
|
|
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
|
|
if fn.name not in required:
|
|
continue
|
|
if not _loop_has_cancel_event_check(fn):
|
|
missing.append(fn.name)
|
|
assert not missing, (
|
|
f"Each streaming generator must check `cancel_event.is_set()` inside "
|
|
f"its main loop so `POST /api/inference/cancel` can interrupt the "
|
|
f"stream through proxies that do not forward fetch aborts. "
|
|
f"Missing in: {sorted(missing)}"
|
|
)
|
|
|
|
|
|
def test_audio_input_stream_offloads_blocking_next_to_thread():
|
|
# Guards against regression back to `for chunk_text in
|
|
# audio_input_generate():` -- which blocks the event loop on each
|
|
# whisper chunk and prevents POST /api/inference/cancel from being
|
|
# serviced until the chunk yields.
|
|
audio = None
|
|
for fn in ast.walk(_TREE):
|
|
if isinstance(fn, ast.AsyncFunctionDef) and fn.name == "audio_input_stream":
|
|
audio = fn
|
|
break
|
|
assert audio is not None, "audio_input_stream generator missing"
|
|
|
|
for sub in ast.walk(audio):
|
|
if isinstance(sub, (ast.For, ast.AsyncFor)):
|
|
it_src = ast.unparse(sub.iter)
|
|
assert "audio_input_generate" not in it_src, (
|
|
"audio_input_stream must not iterate audio_input_generate() "
|
|
"directly -- that blocks the event loop. Use "
|
|
"`await asyncio.to_thread(next, gen, _DONE)` inside a "
|
|
"`while True` loop instead"
|
|
)
|
|
|
|
found_to_thread_next = False
|
|
for sub in ast.walk(audio):
|
|
if not isinstance(sub, ast.Call):
|
|
continue
|
|
fn_expr = sub.func
|
|
if not (
|
|
isinstance(fn_expr, ast.Attribute)
|
|
and fn_expr.attr == "to_thread"
|
|
and isinstance(fn_expr.value, ast.Name)
|
|
and fn_expr.value.id == "asyncio"
|
|
):
|
|
continue
|
|
if sub.args and isinstance(sub.args[0], ast.Name) and sub.args[0].id == "next":
|
|
found_to_thread_next = True
|
|
break
|
|
assert found_to_thread_next, (
|
|
"audio_input_stream must call `asyncio.to_thread(next, gen, ...)` "
|
|
"to keep the event loop free while whisper yields the next chunk"
|
|
)
|
|
|
|
|
|
def test_stream_chunks_cancel_branch_resets_backend_state():
|
|
# The Unsloth path's cancel branch must flush GPU / KV-cache state
|
|
# via `backend.reset_generation_state()` -- the orchestrator's
|
|
# internal cancel path does not do this, so a cancel-via-POST that
|
|
# only broke the loop would leave the subprocess in a dirty state
|
|
# for the next request.
|
|
fn = None
|
|
top = None
|
|
for n in ast.walk(_TREE):
|
|
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
|
|
top = n
|
|
break
|
|
assert top is not None
|
|
for n in ast.walk(top):
|
|
if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks":
|
|
fn = n
|
|
break
|
|
assert fn is not None, "stream_chunks generator missing"
|
|
|
|
for sub in ast.walk(fn):
|
|
if not isinstance(sub, ast.If):
|
|
continue
|
|
t = sub.test
|
|
if not (
|
|
isinstance(t, ast.Call)
|
|
and isinstance(t.func, ast.Attribute)
|
|
and t.func.attr == "is_set"
|
|
and isinstance(t.func.value, ast.Name)
|
|
and t.func.value.id == "cancel_event"
|
|
):
|
|
continue
|
|
body_src = "\n".join(ast.unparse(s) for s in sub.body)
|
|
if "backend.reset_generation_state()" in body_src:
|
|
return
|
|
raise AssertionError(
|
|
"stream_chunks `if cancel_event.is_set():` branch must call "
|
|
"backend.reset_generation_state() -- matches the existing "
|
|
"request.is_disconnected() / CancelledError cleanup paths and "
|
|
"prevents KV-cache drift after cancel-via-POST"
|
|
)
|
|
|
|
|
|
# ── Behavioral simulations for the iter-1 fixes ──────────────
|
|
|
|
|
|
def test_unsloth_stream_loop_breaks_on_external_cancel_event():
|
|
cancel_event = threading.Event()
|
|
reset_calls = [0]
|
|
|
|
class _Backend:
|
|
def reset_generation_state(self):
|
|
reset_calls[0] += 1
|
|
|
|
backend = _Backend()
|
|
|
|
def _generate():
|
|
for i in range(200):
|
|
time.sleep(0.005)
|
|
yield f"cum-{i}"
|
|
|
|
async def _loop():
|
|
_DONE = object()
|
|
loop = asyncio.get_event_loop()
|
|
gen = _generate()
|
|
seen = []
|
|
while True:
|
|
if cancel_event.is_set():
|
|
backend.reset_generation_state()
|
|
break
|
|
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
|
|
if cumulative is _DONE:
|
|
break
|
|
seen.append(cumulative)
|
|
return seen
|
|
|
|
async def _fire():
|
|
await asyncio.sleep(0.05)
|
|
cancel_event.set()
|
|
|
|
async def _main():
|
|
return await asyncio.gather(_loop(), _fire())
|
|
|
|
seen, _ = asyncio.run(_main())
|
|
assert (
|
|
len(seen) < 200
|
|
), f"loop must not drain the generator after cancel; got {len(seen)} tokens"
|
|
assert reset_calls[0] == 1, (
|
|
f"backend.reset_generation_state() must be called exactly once on "
|
|
f"cancel-via-POST, got {reset_calls[0]}"
|
|
)
|
|
|
|
|
|
def test_audio_stream_stays_responsive_under_blocking_next():
|
|
# Regression guard: replace the post-fix loop with the pre-fix
|
|
# `for chunk in audio_input_generate()` pattern and assert it blocks
|
|
# the event loop; then confirm the post-fix pattern exits promptly.
|
|
cancel_event = threading.Event()
|
|
|
|
def _audio_gen():
|
|
for i in range(8):
|
|
time.sleep(0.15)
|
|
yield f"chunk-{i}"
|
|
|
|
async def _prefix_loop():
|
|
seen = []
|
|
for chunk_text in _audio_gen():
|
|
if cancel_event.is_set():
|
|
break
|
|
seen.append(chunk_text)
|
|
return seen
|
|
|
|
async def _postfix_loop():
|
|
_DONE = object()
|
|
gen = _audio_gen()
|
|
seen = []
|
|
while True:
|
|
if cancel_event.is_set():
|
|
break
|
|
chunk_text = await asyncio.to_thread(next, gen, _DONE)
|
|
if chunk_text is _DONE:
|
|
break
|
|
seen.append(chunk_text)
|
|
return seen
|
|
|
|
async def _fire_early():
|
|
await asyncio.sleep(0.05)
|
|
cancel_event.set()
|
|
|
|
async def _run(loop_coro):
|
|
return await asyncio.gather(loop_coro, _fire_early())
|
|
|
|
cancel_event.clear()
|
|
t0 = time.monotonic()
|
|
prefix_seen, _ = asyncio.run(_run(_prefix_loop()))
|
|
prefix_elapsed = time.monotonic() - t0
|
|
assert prefix_elapsed >= 0.13, (
|
|
f"pre-fix pattern should block event loop for >=1 chunk time "
|
|
f"(~150ms); got {prefix_elapsed:.3f}s, {len(prefix_seen)} chunks"
|
|
)
|
|
|
|
cancel_event.clear()
|
|
t0 = time.monotonic()
|
|
postfix_seen, _ = asyncio.run(_run(_postfix_loop()))
|
|
postfix_elapsed = time.monotonic() - t0
|
|
assert postfix_elapsed < prefix_elapsed, (
|
|
f"post-fix pattern must exit faster than pre-fix (blocking) "
|
|
f"pattern; post={postfix_elapsed:.3f}s vs pre={prefix_elapsed:.3f}s"
|
|
)
|
|
assert (
|
|
len(postfix_seen) < 8
|
|
), f"post-fix loop must not drain all chunks; got {len(postfix_seen)}"
|
|
|
|
|
|
def test_unsloth_stream_loop_emits_zero_tokens_on_preset_cancel():
|
|
# Pending-cancel replay: _TrackedCancel.__enter__ already set
|
|
# cancel_event before the generator body starts iterating. The
|
|
# top-of-loop check must short-circuit the very first iteration so
|
|
# no token is emitted. Catches a regression that moves the check
|
|
# below `next()` -- the mid-loop test would still pass but this
|
|
# test would observe one extra token leak.
|
|
cancel_event = threading.Event()
|
|
cancel_event.set()
|
|
reset_calls = [0]
|
|
|
|
class _Backend:
|
|
def reset_generation_state(self):
|
|
reset_calls[0] += 1
|
|
|
|
backend = _Backend()
|
|
|
|
next_calls = [0]
|
|
|
|
def _generate():
|
|
while True:
|
|
next_calls[0] += 1
|
|
yield f"cum-{next_calls[0]}"
|
|
|
|
async def _loop():
|
|
_DONE = object()
|
|
loop = asyncio.get_event_loop()
|
|
gen = _generate()
|
|
seen = []
|
|
while True:
|
|
if cancel_event.is_set():
|
|
backend.reset_generation_state()
|
|
break
|
|
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
|
|
if cumulative is _DONE:
|
|
break
|
|
seen.append(cumulative)
|
|
return seen
|
|
|
|
seen = asyncio.run(_loop())
|
|
assert seen == [], (
|
|
f"loop must emit zero tokens when cancel_event is pre-set "
|
|
f"(pending-replay path); got {seen}"
|
|
)
|
|
assert next_calls[0] == 0, (
|
|
f"loop must not call next() at all on pre-set cancel; got "
|
|
f"{next_calls[0]} calls"
|
|
)
|
|
assert reset_calls[0] == 1, (
|
|
f"backend.reset_generation_state() must still fire exactly once "
|
|
f"on pre-set cancel; got {reset_calls[0]}"
|
|
)
|
|
|
|
|
|
def test_audio_stream_emits_zero_chunks_on_preset_cancel():
|
|
# Symmetric to the Unsloth pre-set test: the audio loop's top-of-loop
|
|
# cancel check must skip the asyncio.to_thread(next, ...) call when
|
|
# cancel_event was already set via pending-replay.
|
|
cancel_event = threading.Event()
|
|
cancel_event.set()
|
|
|
|
next_calls = [0]
|
|
|
|
def _audio_gen():
|
|
while True:
|
|
next_calls[0] += 1
|
|
yield f"chunk-{next_calls[0]}"
|
|
|
|
async def _loop():
|
|
_DONE = object()
|
|
gen = _audio_gen()
|
|
seen = []
|
|
while True:
|
|
if cancel_event.is_set():
|
|
break
|
|
chunk_text = await asyncio.to_thread(next, gen, _DONE)
|
|
if chunk_text is _DONE:
|
|
break
|
|
seen.append(chunk_text)
|
|
return seen
|
|
|
|
seen = asyncio.run(_loop())
|
|
assert seen == [], f"audio loop must emit zero chunks on pre-set cancel; got {seen}"
|
|
assert next_calls[0] == 0, (
|
|
f"audio loop must not call next() on pre-set cancel; got "
|
|
f"{next_calls[0]} calls"
|
|
)
|