Studio: make stop button actually stop generation (#5069)

* Studio: make stop button actually stop generation

The UI stop button routes through assistant-ui's cancelRun, which aborts
the frontend fetch. Four issues combined to let llama-server keep decoding
long after the user clicked stop:

1. request.is_disconnected() does not fire reliably behind proxies
   (e.g. Colab) that don't propagate fetch aborts.
2. llama-server defaults n_predict to n_ctx when max_tokens is not sent,
   so a cancelled request keeps producing tokens up to 262144.
3. The httpx.Client pool keeps TCP keep-alive, so even a cleanly closed
   stream reuses the same connection and llama-server's liveness poll
   never sees a disconnect.
4. No explicit backend route to cancel - every cancel path relied on
   is_disconnected.

Changes:
- Add POST /api/inference/cancel keyed by session_id/completion_id, with
  a registry populated for the lifetime of each streaming response.
- Have the frontend (chat-adapter.ts) POST /inference/cancel on
  AbortController abort, alongside the existing fetch teardown.
- Send max_tokens=4096 + t_max_predict_ms=120000 as defaults on every
  outbound chat completion to llama-server; honoured by user overrides.
- Disable httpx keep-alive on the streaming client so connection close
  reaches llama-server and its 1s liveness check fires.

No behaviour changes for non-streaming paths or for existing callers
that already pass max_tokens/session_id.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: harden stop-button cancel path and scope cancel route

- Require at least one identifier for /api/inference/cancel so a missing
  thread id cannot silently cancel every in-flight generation.
- Scope /cancel to a dedicated studio_router so it is not exposed under
  the /v1 OpenAI-compat prefix as a surprise endpoint.
- Store a set of cancel events per key in _CANCEL_REGISTRY so concurrent
  requests on the same session_id do not overwrite each other, and
  deduplicate in _cancel_by_keys so the cancelled count reflects unique
  requests.
- Always send session_id with chat completions (not only when tools are
  enabled) so non-tool GGUF streams register under it and are reachable
  from /cancel.
- Register the non-GGUF stream_chunks path in the cancel registry too,
  so transformers-based stop-button works behind proxies that swallow
  fetch aborts.
- Only apply the 2-minute t_max_predict_ms wall-clock cap when the
  caller did not pass max_tokens, so legitimate long generations on
  slow CPU/macOS/Windows supported installs are not silently truncated.
- Remove the abort listener on normal stream completion so reused
  AbortSignals cannot fire a spurious cancel POST after the fact.

* studio: close cancel-race and stale-cancel gaps in stop path

- Register the cancel tracker before returning StreamingResponse so a
  stop POST that arrives during prefill / warmup / proxy buffering
  finds an entry in _CANCEL_REGISTRY. Cleanup now runs via a Starlette
  BackgroundTask instead of a finally inside the async generator body.
- Add a per-run cancel_id on the frontend (crypto.randomUUID) and in
  ChatCompletionRequest so /api/inference/cancel matches one specific
  generation. Removes the stale-cancel bug where pressing stop then
  starting a new run in the same thread would cancel the retry.
- Apply t_max_predict_ms unconditionally in all three llama-server
  payload builders (previously gated on max_tokens=None, which made it
  dead code for UI callers that always send params.maxTokens). Raise
  the default to 10 minutes so slow CPU / macOS / Windows installs are
  not cut off mid-generation.
- Make _cancel_by_keys refuse empty input (return 0) so a future
  internal caller can not accidentally mass-cancel every in-flight
  request.
- Accept cancel_id (primary), session_id, and completion_id on the
  /api/inference/cancel route. Unify the three streaming sites on the
  same _cancel_keys / _tracker variable names.
- Annotate _CANCEL_REGISTRY as dict[str, set[threading.Event]].

* Add review tests for PR #5069

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: harden stop-button cancel semantics and wall-clock cap

- Make /inference/cancel match cancel_id EXCLUSIVELY when supplied.
  Previously the handler iterated ('cancel_id','session_id','completion_id')
  and unioned matches, so a stale cancel POST carrying {cancel_id:old,
  session_id:thr} would still cancel a later run on the same thread via
  the shared session_id. cancel_id is now a per-run exclusive key;
  session_id / completion_id are only used as fallbacks when cancel_id
  is absent.

- Close the early-cancel race. If /inference/cancel lands before the
  streaming handler reaches _TrackedCancel.__enter__() (stop clicked
  during prefill / warmup / proxy buffering), the cancel was silently
  dropped. Stash unmatched cancel_ids in _PENDING_CANCELS with a 30 s
  TTL; _TrackedCancel.__enter__() now replays any matching pending
  cancel by set()-ing the event immediately after registration.

- Make t_max_predict_ms = _DEFAULT_T_MAX_PREDICT_MS conditional on
  max_tokens is None at all three llama-server payload sites. The cap
  is a safety net for callers who leave max_tokens unset (otherwise
  llama-server defaults n_predict to n_ctx, up to 262144). Callers who
  set an explicit max_tokens are already self-limiting and must not be
  silently truncated at 10 minutes on slow CPU / macOS / Windows
  legitimate long generations.

- Guard each StreamingResponse return with try/except BaseException so
  _tracker.__exit__ runs even if StreamingResponse construction or any
  preceding statement raises between _tracker.__enter__() and the
  BackgroundTask attachment. Prevents a registry leak on that narrow
  window.

* studio: close TOCTOU race and restore wall-clock backstop on UI path

- Close TOCTOU race in the pending-cancel mechanism. The previous fix
  split cancel_inference's (cancel_by_keys + remember_pending_cancel)
  and _TrackedCancel.__enter__'s (register + consume_pending) into
  four separate lock acquisitions. Under contention a cancel POST
  could acquire-then-release the lock, find the registry empty, and
  stash ONLY AFTER __enter__ had already registered and consumed an
  empty pending map -- silently dropping the cancel. Both call sites
  now do their work inside a single _CANCEL_LOCK critical section, via
  the new atomic helper _cancel_by_cancel_id_or_stash() and an
  inlined consume-pending step in __enter__. Reproduced the race under
  forced interleaving pre-fix; 0/2000 drops post-fix under parallel
  stress.

- Apply t_max_predict_ms UNCONDITIONALLY at all three llama-server
  payload sites. The previous iteration gated the cap on
  `max_tokens is None`, which turned out to be dead code on the
  primary Studio UI path: chat-adapter.ts sets
  maxTokens=loadResp.context_length after every model load, so every
  chat request carries an explicit max_tokens and the wall-clock
  safety net never fired. The cap's original purpose is to bound
  stuck decodes regardless of the token budget; it must always apply.

- Raise _DEFAULT_T_MAX_PREDICT_MS from 10 minutes to 1 hour. 10
  minutes was too aggressive for legitimate slow-CPU chat responses
  (a 4096-token reply at 2 tok/s takes ~34 min); 1 hour accommodates
  that and still catches genuine zombie decodes.

- Prune _PENDING_CANCELS inside _cancel_by_keys as well, so stashed
  entries expire proportionally to overall cancel traffic rather than
  only to cancel_id-specific POSTs.

* studio: trim verbose comments and docstrings in cancel path

* studio/llama_cpp: drop upstream PR hashes from benchmark comment

* Add review tests for Studio stop button

* Consolidate review tests for Studio stop button

* Align cancel-route test with exclusive cancel_id semantics

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: move cancel cleanup to generator finally; drop dead helper

- Move _tracker.__exit__ from Starlette BackgroundTask into each
  streaming generator's finally block. Starlette skips the background
  callback when stream_response raises (OSError / ClientDisconnect),
  which leaked _CANCEL_REGISTRY entries on abrupt disconnect.
- Check cancel_event.is_set() at the top of each GGUF while loop so a
  pending-replay cancel falls through to final_chunk + [DONE] instead
  of propagating GeneratorExit out of _stream_with_retry.
- Remove unused _remember_pending_cancel; _cancel_by_cancel_id_or_stash
  superseded it.

* Add review tests for Studio stop-button

* studio: wire audio-input stream into cancel registry

- Register cancel_event with _TrackedCancel on the audio-input streaming
  path so POST /api/inference/cancel can stop whisper / audio-input GGUF
  runs. Previously the registry stayed empty on this branch, so the stop
  button returned {"cancelled":0} and the decode ran to completion.
- Apply the same finally-based cleanup and pre-iteration cancel-event
  check used on the other three streaming paths.
- Update the _CANCEL_REGISTRY block comment to list cancel_id as the
  primary key (was stale "session_id preferred").

* Consolidate review tests for Studio stop-button cancel flow

- Merge the 6 behavioral tests from test_stream_cleanup_on_disconnect.py
  (finally cleanup on normal/exception/aclose, pre-set cancel_event
  pattern, and its regressions) into test_stream_cancel_registration_timing.py,
  which is the PR's existing file covering the same area.
- Extend structural invariants to include audio_input_stream alongside the
  three GGUF / Unsloth streaming generators: no _tracker.__enter__ inside
  the async gen body, cleanup via try/finally, no background= on
  StreamingResponse.
- Delete test_stream_cleanup_on_disconnect.py (now empty).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: make cancel-via-POST interrupt Unsloth and audio-input streams

Close two remaining gaps in the stop-button cancellation wiring:

- stream_chunks (Unsloth path): add a top-of-loop cancel_event check and
  call backend.reset_generation_state() so cancel POSTs flush GPU state
  and close the SSE cleanly instead of relying on request.is_disconnected
  (which does not fire through proxies like Colab's).
- audio_input_stream: run the synchronous audio_input_generate() via
  asyncio.to_thread so blocking whisper chunks do not freeze the event
  loop, matching the pattern already used by the GGUF streaming paths.

* Add review tests for Studio stop-button cancel flow

* Consolidate review tests for Studio stop-button cancel flow

- Delete standalone test_cancel_registry.py at repo root: tests duplicated
  test_cancel_atomicity.py / test_cancel_id_wiring.py and re-implemented
  registry primitives inline (scaffolding).
- Extend tests/studio/test_stream_cancel_registration_timing.py with
  regression guards for the iter-1 cancel-loop fixes:
    structural: each streaming generator checks cancel_event in its loop;
                audio_input_stream offloads next() via asyncio.to_thread;
                stream_chunks cancel branch calls reset_generation_state().
    runtime:    Unsloth loop breaks on external cancel and resets state;
                audio loop stays responsive under blocking next();
                both loops emit zero tokens on pre-set cancel (replay path).

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* studio: extend stop-path to passthrough streams; tighten wall-clock cap

- Lower _DEFAULT_T_MAX_PREDICT_MS from 1 hour to 10 minutes so the
  wall-clock backstop actually bounds runaway decodes when cancel
  signaling fails.
- Wire _TrackedCancel and cancel_event.is_set() into
  _openai_passthrough_stream and _anthropic_passthrough_stream and
  disable httpx keepalive so stop requests from /v1 and /v1/messages
  tool-calling clients reach llama-server.
- Apply t_max_predict_ms to the tool-passthrough request body so the
  backstop covers passthrough paths as well.
- Symmetric pre-registration stash for session_id/completion_id
  cancels (_cancel_by_keys_or_stash) so early cancels by those keys
  replay on later registration like cancel_id.
- Drop dead except BaseException guards around StreamingResponse()
  at four streaming sites; cleanup lives in the generator's finally.

* studio: harden cancel registry against ghost-cancel and leak paths

- Revert the session_id/completion_id stash in the fallback cancel
  helper. session_id is thread-scoped and reused across runs, so
  stashing it on an unmatched POST would fire cancel_event for the
  user's next unrelated request via _TrackedCancel.__enter__.
  cancel_id remains the only per-run unique key that gets stashed.
- Default max_tokens to _DEFAULT_MAX_TOKENS in the tool-passthrough
  body. Mirror the direct GGUF path so OpenAI/Anthropic passthrough
  callers who omit max_tokens get the same zombie-decode cap instead
  of relying on the wall-clock backstop alone.
- Wrap _openai_passthrough_stream setup with an outer try/except
  BaseException. The inner except httpx.RequestError does not catch
  asyncio.CancelledError at await client.send, which would otherwise
  leave _tracker registered in _CANCEL_REGISTRY indefinitely.
- Frontend stop POST uses plain fetch + manual Authorization header
  instead of authFetch. A 401 on the cancel POST no longer refreshes
  tokens or redirects the user to the login page mid-stop.

* Add review tests for Studio stop-button cancel flow

* studio: trim comments on stop-button review changes

Collapse multi-paragraph rationale blocks on the cancel registry,
_openai_passthrough_stream, and the frontend onAbortCancel handler
into one-line explanations of why the non-obvious behaviour exists.
Drop authFetch import that became unused when the cancel POST
switched to plain fetch.

* Consolidate review tests for Studio stop-button cancel flow

Move review-added tests out of test_cancel_dispatch_edges.py into the
existing PR test files that already cover the same areas:
- backend registry fan-out / exclusivity / idempotency / falsy-keys
  edge cases moved into tests/studio/test_cancel_atomicity.py
- frontend plain-fetch (not authFetch) + manual Authorization header
  moved into tests/studio/test_cancel_id_wiring.py
Delete the now-empty test_cancel_dispatch_edges.py.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Studio: stop default-capping responses at 4096 tokens (follow-up to #5069) (#5174)

* Studio: stop default-capping responses at 4096 tokens

Follow-up to #5069. The 4096 default introduced for runaway-decode
defense silently truncates any caller that omits max_tokens. The
Studio chat UI sets params.maxTokens = loadResp.context_length after
a GGUF load, so it's fine, but every other consumer is not:

- OpenAI-API direct callers (/v1/chat/completions, /v1/responses,
  /v1/messages, /v1/completions) where the OpenAI default is
  effectively unlimited per response. langchain, llama-index, raw
  curl, and the openai SDK all rely on that.
- Reasoning models. Qwen3 / gpt-oss reasoning traces routinely exceed
  4096 tokens before the model emits a single visible content token.
  The user sees the trace cut off mid-thought.
- Long-form generation ("write a chapter", "produce a full SVG").

Reproduced on this branch: gemma-4-E2B-it-GGUF Q8_0, prompt asking
for a 10000-word story, no max_tokens in the request:

    finish_reason: stop  (misleading -- should be 'length')
    content_chars: 19772
    content_tail: ...'a comforting, yet immense, pressure.\n\n*"'

Body ended mid-sentence on a stray opening quote, right at the 4096
token mark.

After this patch the same request returns 38357 chars ending with
'...held in a perfect, dynamic equilibrium.' -- a natural stop, not
a truncation.

Implementation: rename the constant to _DEFAULT_MAX_TOKENS_FLOOR and
set it to 32768. Each call site now uses the model's effective
context length when known, falling back to the floor:

    default_cap = self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR

The 10-minute t_max_predict_ms wall-clock backstop from #5069 is
preserved as the second line of defense.

Plumbed _build_passthrough_payload + _build_openai_passthrough_body
through the routes layer so the Anthropic and OpenAI passthrough
paths also respect the model's context length.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Studio: cancel passthrough streams during llama-server prefill + route through apiUrl for Tauri

Three reviewer-flagged correctness gaps in the stop-button mechanism.

1) `_openai_passthrough_stream` could not honor cancel during prefill.
   The cancel check ran inside the `async for raw_line in lines_iter`
   body, so a cancel POST that arrived before llama-server emitted the
   first SSE line was unobservable until prefill completed. With a long
   prompt under proxy/Colab conditions -- the exact target scenario for
   this PR -- that left the model decoding for a long time after the
   user clicked Stop. Add an asyncio watcher task that closes `resp` as
   soon as `cancel_event` is set, raising in `aiter_lines` so the
   generator can exit. The watcher polls a threading.Event because the
   cancel registry is keyed by threading.Event for the synchronous
   /cancel handler.

2) `_anthropic_passthrough_stream` had the same blocking-prefill pattern.
   Same fix.

3) The frontend's stop-button cancel POST used a bare relative
   `fetch("/api/inference/cancel", ...)`, which targets the webview
   origin in Tauri production builds (where the backend is at
   `http://127.0.0.1:8888`). Route through the existing `apiUrl()`
   helper from `lib/api-base.ts` to match every other Studio call.
   Browser/dev builds get the empty base, so behavior is unchanged
   there.

Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: cancel
during prefill terminates within ~250ms on both passthrough paths
(was 145s+ on the Anthropic path before this change), and the standard
non-passthrough chat path still cancels with no regression.

* Studio: log cancel-body parse errors instead of silently swallowing

Reviewer-flagged defensive logging gap. The bare `except Exception: pass`
in `cancel_inference` would mask malformed payloads that hint at a buggy
client or a transport issue. Log at debug so future investigation isn't
left guessing whether `body={}` came from a missing body or a parse
failure. Behavior is unchanged: an unparseable body still falls through
to the empty-dict path and the cancel call returns `{"cancelled": 0}`.

* Studio: Anthropic passthrough cancel parity with OpenAI passthrough

Two reviewer-flagged consistency gaps in the cancel surface for
/v1/messages.

1) Anthropic passthrough did not register cancel_id, so a per-run cancel
   POST (the cleanest Studio-style cancel path) silently missed when
   the route hit `_anthropic_passthrough_stream`. The OpenAI passthrough
   has registered (cancel_id, session_id, completion_id) since this PR
   was first opened; mirror that here. Also add `cancel_id` to
   `AnthropicMessagesRequest` so the route handler can plumb it through.

2) The cancel handler's fallback key list checked only completion_id
   and session_id, never message_id. Anthropic clients that send their
   native `id` (returned in the SSE message_start event) for cancel had
   no way to hit the registry. Add message_id to the fallback list.

Verified via temp/pr_simulation/sim_5069_prefill_cancel.py: P2 now
cancels by cancel_id in 137ms (was hanging pre-fix), and the new P2b
case cancels by message_id in 77ms. P1 (OpenAI) and P3 (standard chat)
still pass with no regression.

---------

Co-authored-by: danielhanchen <michaelhan2050@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-24 10:09:25 -07:00 committed by GitHub
parent 8264e80dd9
commit eb8b0dee2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1785 additions and 119 deletions

View file

@ -53,6 +53,21 @@ _INTENT_SIGNAL = re.compile(
r")"
)
_MAX_REPROMPTS = 3
# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to
# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel
# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally,
# but the llama.cpp README notes it ONLY fires after a newline has been
# generated -- a model stuck in a long unbroken non-newline sequence is
# unbounded by it. So we still want a token cap as the front-line limiter.
#
# The cap is the model's effective context length when we know it,
# falling back to a generous floor when metadata is unavailable. 4096 was
# too low: Qwen3 / gpt-oss reasoning traces routinely exceed it, and any
# OpenAI-API caller that omits max_tokens (langchain, llama-index, raw
# curl) sees responses silently truncated mid-sentence.
_DEFAULT_MAX_TOKENS_FLOOR = 32768
_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min
_REPROMPT_MAX_CHARS = 2000
# ── Pre-compiled patterns for GGUF shard detection ───────────
@ -1636,7 +1651,7 @@ class LlamaCppBackend:
# existing text (code refactoring, summarization, reasoning).
# For general chat with low repetition, overhead is ~5 ms.
#
# Benchmarks from llama.cpp PRs #18471, #19164:
# Benchmarks from upstream llama.cpp speculative-decoding PRs:
# Scenario | Without | With | Speedup
# gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x
# Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x
@ -2549,8 +2564,15 @@ class LlamaCppBackend:
)
if _reasoning_kw is not None:
payload["chat_template_kwargs"] = _reasoning_kw
if max_tokens is not None:
payload["max_tokens"] = max_tokens
# Default cap to the model's effective context length when known,
# otherwise the conservative floor. The wall-clock backstop below
# keeps a stuck model from running indefinitely either way.
payload["max_tokens"] = (
max_tokens
if max_tokens is not None
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
)
payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
if stop:
payload["stop"] = stop
payload["stream_options"] = {"include_usage": True}
@ -2570,7 +2592,9 @@ class LlamaCppBackend:
_auth_headers = (
{"Authorization": f"Bearer {self._api_key}"} if self._api_key else None
)
with httpx.Client(timeout = stream_timeout) as client:
with httpx.Client(
timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0)
) as client:
with self._stream_with_retry(
client,
url,
@ -2769,8 +2793,12 @@ class LlamaCppBackend:
)
if _reasoning_kw is not None:
payload["chat_template_kwargs"] = _reasoning_kw
if max_tokens is not None:
payload["max_tokens"] = max_tokens
payload["max_tokens"] = (
max_tokens
if max_tokens is not None
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
)
payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
if stop:
payload["stop"] = stop
@ -2809,7 +2837,10 @@ class LlamaCppBackend:
write = 10,
pool = 10,
)
with httpx.Client(timeout = stream_timeout) as client:
with httpx.Client(
timeout = stream_timeout,
limits = httpx.Limits(max_keepalive_connections = 0),
) as client:
with self._stream_with_retry(
client,
url,
@ -3422,8 +3453,12 @@ class LlamaCppBackend:
)
if _reasoning_kw is not None:
stream_payload["chat_template_kwargs"] = _reasoning_kw
if max_tokens is not None:
stream_payload["max_tokens"] = max_tokens
stream_payload["max_tokens"] = (
max_tokens
if max_tokens is not None
else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR)
)
stream_payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
if stop:
stream_payload["stop"] = stop
stream_payload["stream_options"] = {"include_usage": True}
@ -3442,7 +3477,9 @@ class LlamaCppBackend:
_auth_headers = (
{"Authorization": f"Bearer {self._api_key}"} if self._api_key else None
)
with httpx.Client(timeout = stream_timeout) as client:
with httpx.Client(
timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0)
) as client:
with self._stream_with_retry(
client,
url,

View file

@ -62,6 +62,7 @@ from routes import (
datasets_router,
export_router,
inference_router,
inference_studio_router,
models_router,
training_history_router,
training_router,
@ -207,6 +208,9 @@ app.include_router(auth_router, prefix = "/api/auth", tags = ["auth"])
app.include_router(training_router, prefix = "/api/train", tags = ["training"])
app.include_router(models_router, prefix = "/api/models", tags = ["models"])
app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"])
# Studio-only inference endpoints (cancel, etc.) are intentionally NOT
# exposed on the /v1 OpenAI-compat prefix below.
app.include_router(inference_studio_router, prefix = "/api/inference", tags = ["inference"])
# OpenAI-compatible endpoints: mount the same inference router at /v1
# so external tools (Open WebUI, SillyTavern, etc.) can use the

View file

@ -531,6 +531,10 @@ class ChatCompletionRequest(BaseModel):
None,
description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.",
)
cancel_id: Optional[str] = Field(
None,
description = "[x-unsloth] Per-request cancellation token. Frontend sends a fresh UUID per run so /inference/cancel matches one specific generation.",
)
# ── Streaming response chunks ────────────────────────────────────
@ -992,6 +996,7 @@ class AnthropicMessagesRequest(BaseModel):
enable_tools: Optional[bool] = None
enabled_tools: Optional[list[str]] = None
session_id: Optional[str] = None
cancel_id: Optional[str] = None
model_config = {"extra": "allow"}

View file

@ -8,6 +8,7 @@ API Routes
from routes.training import router as training_router
from routes.models import router as models_router
from routes.inference import router as inference_router
from routes.inference import studio_router as inference_studio_router
from routes.datasets import router as datasets_router
from routes.auth import router as auth_router
from routes.data_recipe import router as data_recipe_router
@ -18,6 +19,7 @@ __all__ = [
"training_router",
"models_router",
"inference_router",
"inference_studio_router",
"datasets_router",
"auth_router",
"data_recipe_router",

View file

@ -113,7 +113,12 @@ if str(backend_path) not in sys.path:
# Import backend functions
try:
from core.inference import get_inference_backend
from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags
from core.inference.llama_cpp import (
LlamaCppBackend,
_DEFAULT_MAX_TOKENS_FLOOR,
_DEFAULT_T_MAX_PREDICT_MS,
detect_reasoning_flags,
)
from utils.models import ModelConfig
from utils.inference import load_inference_config
from utils.models.model_config import load_model_defaults
@ -122,7 +127,12 @@ except ImportError:
if str(parent_backend) not in sys.path:
sys.path.insert(0, str(parent_backend))
from core.inference import get_inference_backend
from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags
from core.inference.llama_cpp import (
LlamaCppBackend,
_DEFAULT_MAX_TOKENS_FLOOR,
_DEFAULT_T_MAX_PREDICT_MS,
detect_reasoning_flags,
)
from utils.models import ModelConfig
from utils.inference import load_inference_config
from utils.models.model_config import load_model_defaults
@ -185,6 +195,126 @@ import numpy as np
from datetime import date as _date
router = APIRouter()
# Studio-only router (not mounted on /v1 OpenAI-compat).
studio_router = APIRouter()
# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts
# so is_disconnected() never fires. POST /inference/cancel looks up
# in-flight cancel_events here by cancel_id (per-run) or session_id /
# completion_id (fallbacks).
_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {}
_CANCEL_LOCK = threading.Lock()
# Cancel POSTs that arrive before registration are stashed; the next
# matching __enter__ replays set() within the TTL.
_PENDING_CANCELS: dict[str, float] = {}
_PENDING_CANCEL_TTL_S = 30.0
def _prune_pending(now: float) -> None:
for k in [
k for k, ts in _PENDING_CANCELS.items() if now - ts > _PENDING_CANCEL_TTL_S
]:
_PENDING_CANCELS.pop(k, None)
class _TrackedCancel:
"""Register cancel_event in _CANCEL_REGISTRY for the block's duration."""
def __init__(self, event: threading.Event, *keys):
self.event = event
self.keys = tuple(k for k in keys if k)
def __enter__(self):
# Register + consume-pending must be one critical section to close
# the TOCTOU race against a concurrent cancel POST.
should_cancel = False
with _CANCEL_LOCK:
for k in self.keys:
_CANCEL_REGISTRY.setdefault(k, set()).add(self.event)
now = time.monotonic()
_prune_pending(now)
for k in self.keys:
if k and _PENDING_CANCELS.pop(k, None) is not None:
should_cancel = True
if should_cancel:
self.event.set()
return self.event
def __exit__(self, *exc):
with _CANCEL_LOCK:
for k in self.keys:
bucket = _CANCEL_REGISTRY.get(k)
if bucket is None:
continue
bucket.discard(self.event)
if not bucket:
_CANCEL_REGISTRY.pop(k, None)
return False
def _cancel_by_keys(keys) -> int:
"""Set cancel_event for matching registry entries; no stash.
session_id/completion_id are shared across runs on the same thread,
so stashing them would ghost-cancel the user's next request. Only
cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash)."""
if not keys:
return 0
events: set[threading.Event] = set()
with _CANCEL_LOCK:
_prune_pending(time.monotonic())
for k in keys:
bucket = _CANCEL_REGISTRY.get(k)
if bucket:
events.update(bucket)
for ev in events:
ev.set()
return len(events)
def _cancel_by_cancel_id_or_stash(cancel_id: str) -> int:
"""Atomic lookup-or-stash; pairs with _TrackedCancel.__enter__ to
close the TOCTOU race."""
now = time.monotonic()
events: set[threading.Event] = set()
with _CANCEL_LOCK:
_prune_pending(now)
bucket = _CANCEL_REGISTRY.get(cancel_id)
if bucket:
events.update(bucket)
else:
_PENDING_CANCELS[cancel_id] = now
for ev in events:
ev.set()
return len(events)
async def _await_cancel_then_close(cancel_event, resp) -> None:
"""Watch a threading.Event from asyncio and close ``resp`` when it fires.
Used by the passthrough streamers so a /cancel POST can interrupt
while the async iterator is blocked waiting for llama-server prefill.
Without this watcher the in-loop ``cancel_event.is_set()`` check is
unreachable until the first SSE chunk arrives, which is exactly the
proxy/Colab scenario the cancel POST exists to handle.
Polls a threading.Event because the cancel registry is keyed by
threading.Event so the synchronous /cancel handler can call .set().
50ms cadence adds at most that much latency to a prefill cancel; the
common-case streaming cancel path still observes the event in the
iterator's first iteration after the next chunk.
"""
try:
while not cancel_event.is_set():
await asyncio.sleep(0.05)
try:
await resp.aclose()
except Exception:
pass
except asyncio.CancelledError:
return
# Appended to tool-use nudge to discourage plan-without-action
_TOOL_ACTION_NUDGE = (
@ -706,6 +836,48 @@ async def unload_model(
raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}")
@studio_router.post("/cancel")
async def cancel_inference(
request: Request,
current_subject: str = Depends(get_current_subject),
):
"""Cancel in-flight inference requests.
Body (JSON, at least one key required):
cancel_id - preferred: per-run UUID, matched exclusively.
session_id - fallback when cancel_id is absent.
completion_id - fallback when cancel_id is absent.
A cancel_id arriving before its stream registers is stashed briefly
and replayed on registration. Returns {"cancelled": N}.
"""
try:
body = await request.json()
if not isinstance(body, dict):
body = {}
except Exception as e:
logger.debug("Failed to parse cancel request body: %s", e)
body = {}
cancel_id = body.get("cancel_id")
if isinstance(cancel_id, str) and cancel_id:
return {"cancelled": _cancel_by_cancel_id_or_stash(cancel_id)}
keys = []
# `message_id` is the Anthropic passthrough's per-run identifier --
# included so /v1/messages clients can cancel by their native id.
for k in ("completion_id", "session_id", "message_id"):
v = body.get(k)
if isinstance(v, str) and v:
keys.append(v)
if not keys:
return {"cancelled": 0}
n = _cancel_by_keys(keys)
return {"cancelled": n}
@router.post("/generate/stream")
async def generate_stream(
request: GenerateRequest,
@ -1169,6 +1341,9 @@ async def openai_chat_completions(
)
if payload.stream:
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
_tracker.__enter__()
async def audio_input_stream():
try:
@ -1185,10 +1360,17 @@ async def openai_chat_completions(
)
yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n"
for chunk_text in audio_input_generate():
gen = audio_input_generate()
_DONE = object()
while True:
if cancel_event.is_set():
break
if await request.is_disconnected():
cancel_event.set()
return
chunk_text = await asyncio.to_thread(next, gen, _DONE)
if chunk_text is _DONE:
break
if chunk_text:
chunk = ChatCompletionChunk(
id = completion_id,
@ -1221,6 +1403,8 @@ async def openai_chat_completions(
f"Error during audio input streaming: {e}", exc_info = True
)
yield f"data: {json.dumps({'error': {'message': _friendly_error(e), 'type': 'server_error'}})}\n\n"
finally:
_tracker.__exit__(None, None, None)
return StreamingResponse(
audio_input_stream(),
@ -1466,6 +1650,10 @@ async def openai_chat_completions(
_tool_sentinel = object()
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
_tracker.__enter__()
async def gguf_tool_stream():
try:
first_chunk = ChatCompletionChunk(
@ -1488,6 +1676,8 @@ async def openai_chat_completions(
_stream_usage = None
_stream_timings = None
while True:
if cancel_event.is_set():
break
if await request.is_disconnected():
cancel_event.set()
return
@ -1595,6 +1785,8 @@ async def openai_chat_completions(
},
}
yield f"data: {json.dumps(error_chunk)}\n\n"
finally:
_tracker.__exit__(None, None, None)
return StreamingResponse(
gguf_tool_stream(),
@ -1628,6 +1820,9 @@ async def openai_chat_completions(
_gguf_sentinel = object()
if payload.stream:
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
_tracker.__enter__()
async def gguf_stream_chunks():
try:
@ -1652,6 +1847,8 @@ async def openai_chat_completions(
_stream_usage = None
_stream_timings = None
while True:
if cancel_event.is_set():
break
if await request.is_disconnected():
cancel_event.set()
return
@ -1735,6 +1932,8 @@ async def openai_chat_completions(
},
}
yield f"data: {json.dumps(error_chunk)}\n\n"
finally:
_tracker.__exit__(None, None, None)
return StreamingResponse(
gguf_stream_chunks(),
@ -1834,6 +2033,9 @@ async def openai_chat_completions(
# ── Streaming response ────────────────────────────────────────
if payload.stream:
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
_tracker.__enter__()
async def stream_chunks():
try:
@ -1861,6 +2063,9 @@ async def openai_chat_completions(
loop = asyncio.get_event_loop()
gen = generate()
while True:
if cancel_event.is_set():
backend.reset_generation_state()
break
# next(gen, _DONE) returns _DONE instead of raising
# StopIteration — StopIteration cannot propagate
# through asyncio futures (Python limitation).
@ -1916,6 +2121,8 @@ async def openai_chat_completions(
},
}
yield f"data: {json.dumps(error_chunk)}\n\n"
finally:
_tracker.__exit__(None, None, None)
return StreamingResponse(
stream_chunks(),
@ -2596,7 +2803,9 @@ async def _responses_stream(
),
)
body = _build_openai_passthrough_body(chat_req)
body = _build_openai_passthrough_body(
chat_req, backend_ctx = llama_backend.context_length
)
target_url = f"{llama_backend.base_url}/v1/chat/completions"
async def event_generator():
@ -3081,6 +3290,8 @@ async def anthropic_messages(
repetition_penalty = repetition_penalty,
presence_penalty = presence_penalty,
tool_choice = openai_tool_choice,
session_id = payload.session_id,
cancel_id = payload.cancel_id,
)
return await _anthropic_passthrough_non_streaming(
llama_backend,
@ -3441,6 +3652,7 @@ def _build_passthrough_payload(
repetition_penalty = None,
presence_penalty = None,
tool_choice = "auto",
backend_ctx = None,
):
body = {
"messages": openai_messages,
@ -3453,8 +3665,12 @@ def _build_passthrough_payload(
}
if stream:
body["stream_options"] = {"include_usage": True}
if max_tokens is not None:
body["max_tokens"] = max_tokens
body["max_tokens"] = (
max_tokens
if max_tokens is not None
else (backend_ctx or _DEFAULT_MAX_TOKENS_FLOOR)
)
body["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS
if stop:
body["stop"] = stop
if min_p is not None:
@ -3484,6 +3700,8 @@ async def _anthropic_passthrough_stream(
repetition_penalty = None,
presence_penalty = None,
tool_choice = "auto",
session_id = None,
cancel_id = None,
):
"""Streaming client-side pass-through: forward tools to llama-server and
translate its streaming response to Anthropic SSE without executing anything."""
@ -3501,8 +3719,14 @@ async def _anthropic_passthrough_stream(
repetition_penalty = repetition_penalty,
presence_penalty = presence_penalty,
tool_choice = tool_choice,
backend_ctx = llama_backend.context_length,
)
# cancel_id mirrors the OpenAI passthrough so a per-run cancel POST
# works without the caller having to know the local message_id.
_tracker = _TrackedCancel(cancel_event, cancel_id, session_id, message_id)
_tracker.__enter__()
async def _stream():
emitter = AnthropicPassthroughEmitter()
for line in emitter.start(message_id, model_name):
@ -3535,15 +3759,28 @@ async def _anthropic_passthrough_stream(
# has anything orphaned to finalize. Each aclose is wrapped in
# `try: ... except Exception: pass` so anyio cleanup noise from
# nested aclose paths can't bubble out.
client = httpx.AsyncClient(timeout = 600)
client = httpx.AsyncClient(
timeout = 600,
limits = httpx.Limits(max_keepalive_connections = 0),
)
resp = None
lines_iter = None
cancel_watcher = None
try:
req = client.build_request("POST", target_url, json = body)
resp = await client.send(req, stream = True)
# See _openai_passthrough_stream for rationale: aiter_lines()
# blocks during llama-server prefill, so the in-loop cancel
# check is unreachable until the first SSE chunk arrives.
# The watcher closes `resp` on cancel, raising in aiter_lines.
cancel_watcher = asyncio.create_task(
_await_cancel_then_close(cancel_event, resp)
)
lines_iter = resp.aiter_lines()
async for raw_line in lines_iter:
if cancel_event.is_set():
break
if await request.is_disconnected():
cancel_event.set()
break
@ -3558,9 +3795,18 @@ async def _anthropic_passthrough_stream(
continue
for line in emitter.feed_chunk(chunk):
yield line
except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError):
if not cancel_event.is_set():
raise
except Exception as e:
logger.error("anthropic_messages passthrough stream error: %s", e)
finally:
if cancel_watcher is not None:
cancel_watcher.cancel()
try:
await cancel_watcher
except (asyncio.CancelledError, Exception):
pass
if lines_iter is not None:
try:
await lines_iter.aclose()
@ -3575,6 +3821,7 @@ async def _anthropic_passthrough_stream(
await client.aclose()
except Exception:
pass
_tracker.__exit__(None, None, None)
for line in emitter.finish():
yield line
@ -3621,6 +3868,7 @@ async def _anthropic_passthrough_non_streaming(
repetition_penalty = repetition_penalty,
presence_penalty = presence_penalty,
tool_choice = tool_choice,
backend_ctx = llama_backend.context_length,
)
async with httpx.AsyncClient() as client:
@ -3742,7 +3990,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
return messages
def _build_openai_passthrough_body(payload) -> dict:
def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict:
"""Assemble the llama-server request body from a ChatCompletionRequest.
Only explicitly-known OpenAI / llama-server fields are forwarded so that
@ -3764,6 +4012,7 @@ def _build_openai_passthrough_body(payload) -> dict:
repetition_penalty = payload.repetition_penalty,
presence_penalty = payload.presence_penalty,
tool_choice = tool_choice,
backend_ctx = backend_ctx,
)
@ -3784,14 +4033,26 @@ async def _openai_passthrough_stream(
observes a standard OpenAI response.
"""
target_url = f"{llama_backend.base_url}/v1/chat/completions"
body = _build_openai_passthrough_body(payload)
body = _build_openai_passthrough_body(
payload, backend_ctx = llama_backend.context_length
)
# Dispatch the upstream request BEFORE returning StreamingResponse so
# transport errors and non-200 upstream statuses surface as real HTTP
# errors to the client. OpenAI SDKs rely on status codes to raise
# ``APIError``/``BadRequestError``/...; burying the failure inside a
# 200 SSE ``error`` frame silently breaks their error handling.
client = httpx.AsyncClient(timeout = 600)
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
_tracker = _TrackedCancel(cancel_event, *_cancel_keys)
_tracker.__enter__()
# Outer guard: asyncio.CancelledError at `await client.send(...)` is
# a BaseException that bypasses `except httpx.RequestError`; without
# this the tracker leaks. The generator's finally only runs once
# iteration starts.
try:
# Dispatch BEFORE returning StreamingResponse so transport errors
# and non-200 upstream statuses surface as real HTTP errors --
# OpenAI SDKs rely on status codes to raise APIError/BadRequestError.
client = httpx.AsyncClient(
timeout = 600,
limits = httpx.Limits(max_keepalive_connections = 0),
)
resp = None
try:
req = client.build_request("POST", target_url, json = body)
@ -3837,19 +4098,24 @@ async def _openai_passthrough_stream(
async def _stream():
# Same httpx lifecycle pattern as _anthropic_passthrough_stream:
# avoid `async with` on the client/response AND explicitly save
# resp.aiter_lines() so we can close it ourselves in the finally
# block. See the long comment there for the full rationale on
# why the anonymous `async for raw_line in resp.aiter_lines():`
# pattern leaks an unclosed async generator that Python's
# asyncgen GC hook then finalizes in a different asyncio task,
# producing "Exception ignored in:" / "async generator ignored
# GeneratorExit" / anyio cancel-scope traces on Python 3.13 +
# httpcore 1.0.x.
# save resp.aiter_lines() so the finally block can aclose() it
# on our task. See that function for full rationale.
lines_iter = None
# During llama-server prefill, `aiter_lines()` blocks until the
# first SSE chunk arrives. The in-loop `cancel_event` check
# cannot fire until then, which is the exact proxy/Colab
# scenario the cancel POST is meant to recover from. Run a
# tiny watcher that closes `resp` as soon as cancel fires,
# unblocking the iterator with a RemoteProtocolError caught
# in the except clause below.
cancel_watcher = asyncio.create_task(
_await_cancel_then_close(cancel_event, resp)
)
try:
lines_iter = resp.aiter_lines()
async for raw_line in lines_iter:
if cancel_event.is_set():
break
if await request.is_disconnected():
cancel_event.set()
break
@ -3857,16 +4123,18 @@ async def _openai_passthrough_stream(
continue
if not raw_line.startswith("data: "):
continue
# Relay the llama-server SSE chunk verbatim so the client
# sees its native `id`, `finish_reason`, `delta.tool_calls`,
# and final `usage` unchanged.
# Relay verbatim to preserve llama-server's native id,
# finish_reason, delta.tool_calls, and usage chunks.
yield raw_line + "\n\n"
if raw_line[6:].strip() == "[DONE]":
break
except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError):
# Watcher closed resp on cancel. Emit nothing extra; the
# client either initiated the cancel or already disconnected.
if not cancel_event.is_set():
raise
except Exception as e:
# Mid-stream failures still have to be reported inside the SSE
# body because the 200 response headers have already been
# committed by the time the first chunk flushes.
# 200 headers are already flushed; errors must be in the SSE body.
logger.error("openai passthrough stream error: %s", e)
err = {
"error": {
@ -3876,6 +4144,11 @@ async def _openai_passthrough_stream(
}
yield f"data: {json.dumps(err)}\n\n"
finally:
cancel_watcher.cancel()
try:
await cancel_watcher
except (asyncio.CancelledError, Exception):
pass
if lines_iter is not None:
try:
await lines_iter.aclose()
@ -3889,6 +4162,7 @@ async def _openai_passthrough_stream(
await client.aclose()
except Exception:
pass
_tracker.__exit__(None, None, None)
return StreamingResponse(
_stream(),
@ -3899,6 +4173,9 @@ async def _openai_passthrough_stream(
"X-Accel-Buffering": "no",
},
)
except BaseException:
_tracker.__exit__(None, None, None)
raise
async def _openai_passthrough_non_streaming(
@ -3914,7 +4191,9 @@ async def _openai_passthrough_non_streaming(
token counts.
"""
target_url = f"{llama_backend.base_url}/v1/chat/completions"
body = _build_openai_passthrough_body(payload)
body = _build_openai_passthrough_body(
payload, backend_ctx = llama_backend.context_length
)
try:
async with httpx.AsyncClient() as client:

View file

@ -4,6 +4,8 @@
import type { ChatModelAdapter } from "@assistant-ui/react";
import type { MessageTiming, ToolCallMessagePart } from "@assistant-ui/core";
import { toast } from "sonner";
import { getAuthToken } from "@/features/auth/session";
import { apiUrl } from "@/lib/api-base";
import {
generateAudio,
listCachedGguf,
@ -707,7 +709,42 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
const toolCallParts: ToolCallMessagePart[] = [];
let serverMetadata: { usage?: ServerUsage; timings?: ServerTimings } | null = null;
// Per-run cancellation token so a delayed stop POST cannot match
// the next run on the same thread.
const cancelId =
typeof crypto !== "undefined" && "randomUUID" in crypto
? crypto.randomUUID()
: `${Date.now()}-${Math.random().toString(36).slice(2)}`;
// Colab-style proxies can swallow fetch aborts, so also POST
// /inference/cancel explicitly on abort.
const onAbortCancel = () => {
const body: Record<string, string> = { cancel_id: cancelId };
if (resolvedThreadId) body.session_id = resolvedThreadId;
// Plain fetch, not authFetch: authFetch redirects to login on
// 401, which would kick the user out mid-stop.
const token = getAuthToken();
// Use apiUrl so the cancel POST reaches the right origin in
// Tauri production builds (where the webview origin is not the
// backend at 127.0.0.1:<port>). Browser/dev builds get the empty
// base, so the path is unchanged there.
void fetch(apiUrl("/api/inference/cancel"), {
method: "POST",
headers: {
"Content-Type": "application/json",
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
body: JSON.stringify(body),
keepalive: true,
}).catch(() => {});
};
try {
if (abortSignal.aborted) {
onAbortCancel();
} else {
abortSignal.addEventListener("abort", onAbortCancel, { once: true });
}
const {
supportsReasoning,
reasoningEnabled,
@ -730,6 +767,8 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
presence_penalty: params.presencePenalty,
image_base64: imageBase64,
audio_base64: audioBase64,
cancel_id: cancelId,
...(resolvedThreadId ? { session_id: resolvedThreadId } : {}),
...(useAdapter === undefined ? {} : { use_adapter: useAdapter }),
...(supportsReasoning
? reasoningStyle === "reasoning_effort"
@ -750,7 +789,6 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
const mins = useChatRuntimeStore.getState().toolCallTimeout;
return mins >= 9999 ? 9999 : mins * 60;
})(),
session_id: resolvedThreadId,
}
: {}),
},
@ -948,6 +986,7 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter {
}
throw err;
} finally {
abortSignal.removeEventListener("abort", onAbortCancel);
runtime.setGeneratingStatus(null);
runtime.setToolStatus(null);
clearTimeout(warmupTimer);

View file

@ -179,6 +179,7 @@ export interface OpenAIChatCompletionsRequest {
max_tool_calls_per_message?: number;
tool_call_timeout?: number;
session_id?: string;
cancel_id?: string;
}
export interface OpenAIChatDelta {

View file

@ -0,0 +1,289 @@
"""
TOCTOU atomicity guards for the cancel path.
Structural: cancel_inference, _cancel_by_cancel_id_or_stash, and
_TrackedCancel.__enter__ must each use a single _CANCEL_LOCK critical
section over lookup + stash / register + consume-pending.
Behavioral: parallel cancel-POST vs __enter__ must never drop a cancel.
"""
from __future__ import annotations
import ast
import random
import threading
from pathlib import Path
SOURCE_PATH = (
Path(__file__).resolve().parents[2]
/ "studio"
/ "backend"
/ "routes"
/ "inference.py"
)
_SRC = SOURCE_PATH.read_text()
_TREE = ast.parse(_SRC)
def _find_function(name: str) -> ast.FunctionDef | ast.AsyncFunctionDef:
for node in ast.walk(_TREE):
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == name
):
return node
raise AssertionError(f"function {name!r} not found")
def _find_class(name: str) -> ast.ClassDef:
for node in ast.walk(_TREE):
if isinstance(node, ast.ClassDef) and node.name == name:
return node
raise AssertionError(f"class {name!r} not found")
def _count_with_cancel_lock_blocks(node: ast.AST) -> int:
n = 0
for sub in ast.walk(node):
if not isinstance(sub, ast.With):
continue
for item in sub.items:
ctx = item.context_expr
if isinstance(ctx, ast.Name) and ctx.id == "_CANCEL_LOCK":
n += 1
break
return n
def test_cancel_by_cancel_id_or_stash_is_single_lock_critical_section():
fn = _find_function("_cancel_by_cancel_id_or_stash")
assert _count_with_cancel_lock_blocks(fn) == 1, (
"_cancel_by_cancel_id_or_stash must use exactly one `with "
"_CANCEL_LOCK:` block; splitting into two acquisitions reopens "
"the TOCTOU race with _TrackedCancel.__enter__"
)
src = ast.unparse(fn)
assert "_CANCEL_REGISTRY.get(cancel_id)" in src
assert "_PENDING_CANCELS[cancel_id]" in src
def test_tracked_cancel_enter_registers_and_consumes_pending_under_one_lock():
cls = _find_class("_TrackedCancel")
enter = None
for n in cls.body:
if isinstance(n, ast.FunctionDef) and n.name == "__enter__":
enter = n
break
assert enter is not None
assert _count_with_cancel_lock_blocks(enter) == 1, (
"_TrackedCancel.__enter__ must acquire _CANCEL_LOCK exactly once. "
"A second acquisition for consume-pending lets a concurrent "
"cancel POST stash after consume sees an empty map, silently "
"dropping the cancel"
)
with_block = None
for sub in ast.walk(enter):
if isinstance(sub, ast.With) and any(
isinstance(i.context_expr, ast.Name) and i.context_expr.id == "_CANCEL_LOCK"
for i in sub.items
):
with_block = sub
break
assert with_block is not None
block_src = "\n".join(ast.unparse(s) for s in with_block.body)
assert "_CANCEL_REGISTRY.setdefault" in block_src
assert "_PENDING_CANCELS.pop" in block_src, (
"__enter__ critical section must consume from _PENDING_CANCELS "
"inside the same lock, not a later re-acquisition"
)
def test_cancel_inference_uses_atomic_helper_for_cancel_id_path():
fn = _find_function("cancel_inference")
src = ast.unparse(fn)
assert "_cancel_by_cancel_id_or_stash" in src
# The pre-fix two-step idiom must be gone.
assert "_remember_pending_cancel(cancel_id)" not in src, (
"two-step _cancel_by_keys + _remember_pending_cancel produced "
"the TOCTOU race and must not return"
)
_WANTED = {
"_CANCEL_REGISTRY",
"_CANCEL_LOCK",
"_PENDING_CANCELS",
"_PENDING_CANCEL_TTL_S",
"_prune_pending",
"_remember_pending_cancel",
"_TrackedCancel",
"_cancel_by_keys",
"_cancel_by_cancel_id_or_stash",
}
def _load_registry_module():
chunks = []
for n in _TREE.body:
seg = ast.get_source_segment(_SRC, n)
if seg is None:
continue
if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED:
chunks.append(seg)
elif isinstance(n, ast.Assign):
names = [t.id for t in n.targets if isinstance(t, ast.Name)]
if any(name in _WANTED for name in names):
chunks.append(seg)
elif (
isinstance(n, ast.AnnAssign)
and isinstance(n.target, ast.Name)
and n.target.id in _WANTED
):
chunks.append(seg)
mod = {}
exec(
"import threading, time\nfrom typing import Optional\n" + "\n\n".join(chunks),
mod,
)
return mod
def test_parallel_cancel_vs_register_never_drops():
m = _load_registry_module()
trials = 500
dropped = 0
for i in range(trials):
m["_CANCEL_REGISTRY"].clear()
m["_PENDING_CANCELS"].clear()
cid = f"cid-{i}"
ev = threading.Event()
tracker = m["_TrackedCancel"](ev, cid, "thread")
start = threading.Event()
def do_cancel():
start.wait()
m["_cancel_by_cancel_id_or_stash"](cid)
def do_enter():
start.wait()
tracker.__enter__()
threads = [
threading.Thread(target = do_cancel),
threading.Thread(target = do_enter),
]
random.shuffle(threads)
for t in threads:
t.start()
start.set()
for t in threads:
t.join(timeout = 5.0)
assert not t.is_alive()
if not ev.is_set():
dropped += 1
tracker.__exit__(None, None, None)
assert dropped == 0, (
f"TOCTOU regression: {dropped}/{trials} parallel trials silently "
f"dropped the cancel"
)
def test_cancel_before_register_replays_atomically():
m = _load_registry_module()
cid = "early-cid"
ev = threading.Event()
tracker = m["_TrackedCancel"](ev, cid, "thread-x")
assert m["_cancel_by_cancel_id_or_stash"](cid) == 0
assert cid in m["_PENDING_CANCELS"]
tracker.__enter__()
assert ev.is_set()
assert cid not in m["_PENDING_CANCELS"]
tracker.__exit__(None, None, None)
def test_cancel_after_register_signals_without_stash():
m = _load_registry_module()
cid = "post-cid"
ev = threading.Event()
tracker = m["_TrackedCancel"](ev, cid, "thread-y")
tracker.__enter__()
assert m["_cancel_by_cancel_id_or_stash"](cid) == 1
assert ev.is_set()
assert cid not in m["_PENDING_CANCELS"]
tracker.__exit__(None, None, None)
def test_cancel_by_keys_tolerates_empty_and_falsy_keys():
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
m["_PENDING_CANCELS"].clear()
assert m["_cancel_by_keys"]([]) == 0
assert m["_cancel_by_keys"](["", None, "unknown"]) == 0
# Non-stashing fallback must never leak into _PENDING_CANCELS.
assert m["_PENDING_CANCELS"] == {}
def test_cancel_by_keys_fans_out_to_all_streams_on_same_session():
# Compare mode and other flows launch concurrent streams under a
# shared session_id; a single session cancel POST must hit all of them.
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
m["_PENDING_CANCELS"].clear()
session = "shared-thread"
ev_a = threading.Event()
ev_b = threading.Event()
tracker_a = m["_TrackedCancel"](ev_a, "cancel-a", session, "chatcmpl-a")
tracker_b = m["_TrackedCancel"](ev_b, "cancel-b", session, "chatcmpl-b")
tracker_a.__enter__()
tracker_b.__enter__()
try:
assert m["_cancel_by_keys"]([session]) == 2
assert ev_a.is_set() and ev_b.is_set()
finally:
tracker_a.__exit__(None, None, None)
tracker_b.__exit__(None, None, None)
assert session not in m["_CANCEL_REGISTRY"]
def test_cancel_by_cancel_id_is_exclusive_to_single_run():
# cancel_id is per-run unique; cancelling run A must not touch run B
# even when both share a session_id.
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
m["_PENDING_CANCELS"].clear()
session = "shared-thread-2"
ev_a = threading.Event()
ev_b = threading.Event()
tracker_a = m["_TrackedCancel"](ev_a, "cancel-only-a", session, "chatcmpl-a")
tracker_b = m["_TrackedCancel"](ev_b, "cancel-only-b", session, "chatcmpl-b")
tracker_a.__enter__()
tracker_b.__enter__()
try:
assert m["_cancel_by_cancel_id_or_stash"]("cancel-only-a") == 1
assert ev_a.is_set()
assert not ev_b.is_set()
finally:
tracker_a.__exit__(None, None, None)
tracker_b.__exit__(None, None, None)
def test_tracked_cancel_exit_is_idempotent():
# Outer except BaseException + the generator's finally may both call
# __exit__ under certain race combos; must not raise.
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
m["_PENDING_CANCELS"].clear()
ev = threading.Event()
tracker = m["_TrackedCancel"](ev, "cid", "sess", "chatcmpl-x")
tracker.__enter__()
tracker.__exit__(None, None, None)
tracker.__exit__(None, None, None)
tracker.__exit__(None, None, None)
assert not m["_CANCEL_REGISTRY"]

View file

@ -0,0 +1,169 @@
"""
Wiring tests for the per-run cancel_id field.
A chat-thread-scoped session_id is not safe as a cancel key because a
late stop POST can match a subsequent run on the same thread. The fix
adds cancel_id (a fresh UUID per generation) that is sent both in the
completion payload and in the /api/inference/cancel body.
Verifies:
- ChatCompletionRequest exposes an Optional[str] `cancel_id` field.
- /api/inference/cancel accepts `cancel_id` as the first-preferred key.
- OpenAIChatCompletionsRequest (frontend type) includes cancel_id.
- chat-adapter.ts generates a per-run cancelId (crypto.randomUUID
with a Math.random fallback), sends it in the completion payload,
and includes it in the /inference/cancel body on abort.
"""
from __future__ import annotations
import ast
import re
from pathlib import Path
WORKSPACE = Path(__file__).resolve().parents[2]
MODELS_SRC = (WORKSPACE / "studio/backend/models/inference.py").read_text()
ROUTES_SRC = (WORKSPACE / "studio/backend/routes/inference.py").read_text()
ADAPTER_SRC = (
WORKSPACE / "studio/frontend/src/features/chat/api/chat-adapter.ts"
).read_text()
API_TYPES_SRC = (
WORKSPACE / "studio/frontend/src/features/chat/types/api.ts"
).read_text()
def _find_class(tree: ast.AST, name: str) -> ast.ClassDef | None:
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == name:
return node
return None
def test_chat_completion_request_has_cancel_id_field():
tree = ast.parse(MODELS_SRC)
cls = _find_class(tree, "ChatCompletionRequest")
assert cls is not None
fields = {
n.target.id
for n in cls.body
if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name)
}
assert "cancel_id" in fields, (
"ChatCompletionRequest must expose a cancel_id field for per-run "
"cancellation routing"
)
def test_cancel_route_matches_cancel_id_exclusively_when_present():
# A stale cancel POST carrying cancel_id AND session_id must not
# cancel a later run on the same thread via the shared session_id.
# Enforce this by requiring the handler to early-return through an
# exclusive-cancel_id path -- either an atomic helper or a keys
# list containing ONLY cancel_id (never session_id).
for node in ast.walk(ast.parse(ROUTES_SRC)):
if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference":
break
else:
raise AssertionError("cancel_inference handler missing")
cancel_id_exclusive_branch = False
for sub in ast.walk(node):
if not isinstance(sub, ast.If):
continue
test_src = ast.unparse(sub.test)
if "cancel_id" not in test_src or "isinstance" not in test_src:
continue
branch_src = "\n".join(ast.unparse(s) for s in sub.body)
before_return = branch_src.split("return", 1)[0]
matches_cancel_id_only = (
"_cancel_by_cancel_id_or_stash(cancel_id)" in branch_src
or "_cancel_by_keys([cancel_id])" in branch_src
)
if matches_cancel_id_only and "session_id" not in before_return:
cancel_id_exclusive_branch = True
break
assert cancel_id_exclusive_branch, (
"cancel_inference must early-return with an exclusive cancel_id "
"match when a cancel_id is supplied, so a stale stop POST "
"cannot cancel a later run on the same thread via session_id"
)
def test_cancel_route_falls_back_to_session_or_completion_when_no_cancel_id():
for node in ast.walk(ast.parse(ROUTES_SRC)):
if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference":
break
else:
raise AssertionError("cancel_inference handler missing")
src = ast.unparse(node)
assert "session_id" in src and "completion_id" in src, (
"cancel_inference must still accept session_id / completion_id as "
"fallback keys when cancel_id is absent"
)
def test_frontend_request_type_has_cancel_id():
assert re.search(
r"cancel_id\?\s*:\s*string\s*;", API_TYPES_SRC
), "OpenAIChatCompletionsRequest must expose an optional cancel_id"
def test_chat_adapter_generates_cancel_id_per_run():
m = re.search(
r"const\s+cancelId\s*=\s*([^;]+);",
ADAPTER_SRC,
)
assert m, "chat-adapter.ts must declare a per-run `cancelId` constant"
rhs = m.group(1)
assert (
"randomUUID" in rhs
), "cancelId should prefer crypto.randomUUID() for uniqueness"
def test_chat_adapter_sends_cancel_id_in_completion_payload():
assert "cancel_id: cancelId" in ADAPTER_SRC, (
"chat-adapter.ts must include cancel_id in the streamChatCompletions "
"request payload so the backend registers under that key"
)
def test_chat_adapter_sends_cancel_id_in_abort_cancel_post():
m = re.search(
r"const\s+onAbortCancel\s*=\s*\(\)\s*=>\s*\{(.*?)\};",
ADAPTER_SRC,
flags = re.DOTALL,
)
assert m, "onAbortCancel arrow function missing"
body = m.group(1)
assert re.search(r"cancel_id\s*:\s*cancelId", body), (
"onAbortCancel must include cancel_id in the /inference/cancel body "
"so a stop POST matches the specific run, not the whole thread"
)
def test_abort_cancel_post_uses_plain_fetch_with_manual_auth_header():
# authFetch redirects to login on 401, which would kick the user to
# the login page mid-stop if the access token expired during a long
# stream. Use plain fetch + manual Authorization header for a
# best-effort cancel that never triggers the refresh/redirect flow.
start = ADAPTER_SRC.find("const onAbortCancel")
assert start >= 0, "onAbortCancel handler missing"
rest = ADAPTER_SRC[start:]
end = rest.find("\n try {")
body = rest if end < 0 else rest[:end]
assert "/api/inference/cancel" in body
assert "authFetch(" not in body, (
"onAbortCancel must NOT call authFetch; a 401 from it would "
"redirect the user to the login page during a stop click"
)
assert "fetch(" in body, "onAbortCancel must use plain fetch(...)"
assert "getAuthToken" in body, (
"onAbortCancel must read the bearer token via getAuthToken() "
"rather than relying on authFetch's 401 flow"
)
assert "Authorization" in body
assert (
"keepalive: true" in body
), "keepalive is required so the fetch survives page unload during stop"

View file

@ -0,0 +1,123 @@
"""
Tests for the llama-server wall-clock cap (t_max_predict_ms).
The UI always sends max_tokens = context_length, so gating
t_max_predict_ms on `max_tokens is None` makes the safety net dead
code. The fix applies the wall-clock cap unconditionally on all three
streaming payload sites and raises the default to 10 minutes so slow
CPU / macOS / Windows installs are not cut off mid-generation.
Verifies:
- t_max_predict_ms is assigned unconditionally at the three
payload-builder sites (not inside an `if max_tokens is None` else
branch).
- _DEFAULT_T_MAX_PREDICT_MS is at least 10 minutes (previously
120_000).
- The default max_tokens path still applies _DEFAULT_MAX_TOKENS.
- The three payload variable names (payload x2, stream_payload x1)
each get both `max_tokens` and `t_max_predict_ms`.
"""
from __future__ import annotations
import ast
from pathlib import Path
SOURCE_PATH = (
Path(__file__).resolve().parents[2]
/ "studio"
/ "backend"
/ "core"
/ "inference"
/ "llama_cpp.py"
)
SRC = SOURCE_PATH.read_text()
TREE = ast.parse(SRC)
def _is_subscript_assign(stmt: ast.stmt, target_name: str, key: str) -> bool:
if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1:
return False
t = stmt.targets[0]
if not isinstance(t, ast.Subscript):
return False
if not (isinstance(t.value, ast.Name) and t.value.id == target_name):
return False
slc = t.slice
return isinstance(slc, ast.Constant) and slc.value == key
def _collect_assignments(tree, target_name, key):
"""Return list of (node, stack_of_enclosing_ifs) for each match."""
hits = []
def visit(node, stack):
if _is_subscript_assign(node, target_name, key):
hits.append((node, stack))
for child in ast.iter_child_nodes(node):
if isinstance(child, ast.If):
for sub in child.body:
visit(sub, stack + [(child, "body")])
for sub in child.orelse:
visit(sub, stack + [(child, "orelse")])
else:
visit(child, stack)
visit(tree, [])
return hits
def test_default_t_max_predict_ms_is_at_least_ten_minutes():
for node in TREE.body:
if isinstance(node, ast.Assign) and len(node.targets) == 1:
t = node.targets[0]
if isinstance(t, ast.Name) and t.id == "_DEFAULT_T_MAX_PREDICT_MS":
value = node.value
assert isinstance(value, ast.Constant)
assert value.value >= 600_000, (
f"_DEFAULT_T_MAX_PREDICT_MS must be >= 10 minutes "
f"(600_000 ms) to avoid cutting off slow-CPU generations; "
f"got {value.value}"
)
return
raise AssertionError("_DEFAULT_T_MAX_PREDICT_MS constant missing")
def test_t_max_predict_ms_set_unconditionally_at_three_sites():
hits_payload = _collect_assignments(TREE, "payload", "t_max_predict_ms")
hits_stream = _collect_assignments(TREE, "stream_payload", "t_max_predict_ms")
total = len(hits_payload) + len(hits_stream)
assert total == 3, (
f"expected 3 total t_max_predict_ms assignments "
f"(payload x2 + stream_payload x1), got {total}"
)
for node, stack in hits_payload + hits_stream:
for parent_if, branch in stack:
# The assignment must not be gated by a test that checks
# `max_tokens is None` (which would make it dead code for
# the UI path where max_tokens is always set).
test_src = ast.unparse(parent_if.test)
assert "max_tokens" not in test_src, (
f"t_max_predict_ms at line {node.lineno} is nested under "
f"`if {test_src}:` -- it must be applied unconditionally so "
f"the wall-clock cap is not dead code for callers that set "
f"max_tokens"
)
def test_max_tokens_default_cap_still_applied():
# _DEFAULT_MAX_TOKENS must still kick in when caller passes None.
# We check the conditional expression `max_tokens if max_tokens is not
# None else _DEFAULT_MAX_TOKENS` appears at each site.
matches = 0
for node in ast.walk(TREE):
if not isinstance(node, ast.IfExp):
continue
src = ast.unparse(node)
if "max_tokens" in src and "_DEFAULT_MAX_TOKENS" in src:
matches += 1
assert matches >= 3, (
f"expected >=3 `max_tokens if max_tokens is not None else "
f"_DEFAULT_MAX_TOKENS` expressions; got {matches}"
)

View file

@ -0,0 +1,718 @@
"""
Tests that the cancel tracker is registered BEFORE StreamingResponse is
returned, and that cleanup runs via a `finally` block inside each
async generator.
The zombie-generation scenario is: user clicks Stop during prefill /
warmup / proxy buffering, before the first SSE chunk. If _tracker
__enter__ lives inside the async generator body, the registry is empty
at the moment /api/inference/cancel lands -- so cancel returns 0 and
the decode runs to completion.
The fix moves _tracker = _TrackedCancel(...) and _tracker.__enter__()
to the synchronous body of openai_chat_completions (before the
StreamingResponse is returned) and places _tracker.__exit__ inside
each generator's `finally` block. Using a generator `finally` (rather
than a Starlette BackgroundTask) guarantees cleanup on every
termination path -- normal exhaustion, CancelledError from
ClientDisconnect, and OSError / BrokenPipeError during send() --
because Starlette skips `background` callbacks when stream_response
raises.
Structural verifies:
- No `async def ...:` body contains `_tracker.__enter__()` in
routes/inference.py (registration moved to sync body).
- Each of the four async generators (gguf_tool_stream,
gguf_stream_chunks, stream_chunks, audio_input_stream) contains
`_tracker.__exit__(None, None, None)` inside a try/finally block.
- No StreamingResponse in openai_chat_completions passes
`background=` (cleanup now lives in the generator finally).
Behavioral verifies (extracting `_TrackedCancel` from source and
exercising the actual runtime semantics):
- `finally: _tracker.__exit__(...)` runs on normal completion,
mid-stream exception (OSError / BrokenPipeError from send()),
and aclose() from Starlette ClientDisconnect.
- A pre-set cancel_event (from `_TrackedCancel.__enter__` replaying
a pending cancel POST) lets the GGUF while-loop break cleanly
and emit final_chunk + [DONE] instead of propagating
`GeneratorExit` out of `_stream_with_retry` into the async
generator's `except Exception` (which would not catch it).
"""
from __future__ import annotations
import ast
import asyncio
import threading
import time
from pathlib import Path
SOURCE_PATH = (
Path(__file__).resolve().parents[2]
/ "studio"
/ "backend"
/ "routes"
/ "inference.py"
)
SRC = SOURCE_PATH.read_text()
_TREE = ast.parse(SRC)
# ── Structural (AST) helpers ─────────────────────────────────
def _collect_async_functions(tree: ast.AST):
return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)]
def _has_tracker_enter_call(node: ast.AST) -> bool:
for sub in ast.walk(node):
if not isinstance(sub, ast.Call):
continue
fn = sub.func
if (
isinstance(fn, ast.Attribute)
and fn.attr == "__enter__"
and isinstance(fn.value, ast.Name)
and fn.value.id.startswith("_tracker")
):
return True
return False
def _finalbody_has_tracker_exit(finalbody) -> bool:
for stmt in finalbody:
if not isinstance(stmt, ast.Expr):
continue
call = stmt.value
if not (isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute)):
continue
fn = call.func
if (
fn.attr == "__exit__"
and isinstance(fn.value, ast.Name)
and fn.value.id.startswith("_tracker")
):
return True
return False
# ── Structural tests ─────────────────────────────────────────
def test_no_tracker_enter_inside_async_generators():
offenders = []
for fn in _collect_async_functions(_TREE):
if fn.name in {
"gguf_tool_stream",
"gguf_stream_chunks",
"stream_chunks",
"audio_input_stream",
}:
if _has_tracker_enter_call(fn):
offenders.append(fn.name)
assert not offenders, (
f"Cancel tracker registration must live OUTSIDE the async generator "
f"body so a stop POST can find the registry entry before the first "
f"SSE chunk. Offending generators: {offenders}"
)
def test_tracker_enter_exists_in_sync_body_of_chat_completions():
top = None
for n in ast.walk(_TREE):
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
top = n
break
assert top is not None, "openai_chat_completions handler missing"
count = 0
for sub in ast.walk(top):
if not isinstance(sub, ast.Call):
continue
fn = sub.func
if (
isinstance(fn, ast.Attribute)
and fn.attr == "__enter__"
and isinstance(fn.value, ast.Name)
and fn.value.id.startswith("_tracker")
):
count += 1
assert count >= 3, (
f"expected >=3 _tracker.__enter__() calls in openai_chat_completions "
f"(one per streaming path), got {count}"
)
def test_async_generators_cleanup_tracker_in_finally():
required = {
"gguf_tool_stream",
"gguf_stream_chunks",
"stream_chunks",
"audio_input_stream",
}
found: set[str] = set()
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
if fn.name not in required:
continue
for sub in ast.walk(fn):
if isinstance(sub, ast.Try) and sub.finalbody:
if _finalbody_has_tracker_exit(sub.finalbody):
found.add(fn.name)
break
missing = required - found
assert not missing, (
f"Cleanup must run via `finally: _tracker.__exit__(None, None, None)` "
f"inside each streaming generator so ClientDisconnect / OSError paths "
f"also release registry entries (Starlette skips `background` callbacks "
f"when stream_response raises). Missing in: {sorted(missing)}"
)
def test_streaming_responses_have_no_background_task():
top = None
for n in ast.walk(_TREE):
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
top = n
break
assert top is not None
for sub in ast.walk(top):
if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)):
continue
if sub.func.id != "StreamingResponse":
continue
kwargs = {kw.arg for kw in sub.keywords if kw.arg}
assert "background" not in kwargs, (
"StreamingResponse in openai_chat_completions must not pass "
"`background=` -- cleanup now lives in the generator's finally "
"block; a BackgroundTask would be skipped on abrupt disconnect"
)
# ── Behavioral helpers ───────────────────────────────────────
_WANTED = {
"_CANCEL_REGISTRY",
"_CANCEL_LOCK",
"_PENDING_CANCELS",
"_PENDING_CANCEL_TTL_S",
"_prune_pending",
"_TrackedCancel",
"_cancel_by_keys",
"_cancel_by_cancel_id_or_stash",
}
def _load_registry_module():
chunks = []
for n in _TREE.body:
seg = ast.get_source_segment(SRC, n)
if seg is None:
continue
if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED:
chunks.append(seg)
elif isinstance(n, ast.Assign):
names = [t.id for t in n.targets if isinstance(t, ast.Name)]
if any(name in _WANTED for name in names):
chunks.append(seg)
elif (
isinstance(n, ast.AnnAssign)
and isinstance(n.target, ast.Name)
and n.target.id in _WANTED
):
chunks.append(seg)
mod = {}
exec("import threading, time\n" + "\n\n".join(chunks), mod)
return mod
def _make_stream(tracker, raise_exc):
async def gen():
try:
try:
yield "data: first\n\n"
if raise_exc is not None:
raise raise_exc
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
raise
except Exception:
yield "data: error\n\n"
finally:
tracker.__exit__(None, None, None)
except BaseException:
raise
return gen()
async def _consume(agen):
out = []
try:
async for ch in agen:
out.append(ch)
except BaseException as e:
out.append(type(e).__name__)
return out
def _llama_stub_raises_on_preset_cancel(cancel_event):
# Reproduces llama_cpp.py _stream_with_retry:2240 `raise GeneratorExit`
# when cancel_event is already set at entry.
if cancel_event.is_set():
raise GeneratorExit
yield "cumulative-1"
yield "cumulative-2"
async def _post_fix_gguf_loop(cancel_event):
yield "first_chunk"
gen = _llama_stub_raises_on_preset_cancel(cancel_event)
sentinel = object()
while True:
if cancel_event.is_set():
break
cumulative = await asyncio.to_thread(next, gen, sentinel)
if cumulative is sentinel:
break
yield cumulative
yield "final_chunk"
yield "[DONE]"
# ── Behavioral tests ─────────────────────────────────────────
def test_finally_cleanup_on_normal_completion():
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
ev = threading.Event()
tr = m["_TrackedCancel"](ev, "cid-ok", "sid-ok")
tr.__enter__()
assert "cid-ok" in m["_CANCEL_REGISTRY"]
chunks = asyncio.run(_consume(_make_stream(tr, None)))
assert chunks == ["data: first\n\n", "data: [DONE]\n\n"]
assert "cid-ok" not in m["_CANCEL_REGISTRY"]
assert "sid-ok" not in m["_CANCEL_REGISTRY"]
def test_finally_cleanup_on_mid_stream_exception():
# Simulates OSError / BrokenPipeError from Starlette send() mid-stream --
# the exact case where pre-fix `background = BackgroundTask(...)` was
# skipped and leaked the registry entry.
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
ev = threading.Event()
tr = m["_TrackedCancel"](ev, "cid-err", "sid-err")
tr.__enter__()
assert "cid-err" in m["_CANCEL_REGISTRY"]
asyncio.run(_consume(_make_stream(tr, OSError("disconnect"))))
assert "cid-err" not in m["_CANCEL_REGISTRY"]
assert "sid-err" not in m["_CANCEL_REGISTRY"]
def test_finally_cleanup_on_aclose():
# Starlette calls aclose() on the async generator when the client
# disconnects mid-stream. The generator's finally block must run.
m = _load_registry_module()
m["_CANCEL_REGISTRY"].clear()
ev = threading.Event()
tr = m["_TrackedCancel"](ev, "cid-abort", "sid-abort")
tr.__enter__()
assert "cid-abort" in m["_CANCEL_REGISTRY"]
async def run():
gen = _make_stream(tr, None)
it = gen.__aiter__()
await it.__anext__()
await gen.aclose()
asyncio.run(run())
assert "cid-abort" not in m["_CANCEL_REGISTRY"]
assert "sid-abort" not in m["_CANCEL_REGISTRY"]
def test_preset_cancel_event_exits_cleanly_with_done():
# Pending-replay: POST /cancel arrived before the stream registered,
# was stashed, then consumed by _TrackedCancel.__enter__ which set
# cancel_event. The generator must break out of the loop cleanly
# and emit final_chunk + [DONE] rather than calling next(gen) and
# propagating `GeneratorExit` out of the GGUF stream wrapper.
ev = threading.Event()
ev.set()
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
assert "first_chunk" in chunks
assert "final_chunk" in chunks
assert "[DONE]" in chunks
assert "GeneratorExit" not in chunks
assert "cumulative-1" not in chunks
assert "cumulative-2" not in chunks
def test_normal_path_streams_all_tokens():
# Regression: the top-of-loop cancel_event check must not short-circuit
# when cancel_event is unset.
ev = threading.Event()
chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev)))
assert chunks == [
"first_chunk",
"cumulative-1",
"cumulative-2",
"final_chunk",
"[DONE]",
]
def test_cancel_during_streaming_stops_iteration_promptly():
# Setting cancel_event between yields breaks out on the next iteration
# rather than draining the stub generator.
ev = threading.Event()
async def _run():
gen = _post_fix_gguf_loop(ev)
seen = []
async for ch in gen:
seen.append(ch)
if ch == "cumulative-1":
ev.set()
return seen
seen = asyncio.run(_run())
assert "first_chunk" in seen
assert "cumulative-1" in seen
assert "cumulative-2" not in seen
assert "final_chunk" in seen
assert "[DONE]" in seen
# ── Cancel-event responsiveness in the streaming loops ───────
def _loop_has_cancel_event_check(fn) -> bool:
# An `if cancel_event.is_set():` statement anywhere inside a
# `while`/`for` loop body is sufficient -- without it, a cancel POST
# cannot interrupt the loop because Colab-style proxies do not
# propagate request.is_disconnected().
for sub in ast.walk(fn):
if not isinstance(sub, (ast.While, ast.For, ast.AsyncFor)):
continue
for stmt in ast.walk(sub):
if not isinstance(stmt, ast.If):
continue
t = stmt.test
if (
isinstance(t, ast.Call)
and isinstance(t.func, ast.Attribute)
and t.func.attr == "is_set"
and isinstance(t.func.value, ast.Name)
and t.func.value.id == "cancel_event"
):
return True
return False
def test_streaming_generators_check_cancel_event_in_loop():
required = {
"gguf_tool_stream",
"gguf_stream_chunks",
"stream_chunks",
"audio_input_stream",
}
missing = []
for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]:
if fn.name not in required:
continue
if not _loop_has_cancel_event_check(fn):
missing.append(fn.name)
assert not missing, (
f"Each streaming generator must check `cancel_event.is_set()` inside "
f"its main loop so `POST /api/inference/cancel` can interrupt the "
f"stream through proxies that do not forward fetch aborts. "
f"Missing in: {sorted(missing)}"
)
def test_audio_input_stream_offloads_blocking_next_to_thread():
# Guards against regression back to `for chunk_text in
# audio_input_generate():` -- which blocks the event loop on each
# whisper chunk and prevents POST /api/inference/cancel from being
# serviced until the chunk yields.
audio = None
for fn in ast.walk(_TREE):
if isinstance(fn, ast.AsyncFunctionDef) and fn.name == "audio_input_stream":
audio = fn
break
assert audio is not None, "audio_input_stream generator missing"
for sub in ast.walk(audio):
if isinstance(sub, (ast.For, ast.AsyncFor)):
it_src = ast.unparse(sub.iter)
assert "audio_input_generate" not in it_src, (
"audio_input_stream must not iterate audio_input_generate() "
"directly -- that blocks the event loop. Use "
"`await asyncio.to_thread(next, gen, _DONE)` inside a "
"`while True` loop instead"
)
found_to_thread_next = False
for sub in ast.walk(audio):
if not isinstance(sub, ast.Call):
continue
fn_expr = sub.func
if not (
isinstance(fn_expr, ast.Attribute)
and fn_expr.attr == "to_thread"
and isinstance(fn_expr.value, ast.Name)
and fn_expr.value.id == "asyncio"
):
continue
if sub.args and isinstance(sub.args[0], ast.Name) and sub.args[0].id == "next":
found_to_thread_next = True
break
assert found_to_thread_next, (
"audio_input_stream must call `asyncio.to_thread(next, gen, ...)` "
"to keep the event loop free while whisper yields the next chunk"
)
def test_stream_chunks_cancel_branch_resets_backend_state():
# The Unsloth path's cancel branch must flush GPU / KV-cache state
# via `backend.reset_generation_state()` -- the orchestrator's
# internal cancel path does not do this, so a cancel-via-POST that
# only broke the loop would leave the subprocess in a dirty state
# for the next request.
fn = None
top = None
for n in ast.walk(_TREE):
if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions":
top = n
break
assert top is not None
for n in ast.walk(top):
if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks":
fn = n
break
assert fn is not None, "stream_chunks generator missing"
for sub in ast.walk(fn):
if not isinstance(sub, ast.If):
continue
t = sub.test
if not (
isinstance(t, ast.Call)
and isinstance(t.func, ast.Attribute)
and t.func.attr == "is_set"
and isinstance(t.func.value, ast.Name)
and t.func.value.id == "cancel_event"
):
continue
body_src = "\n".join(ast.unparse(s) for s in sub.body)
if "backend.reset_generation_state()" in body_src:
return
raise AssertionError(
"stream_chunks `if cancel_event.is_set():` branch must call "
"backend.reset_generation_state() -- matches the existing "
"request.is_disconnected() / CancelledError cleanup paths and "
"prevents KV-cache drift after cancel-via-POST"
)
# ── Behavioral simulations for the iter-1 fixes ──────────────
def test_unsloth_stream_loop_breaks_on_external_cancel_event():
cancel_event = threading.Event()
reset_calls = [0]
class _Backend:
def reset_generation_state(self):
reset_calls[0] += 1
backend = _Backend()
def _generate():
for i in range(200):
time.sleep(0.005)
yield f"cum-{i}"
async def _loop():
_DONE = object()
loop = asyncio.get_event_loop()
gen = _generate()
seen = []
while True:
if cancel_event.is_set():
backend.reset_generation_state()
break
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
if cumulative is _DONE:
break
seen.append(cumulative)
return seen
async def _fire():
await asyncio.sleep(0.05)
cancel_event.set()
async def _main():
return await asyncio.gather(_loop(), _fire())
seen, _ = asyncio.run(_main())
assert (
len(seen) < 200
), f"loop must not drain the generator after cancel; got {len(seen)} tokens"
assert reset_calls[0] == 1, (
f"backend.reset_generation_state() must be called exactly once on "
f"cancel-via-POST, got {reset_calls[0]}"
)
def test_audio_stream_stays_responsive_under_blocking_next():
# Regression guard: replace the post-fix loop with the pre-fix
# `for chunk in audio_input_generate()` pattern and assert it blocks
# the event loop; then confirm the post-fix pattern exits promptly.
cancel_event = threading.Event()
def _audio_gen():
for i in range(8):
time.sleep(0.15)
yield f"chunk-{i}"
async def _prefix_loop():
seen = []
for chunk_text in _audio_gen():
if cancel_event.is_set():
break
seen.append(chunk_text)
return seen
async def _postfix_loop():
_DONE = object()
gen = _audio_gen()
seen = []
while True:
if cancel_event.is_set():
break
chunk_text = await asyncio.to_thread(next, gen, _DONE)
if chunk_text is _DONE:
break
seen.append(chunk_text)
return seen
async def _fire_early():
await asyncio.sleep(0.05)
cancel_event.set()
async def _run(loop_coro):
return await asyncio.gather(loop_coro, _fire_early())
cancel_event.clear()
t0 = time.monotonic()
prefix_seen, _ = asyncio.run(_run(_prefix_loop()))
prefix_elapsed = time.monotonic() - t0
assert prefix_elapsed >= 0.13, (
f"pre-fix pattern should block event loop for >=1 chunk time "
f"(~150ms); got {prefix_elapsed:.3f}s, {len(prefix_seen)} chunks"
)
cancel_event.clear()
t0 = time.monotonic()
postfix_seen, _ = asyncio.run(_run(_postfix_loop()))
postfix_elapsed = time.monotonic() - t0
assert postfix_elapsed < prefix_elapsed, (
f"post-fix pattern must exit faster than pre-fix (blocking) "
f"pattern; post={postfix_elapsed:.3f}s vs pre={prefix_elapsed:.3f}s"
)
assert (
len(postfix_seen) < 8
), f"post-fix loop must not drain all chunks; got {len(postfix_seen)}"
def test_unsloth_stream_loop_emits_zero_tokens_on_preset_cancel():
# Pending-cancel replay: _TrackedCancel.__enter__ already set
# cancel_event before the generator body starts iterating. The
# top-of-loop check must short-circuit the very first iteration so
# no token is emitted. Catches a regression that moves the check
# below `next()` -- the mid-loop test would still pass but this
# test would observe one extra token leak.
cancel_event = threading.Event()
cancel_event.set()
reset_calls = [0]
class _Backend:
def reset_generation_state(self):
reset_calls[0] += 1
backend = _Backend()
next_calls = [0]
def _generate():
while True:
next_calls[0] += 1
yield f"cum-{next_calls[0]}"
async def _loop():
_DONE = object()
loop = asyncio.get_event_loop()
gen = _generate()
seen = []
while True:
if cancel_event.is_set():
backend.reset_generation_state()
break
cumulative = await loop.run_in_executor(None, next, gen, _DONE)
if cumulative is _DONE:
break
seen.append(cumulative)
return seen
seen = asyncio.run(_loop())
assert seen == [], (
f"loop must emit zero tokens when cancel_event is pre-set "
f"(pending-replay path); got {seen}"
)
assert next_calls[0] == 0, (
f"loop must not call next() at all on pre-set cancel; got "
f"{next_calls[0]} calls"
)
assert reset_calls[0] == 1, (
f"backend.reset_generation_state() must still fire exactly once "
f"on pre-set cancel; got {reset_calls[0]}"
)
def test_audio_stream_emits_zero_chunks_on_preset_cancel():
# Symmetric to the Unsloth pre-set test: the audio loop's top-of-loop
# cancel check must skip the asyncio.to_thread(next, ...) call when
# cancel_event was already set via pending-replay.
cancel_event = threading.Event()
cancel_event.set()
next_calls = [0]
def _audio_gen():
while True:
next_calls[0] += 1
yield f"chunk-{next_calls[0]}"
async def _loop():
_DONE = object()
gen = _audio_gen()
seen = []
while True:
if cancel_event.is_set():
break
chunk_text = await asyncio.to_thread(next, gen, _DONE)
if chunk_text is _DONE:
break
seen.append(chunk_text)
return seen
seen = asyncio.run(_loop())
assert seen == [], f"audio loop must emit zero chunks on pre-set cancel; got {seen}"
assert next_calls[0] == 0, (
f"audio loop must not call next() on pre-set cancel; got "
f"{next_calls[0]} calls"
)