diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 35ef90ee3..81ff9ae4e 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -53,6 +53,21 @@ _INTENT_SIGNAL = re.compile( r")" ) _MAX_REPROMPTS = 3 + +# Without max_tokens, llama-server defaults to n_predict = n_ctx (up to +# 262144 for Qwen3.5), producing many-minute zombie decodes when cancel +# fails. t_max_predict_ms is a wall-clock backstop applied unconditionally, +# but the llama.cpp README notes it ONLY fires after a newline has been +# generated -- a model stuck in a long unbroken non-newline sequence is +# unbounded by it. So we still want a token cap as the front-line limiter. +# +# The cap is the model's effective context length when we know it, +# falling back to a generous floor when metadata is unavailable. 4096 was +# too low: Qwen3 / gpt-oss reasoning traces routinely exceed it, and any +# OpenAI-API caller that omits max_tokens (langchain, llama-index, raw +# curl) sees responses silently truncated mid-sentence. +_DEFAULT_MAX_TOKENS_FLOOR = 32768 +_DEFAULT_T_MAX_PREDICT_MS = 600_000 # 10 min _REPROMPT_MAX_CHARS = 2000 # ── Pre-compiled patterns for GGUF shard detection ─────────── @@ -1636,7 +1651,7 @@ class LlamaCppBackend: # existing text (code refactoring, summarization, reasoning). # For general chat with low repetition, overhead is ~5 ms. # - # Benchmarks from llama.cpp PRs #18471, #19164: + # Benchmarks from upstream llama.cpp speculative-decoding PRs: # Scenario | Without | With | Speedup # gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x # Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x @@ -2549,8 +2564,15 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - payload["max_tokens"] = max_tokens + # Default cap to the model's effective context length when known, + # otherwise the conservative floor. The wall-clock backstop below + # keeps a stuck model from running indefinitely either way. + payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: payload["stop"] = stop payload["stream_options"] = {"include_usage": True} @@ -2570,7 +2592,9 @@ class LlamaCppBackend: _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) + ) as client: with self._stream_with_retry( client, url, @@ -2769,8 +2793,12 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - payload["max_tokens"] = max_tokens + payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: payload["stop"] = stop @@ -2809,7 +2837,10 @@ class LlamaCppBackend: write = 10, pool = 10, ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, + limits = httpx.Limits(max_keepalive_connections = 0), + ) as client: with self._stream_with_retry( client, url, @@ -3422,8 +3453,12 @@ class LlamaCppBackend: ) if _reasoning_kw is not None: stream_payload["chat_template_kwargs"] = _reasoning_kw - if max_tokens is not None: - stream_payload["max_tokens"] = max_tokens + stream_payload["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (self._effective_context_length or _DEFAULT_MAX_TOKENS_FLOOR) + ) + stream_payload["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: stream_payload["stop"] = stop stream_payload["stream_options"] = {"include_usage": True} @@ -3442,7 +3477,9 @@ class LlamaCppBackend: _auth_headers = ( {"Authorization": f"Bearer {self._api_key}"} if self._api_key else None ) - with httpx.Client(timeout = stream_timeout) as client: + with httpx.Client( + timeout = stream_timeout, limits = httpx.Limits(max_keepalive_connections = 0) + ) as client: with self._stream_with_retry( client, url, diff --git a/studio/backend/main.py b/studio/backend/main.py index 05adcaa2e..9212404b3 100644 --- a/studio/backend/main.py +++ b/studio/backend/main.py @@ -62,6 +62,7 @@ from routes import ( datasets_router, export_router, inference_router, + inference_studio_router, models_router, training_history_router, training_router, @@ -207,6 +208,9 @@ app.include_router(auth_router, prefix = "/api/auth", tags = ["auth"]) app.include_router(training_router, prefix = "/api/train", tags = ["training"]) app.include_router(models_router, prefix = "/api/models", tags = ["models"]) app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"]) +# Studio-only inference endpoints (cancel, etc.) are intentionally NOT +# exposed on the /v1 OpenAI-compat prefix below. +app.include_router(inference_studio_router, prefix = "/api/inference", tags = ["inference"]) # OpenAI-compatible endpoints: mount the same inference router at /v1 # so external tools (Open WebUI, SillyTavern, etc.) can use the diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index e5b037755..bf0177efb 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -531,6 +531,10 @@ class ChatCompletionRequest(BaseModel): None, description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.", ) + cancel_id: Optional[str] = Field( + None, + description = "[x-unsloth] Per-request cancellation token. Frontend sends a fresh UUID per run so /inference/cancel matches one specific generation.", + ) # ── Streaming response chunks ──────────────────────────────────── @@ -992,6 +996,7 @@ class AnthropicMessagesRequest(BaseModel): enable_tools: Optional[bool] = None enabled_tools: Optional[list[str]] = None session_id: Optional[str] = None + cancel_id: Optional[str] = None model_config = {"extra": "allow"} diff --git a/studio/backend/routes/__init__.py b/studio/backend/routes/__init__.py index e79f6553f..cf4586281 100644 --- a/studio/backend/routes/__init__.py +++ b/studio/backend/routes/__init__.py @@ -8,6 +8,7 @@ API Routes from routes.training import router as training_router from routes.models import router as models_router from routes.inference import router as inference_router +from routes.inference import studio_router as inference_studio_router from routes.datasets import router as datasets_router from routes.auth import router as auth_router from routes.data_recipe import router as data_recipe_router @@ -18,6 +19,7 @@ __all__ = [ "training_router", "models_router", "inference_router", + "inference_studio_router", "datasets_router", "auth_router", "data_recipe_router", diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index ed331a566..cf3b37a2f 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -113,7 +113,12 @@ if str(backend_path) not in sys.path: # Import backend functions try: from core.inference import get_inference_backend - from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags + from core.inference.llama_cpp import ( + LlamaCppBackend, + _DEFAULT_MAX_TOKENS_FLOOR, + _DEFAULT_T_MAX_PREDICT_MS, + detect_reasoning_flags, + ) from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults @@ -122,7 +127,12 @@ except ImportError: if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.inference import get_inference_backend - from core.inference.llama_cpp import LlamaCppBackend, detect_reasoning_flags + from core.inference.llama_cpp import ( + LlamaCppBackend, + _DEFAULT_MAX_TOKENS_FLOOR, + _DEFAULT_T_MAX_PREDICT_MS, + detect_reasoning_flags, + ) from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults @@ -185,6 +195,126 @@ import numpy as np from datetime import date as _date router = APIRouter() +# Studio-only router (not mounted on /v1 OpenAI-compat). +studio_router = APIRouter() + + +# Cancel registry. Proxies (e.g. Colab) can swallow client fetch aborts +# so is_disconnected() never fires. POST /inference/cancel looks up +# in-flight cancel_events here by cancel_id (per-run) or session_id / +# completion_id (fallbacks). +_CANCEL_REGISTRY: dict[str, set[threading.Event]] = {} +_CANCEL_LOCK = threading.Lock() + +# Cancel POSTs that arrive before registration are stashed; the next +# matching __enter__ replays set() within the TTL. +_PENDING_CANCELS: dict[str, float] = {} +_PENDING_CANCEL_TTL_S = 30.0 + + +def _prune_pending(now: float) -> None: + for k in [ + k for k, ts in _PENDING_CANCELS.items() if now - ts > _PENDING_CANCEL_TTL_S + ]: + _PENDING_CANCELS.pop(k, None) + + +class _TrackedCancel: + """Register cancel_event in _CANCEL_REGISTRY for the block's duration.""" + + def __init__(self, event: threading.Event, *keys): + self.event = event + self.keys = tuple(k for k in keys if k) + + def __enter__(self): + # Register + consume-pending must be one critical section to close + # the TOCTOU race against a concurrent cancel POST. + should_cancel = False + with _CANCEL_LOCK: + for k in self.keys: + _CANCEL_REGISTRY.setdefault(k, set()).add(self.event) + now = time.monotonic() + _prune_pending(now) + for k in self.keys: + if k and _PENDING_CANCELS.pop(k, None) is not None: + should_cancel = True + if should_cancel: + self.event.set() + return self.event + + def __exit__(self, *exc): + with _CANCEL_LOCK: + for k in self.keys: + bucket = _CANCEL_REGISTRY.get(k) + if bucket is None: + continue + bucket.discard(self.event) + if not bucket: + _CANCEL_REGISTRY.pop(k, None) + return False + + +def _cancel_by_keys(keys) -> int: + """Set cancel_event for matching registry entries; no stash. + session_id/completion_id are shared across runs on the same thread, + so stashing them would ghost-cancel the user's next request. Only + cancel_id is per-run unique (see _cancel_by_cancel_id_or_stash).""" + if not keys: + return 0 + events: set[threading.Event] = set() + with _CANCEL_LOCK: + _prune_pending(time.monotonic()) + for k in keys: + bucket = _CANCEL_REGISTRY.get(k) + if bucket: + events.update(bucket) + for ev in events: + ev.set() + return len(events) + + +def _cancel_by_cancel_id_or_stash(cancel_id: str) -> int: + """Atomic lookup-or-stash; pairs with _TrackedCancel.__enter__ to + close the TOCTOU race.""" + now = time.monotonic() + events: set[threading.Event] = set() + with _CANCEL_LOCK: + _prune_pending(now) + bucket = _CANCEL_REGISTRY.get(cancel_id) + if bucket: + events.update(bucket) + else: + _PENDING_CANCELS[cancel_id] = now + for ev in events: + ev.set() + return len(events) + + +async def _await_cancel_then_close(cancel_event, resp) -> None: + """Watch a threading.Event from asyncio and close ``resp`` when it fires. + + Used by the passthrough streamers so a /cancel POST can interrupt + while the async iterator is blocked waiting for llama-server prefill. + Without this watcher the in-loop ``cancel_event.is_set()`` check is + unreachable until the first SSE chunk arrives, which is exactly the + proxy/Colab scenario the cancel POST exists to handle. + + Polls a threading.Event because the cancel registry is keyed by + threading.Event so the synchronous /cancel handler can call .set(). + 50ms cadence adds at most that much latency to a prefill cancel; the + common-case streaming cancel path still observes the event in the + iterator's first iteration after the next chunk. + """ + try: + while not cancel_event.is_set(): + await asyncio.sleep(0.05) + try: + await resp.aclose() + except Exception: + pass + except asyncio.CancelledError: + return + # Appended to tool-use nudge to discourage plan-without-action _TOOL_ACTION_NUDGE = ( @@ -706,6 +836,48 @@ async def unload_model( raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}") +@studio_router.post("/cancel") +async def cancel_inference( + request: Request, + current_subject: str = Depends(get_current_subject), +): + """Cancel in-flight inference requests. + + Body (JSON, at least one key required): + cancel_id - preferred: per-run UUID, matched exclusively. + session_id - fallback when cancel_id is absent. + completion_id - fallback when cancel_id is absent. + + A cancel_id arriving before its stream registers is stashed briefly + and replayed on registration. Returns {"cancelled": N}. + """ + try: + body = await request.json() + if not isinstance(body, dict): + body = {} + except Exception as e: + logger.debug("Failed to parse cancel request body: %s", e) + body = {} + + cancel_id = body.get("cancel_id") + if isinstance(cancel_id, str) and cancel_id: + return {"cancelled": _cancel_by_cancel_id_or_stash(cancel_id)} + + keys = [] + # `message_id` is the Anthropic passthrough's per-run identifier -- + # included so /v1/messages clients can cancel by their native id. + for k in ("completion_id", "session_id", "message_id"): + v = body.get(k) + if isinstance(v, str) and v: + keys.append(v) + + if not keys: + return {"cancelled": 0} + + n = _cancel_by_keys(keys) + return {"cancelled": n} + + @router.post("/generate/stream") async def generate_stream( request: GenerateRequest, @@ -1169,6 +1341,9 @@ async def openai_chat_completions( ) if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def audio_input_stream(): try: @@ -1185,10 +1360,17 @@ async def openai_chat_completions( ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" - for chunk_text in audio_input_generate(): + gen = audio_input_generate() + _DONE = object() + while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break if chunk_text: chunk = ChatCompletionChunk( id = completion_id, @@ -1221,6 +1403,8 @@ async def openai_chat_completions( f"Error during audio input streaming: {e}", exc_info = True ) yield f"data: {json.dumps({'error': {'message': _friendly_error(e), 'type': 'server_error'}})}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( audio_input_stream(), @@ -1466,6 +1650,10 @@ async def openai_chat_completions( _tool_sentinel = object() + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() + async def gguf_tool_stream(): try: first_chunk = ChatCompletionChunk( @@ -1488,6 +1676,8 @@ async def openai_chat_completions( _stream_usage = None _stream_timings = None while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return @@ -1595,6 +1785,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( gguf_tool_stream(), @@ -1628,6 +1820,9 @@ async def openai_chat_completions( _gguf_sentinel = object() if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def gguf_stream_chunks(): try: @@ -1652,6 +1847,8 @@ async def openai_chat_completions( _stream_usage = None _stream_timings = None while True: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() return @@ -1735,6 +1932,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( gguf_stream_chunks(), @@ -1834,6 +2033,9 @@ async def openai_chat_completions( # ── Streaming response ──────────────────────────────────────── if payload.stream: + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() async def stream_chunks(): try: @@ -1861,6 +2063,9 @@ async def openai_chat_completions( loop = asyncio.get_event_loop() gen = generate() while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break # next(gen, _DONE) returns _DONE instead of raising # StopIteration — StopIteration cannot propagate # through asyncio futures (Python limitation). @@ -1916,6 +2121,8 @@ async def openai_chat_completions( }, } yield f"data: {json.dumps(error_chunk)}\n\n" + finally: + _tracker.__exit__(None, None, None) return StreamingResponse( stream_chunks(), @@ -2596,7 +2803,9 @@ async def _responses_stream( ), ) - body = _build_openai_passthrough_body(chat_req) + body = _build_openai_passthrough_body( + chat_req, backend_ctx = llama_backend.context_length + ) target_url = f"{llama_backend.base_url}/v1/chat/completions" async def event_generator(): @@ -3081,6 +3290,8 @@ async def anthropic_messages( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = openai_tool_choice, + session_id = payload.session_id, + cancel_id = payload.cancel_id, ) return await _anthropic_passthrough_non_streaming( llama_backend, @@ -3441,6 +3652,7 @@ def _build_passthrough_payload( repetition_penalty = None, presence_penalty = None, tool_choice = "auto", + backend_ctx = None, ): body = { "messages": openai_messages, @@ -3453,8 +3665,12 @@ def _build_passthrough_payload( } if stream: body["stream_options"] = {"include_usage": True} - if max_tokens is not None: - body["max_tokens"] = max_tokens + body["max_tokens"] = ( + max_tokens + if max_tokens is not None + else (backend_ctx or _DEFAULT_MAX_TOKENS_FLOOR) + ) + body["t_max_predict_ms"] = _DEFAULT_T_MAX_PREDICT_MS if stop: body["stop"] = stop if min_p is not None: @@ -3484,6 +3700,8 @@ async def _anthropic_passthrough_stream( repetition_penalty = None, presence_penalty = None, tool_choice = "auto", + session_id = None, + cancel_id = None, ): """Streaming client-side pass-through: forward tools to llama-server and translate its streaming response to Anthropic SSE without executing anything.""" @@ -3501,8 +3719,14 @@ async def _anthropic_passthrough_stream( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = tool_choice, + backend_ctx = llama_backend.context_length, ) + # cancel_id mirrors the OpenAI passthrough so a per-run cancel POST + # works without the caller having to know the local message_id. + _tracker = _TrackedCancel(cancel_event, cancel_id, session_id, message_id) + _tracker.__enter__() + async def _stream(): emitter = AnthropicPassthroughEmitter() for line in emitter.start(message_id, model_name): @@ -3535,15 +3759,28 @@ async def _anthropic_passthrough_stream( # has anything orphaned to finalize. Each aclose is wrapped in # `try: ... except Exception: pass` so anyio cleanup noise from # nested aclose paths can't bubble out. - client = httpx.AsyncClient(timeout = 600) + client = httpx.AsyncClient( + timeout = 600, + limits = httpx.Limits(max_keepalive_connections = 0), + ) resp = None lines_iter = None + cancel_watcher = None try: req = client.build_request("POST", target_url, json = body) resp = await client.send(req, stream = True) + # See _openai_passthrough_stream for rationale: aiter_lines() + # blocks during llama-server prefill, so the in-loop cancel + # check is unreachable until the first SSE chunk arrives. + # The watcher closes `resp` on cancel, raising in aiter_lines. + cancel_watcher = asyncio.create_task( + _await_cancel_then_close(cancel_event, resp) + ) lines_iter = resp.aiter_lines() async for raw_line in lines_iter: + if cancel_event.is_set(): + break if await request.is_disconnected(): cancel_event.set() break @@ -3558,9 +3795,18 @@ async def _anthropic_passthrough_stream( continue for line in emitter.feed_chunk(chunk): yield line + except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): + if not cancel_event.is_set(): + raise except Exception as e: logger.error("anthropic_messages passthrough stream error: %s", e) finally: + if cancel_watcher is not None: + cancel_watcher.cancel() + try: + await cancel_watcher + except (asyncio.CancelledError, Exception): + pass if lines_iter is not None: try: await lines_iter.aclose() @@ -3575,6 +3821,7 @@ async def _anthropic_passthrough_stream( await client.aclose() except Exception: pass + _tracker.__exit__(None, None, None) for line in emitter.finish(): yield line @@ -3621,6 +3868,7 @@ async def _anthropic_passthrough_non_streaming( repetition_penalty = repetition_penalty, presence_penalty = presence_penalty, tool_choice = tool_choice, + backend_ctx = llama_backend.context_length, ) async with httpx.AsyncClient() as client: @@ -3742,7 +3990,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: return messages -def _build_openai_passthrough_body(payload) -> dict: +def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict: """Assemble the llama-server request body from a ChatCompletionRequest. Only explicitly-known OpenAI / llama-server fields are forwarded so that @@ -3764,6 +4012,7 @@ def _build_openai_passthrough_body(payload) -> dict: repetition_penalty = payload.repetition_penalty, presence_penalty = payload.presence_penalty, tool_choice = tool_choice, + backend_ctx = backend_ctx, ) @@ -3784,103 +4033,56 @@ async def _openai_passthrough_stream( observes a standard OpenAI response. """ target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body(payload) + body = _build_openai_passthrough_body( + payload, backend_ctx = llama_backend.context_length + ) - # Dispatch the upstream request BEFORE returning StreamingResponse so - # transport errors and non-200 upstream statuses surface as real HTTP - # errors to the client. OpenAI SDKs rely on status codes to raise - # ``APIError``/``BadRequestError``/...; burying the failure inside a - # 200 SSE ``error`` frame silently breaks their error handling. - client = httpx.AsyncClient(timeout = 600) - resp = None + _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) + _tracker = _TrackedCancel(cancel_event, *_cancel_keys) + _tracker.__enter__() + + # Outer guard: asyncio.CancelledError at `await client.send(...)` is + # a BaseException that bypasses `except httpx.RequestError`; without + # this the tracker leaks. The generator's finally only runs once + # iteration starts. try: - req = client.build_request("POST", target_url, json = body) - resp = await client.send(req, stream = True) - except httpx.RequestError as e: - # llama-server subprocess crashed / still starting / unreachable. - logger.error("openai passthrough stream: upstream unreachable: %s", e) - if resp is not None: - try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = 502, - detail = _friendly_error(e), + # Dispatch BEFORE returning StreamingResponse so transport errors + # and non-200 upstream statuses surface as real HTTP errors -- + # OpenAI SDKs rely on status codes to raise APIError/BadRequestError. + client = httpx.AsyncClient( + timeout = 600, + limits = httpx.Limits(max_keepalive_connections = 0), ) - - if resp.status_code != 200: - err_bytes = await resp.aread() - err_text = err_bytes.decode("utf-8", errors = "replace") - logger.error( - "openai passthrough upstream error: status=%s body=%s", - resp.status_code, - err_text[:500], - ) - upstream_status = resp.status_code + resp = None try: - await resp.aclose() - except Exception: - pass - try: - await client.aclose() - except Exception: - pass - raise HTTPException( - status_code = upstream_status, - detail = f"llama-server error: {err_text[:500]}", - ) - - async def _stream(): - # Same httpx lifecycle pattern as _anthropic_passthrough_stream: - # avoid `async with` on the client/response AND explicitly save - # resp.aiter_lines() so we can close it ourselves in the finally - # block. See the long comment there for the full rationale on - # why the anonymous `async for raw_line in resp.aiter_lines():` - # pattern leaks an unclosed async generator that Python's - # asyncgen GC hook then finalizes in a different asyncio task, - # producing "Exception ignored in:" / "async generator ignored - # GeneratorExit" / anyio cancel-scope traces on Python 3.13 + - # httpcore 1.0.x. - lines_iter = None - try: - lines_iter = resp.aiter_lines() - async for raw_line in lines_iter: - if await request.is_disconnected(): - cancel_event.set() - break - if not raw_line: - continue - if not raw_line.startswith("data: "): - continue - # Relay the llama-server SSE chunk verbatim so the client - # sees its native `id`, `finish_reason`, `delta.tool_calls`, - # and final `usage` unchanged. - yield raw_line + "\n\n" - if raw_line[6:].strip() == "[DONE]": - break - except Exception as e: - # Mid-stream failures still have to be reported inside the SSE - # body because the 200 response headers have already been - # committed by the time the first chunk flushes. - logger.error("openai passthrough stream error: %s", e) - err = { - "error": { - "message": _friendly_error(e), - "type": "server_error", - }, - } - yield f"data: {json.dumps(err)}\n\n" - finally: - if lines_iter is not None: + req = client.build_request("POST", target_url, json = body) + resp = await client.send(req, stream = True) + except httpx.RequestError as e: + # llama-server subprocess crashed / still starting / unreachable. + logger.error("openai passthrough stream: upstream unreachable: %s", e) + if resp is not None: try: - await lines_iter.aclose() + await resp.aclose() except Exception: pass + try: + await client.aclose() + except Exception: + pass + raise HTTPException( + status_code = 502, + detail = _friendly_error(e), + ) + + if resp.status_code != 200: + err_bytes = await resp.aread() + err_text = err_bytes.decode("utf-8", errors = "replace") + logger.error( + "openai passthrough upstream error: status=%s body=%s", + resp.status_code, + err_text[:500], + ) + upstream_status = resp.status_code try: await resp.aclose() except Exception: @@ -3889,16 +4091,91 @@ async def _openai_passthrough_stream( await client.aclose() except Exception: pass + raise HTTPException( + status_code = upstream_status, + detail = f"llama-server error: {err_text[:500]}", + ) - return StreamingResponse( - _stream(), - media_type = "text/event-stream", - headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) + async def _stream(): + # Same httpx lifecycle pattern as _anthropic_passthrough_stream: + # save resp.aiter_lines() so the finally block can aclose() it + # on our task. See that function for full rationale. + lines_iter = None + # During llama-server prefill, `aiter_lines()` blocks until the + # first SSE chunk arrives. The in-loop `cancel_event` check + # cannot fire until then, which is the exact proxy/Colab + # scenario the cancel POST is meant to recover from. Run a + # tiny watcher that closes `resp` as soon as cancel fires, + # unblocking the iterator with a RemoteProtocolError caught + # in the except clause below. + cancel_watcher = asyncio.create_task( + _await_cancel_then_close(cancel_event, resp) + ) + try: + lines_iter = resp.aiter_lines() + async for raw_line in lines_iter: + if cancel_event.is_set(): + break + if await request.is_disconnected(): + cancel_event.set() + break + if not raw_line: + continue + if not raw_line.startswith("data: "): + continue + # Relay verbatim to preserve llama-server's native id, + # finish_reason, delta.tool_calls, and usage chunks. + yield raw_line + "\n\n" + if raw_line[6:].strip() == "[DONE]": + break + except (httpx.RemoteProtocolError, httpx.ReadError, httpx.CloseError): + # Watcher closed resp on cancel. Emit nothing extra; the + # client either initiated the cancel or already disconnected. + if not cancel_event.is_set(): + raise + except Exception as e: + # 200 headers are already flushed; errors must be in the SSE body. + logger.error("openai passthrough stream error: %s", e) + err = { + "error": { + "message": _friendly_error(e), + "type": "server_error", + }, + } + yield f"data: {json.dumps(err)}\n\n" + finally: + cancel_watcher.cancel() + try: + await cancel_watcher + except (asyncio.CancelledError, Exception): + pass + if lines_iter is not None: + try: + await lines_iter.aclose() + except Exception: + pass + try: + await resp.aclose() + except Exception: + pass + try: + await client.aclose() + except Exception: + pass + _tracker.__exit__(None, None, None) + + return StreamingResponse( + _stream(), + media_type = "text/event-stream", + headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + except BaseException: + _tracker.__exit__(None, None, None) + raise async def _openai_passthrough_non_streaming( @@ -3914,7 +4191,9 @@ async def _openai_passthrough_non_streaming( token counts. """ target_url = f"{llama_backend.base_url}/v1/chat/completions" - body = _build_openai_passthrough_body(payload) + body = _build_openai_passthrough_body( + payload, backend_ctx = llama_backend.context_length + ) try: async with httpx.AsyncClient() as client: diff --git a/studio/frontend/src/features/chat/api/chat-adapter.ts b/studio/frontend/src/features/chat/api/chat-adapter.ts index a93120397..dde597ca1 100644 --- a/studio/frontend/src/features/chat/api/chat-adapter.ts +++ b/studio/frontend/src/features/chat/api/chat-adapter.ts @@ -4,6 +4,8 @@ import type { ChatModelAdapter } from "@assistant-ui/react"; import type { MessageTiming, ToolCallMessagePart } from "@assistant-ui/core"; import { toast } from "sonner"; +import { getAuthToken } from "@/features/auth/session"; +import { apiUrl } from "@/lib/api-base"; import { generateAudio, listCachedGguf, @@ -707,7 +709,42 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { const toolCallParts: ToolCallMessagePart[] = []; let serverMetadata: { usage?: ServerUsage; timings?: ServerTimings } | null = null; + // Per-run cancellation token so a delayed stop POST cannot match + // the next run on the same thread. + const cancelId = + typeof crypto !== "undefined" && "randomUUID" in crypto + ? crypto.randomUUID() + : `${Date.now()}-${Math.random().toString(36).slice(2)}`; + + // Colab-style proxies can swallow fetch aborts, so also POST + // /inference/cancel explicitly on abort. + const onAbortCancel = () => { + const body: Record = { cancel_id: cancelId }; + if (resolvedThreadId) body.session_id = resolvedThreadId; + // Plain fetch, not authFetch: authFetch redirects to login on + // 401, which would kick the user out mid-stop. + const token = getAuthToken(); + // Use apiUrl so the cancel POST reaches the right origin in + // Tauri production builds (where the webview origin is not the + // backend at 127.0.0.1:). Browser/dev builds get the empty + // base, so the path is unchanged there. + void fetch(apiUrl("/api/inference/cancel"), { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(token ? { Authorization: `Bearer ${token}` } : {}), + }, + body: JSON.stringify(body), + keepalive: true, + }).catch(() => {}); + }; try { + if (abortSignal.aborted) { + onAbortCancel(); + } else { + abortSignal.addEventListener("abort", onAbortCancel, { once: true }); + } + const { supportsReasoning, reasoningEnabled, @@ -730,6 +767,8 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { presence_penalty: params.presencePenalty, image_base64: imageBase64, audio_base64: audioBase64, + cancel_id: cancelId, + ...(resolvedThreadId ? { session_id: resolvedThreadId } : {}), ...(useAdapter === undefined ? {} : { use_adapter: useAdapter }), ...(supportsReasoning ? reasoningStyle === "reasoning_effort" @@ -750,7 +789,6 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { const mins = useChatRuntimeStore.getState().toolCallTimeout; return mins >= 9999 ? 9999 : mins * 60; })(), - session_id: resolvedThreadId, } : {}), }, @@ -948,6 +986,7 @@ export function createOpenAIStreamAdapter(): ChatModelAdapter { } throw err; } finally { + abortSignal.removeEventListener("abort", onAbortCancel); runtime.setGeneratingStatus(null); runtime.setToolStatus(null); clearTimeout(warmupTimer); diff --git a/studio/frontend/src/features/chat/types/api.ts b/studio/frontend/src/features/chat/types/api.ts index ccfc8b2bf..25957f4a7 100644 --- a/studio/frontend/src/features/chat/types/api.ts +++ b/studio/frontend/src/features/chat/types/api.ts @@ -179,6 +179,7 @@ export interface OpenAIChatCompletionsRequest { max_tool_calls_per_message?: number; tool_call_timeout?: number; session_id?: string; + cancel_id?: string; } export interface OpenAIChatDelta { diff --git a/tests/studio/test_cancel_atomicity.py b/tests/studio/test_cancel_atomicity.py new file mode 100644 index 000000000..a8d483945 --- /dev/null +++ b/tests/studio/test_cancel_atomicity.py @@ -0,0 +1,289 @@ +""" +TOCTOU atomicity guards for the cancel path. + +Structural: cancel_inference, _cancel_by_cancel_id_or_stash, and +_TrackedCancel.__enter__ must each use a single _CANCEL_LOCK critical +section over lookup + stash / register + consume-pending. + +Behavioral: parallel cancel-POST vs __enter__ must never drop a cancel. +""" + +from __future__ import annotations + +import ast +import random +import threading +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "routes" + / "inference.py" +) +_SRC = SOURCE_PATH.read_text() +_TREE = ast.parse(_SRC) + + +def _find_function(name: str) -> ast.FunctionDef | ast.AsyncFunctionDef: + for node in ast.walk(_TREE): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == name + ): + return node + raise AssertionError(f"function {name!r} not found") + + +def _find_class(name: str) -> ast.ClassDef: + for node in ast.walk(_TREE): + if isinstance(node, ast.ClassDef) and node.name == name: + return node + raise AssertionError(f"class {name!r} not found") + + +def _count_with_cancel_lock_blocks(node: ast.AST) -> int: + n = 0 + for sub in ast.walk(node): + if not isinstance(sub, ast.With): + continue + for item in sub.items: + ctx = item.context_expr + if isinstance(ctx, ast.Name) and ctx.id == "_CANCEL_LOCK": + n += 1 + break + return n + + +def test_cancel_by_cancel_id_or_stash_is_single_lock_critical_section(): + fn = _find_function("_cancel_by_cancel_id_or_stash") + assert _count_with_cancel_lock_blocks(fn) == 1, ( + "_cancel_by_cancel_id_or_stash must use exactly one `with " + "_CANCEL_LOCK:` block; splitting into two acquisitions reopens " + "the TOCTOU race with _TrackedCancel.__enter__" + ) + src = ast.unparse(fn) + assert "_CANCEL_REGISTRY.get(cancel_id)" in src + assert "_PENDING_CANCELS[cancel_id]" in src + + +def test_tracked_cancel_enter_registers_and_consumes_pending_under_one_lock(): + cls = _find_class("_TrackedCancel") + enter = None + for n in cls.body: + if isinstance(n, ast.FunctionDef) and n.name == "__enter__": + enter = n + break + assert enter is not None + assert _count_with_cancel_lock_blocks(enter) == 1, ( + "_TrackedCancel.__enter__ must acquire _CANCEL_LOCK exactly once. " + "A second acquisition for consume-pending lets a concurrent " + "cancel POST stash after consume sees an empty map, silently " + "dropping the cancel" + ) + with_block = None + for sub in ast.walk(enter): + if isinstance(sub, ast.With) and any( + isinstance(i.context_expr, ast.Name) and i.context_expr.id == "_CANCEL_LOCK" + for i in sub.items + ): + with_block = sub + break + assert with_block is not None + block_src = "\n".join(ast.unparse(s) for s in with_block.body) + assert "_CANCEL_REGISTRY.setdefault" in block_src + assert "_PENDING_CANCELS.pop" in block_src, ( + "__enter__ critical section must consume from _PENDING_CANCELS " + "inside the same lock, not a later re-acquisition" + ) + + +def test_cancel_inference_uses_atomic_helper_for_cancel_id_path(): + fn = _find_function("cancel_inference") + src = ast.unparse(fn) + assert "_cancel_by_cancel_id_or_stash" in src + # The pre-fix two-step idiom must be gone. + assert "_remember_pending_cancel(cancel_id)" not in src, ( + "two-step _cancel_by_keys + _remember_pending_cancel produced " + "the TOCTOU race and must not return" + ) + + +_WANTED = { + "_CANCEL_REGISTRY", + "_CANCEL_LOCK", + "_PENDING_CANCELS", + "_PENDING_CANCEL_TTL_S", + "_prune_pending", + "_remember_pending_cancel", + "_TrackedCancel", + "_cancel_by_keys", + "_cancel_by_cancel_id_or_stash", +} + + +def _load_registry_module(): + chunks = [] + for n in _TREE.body: + seg = ast.get_source_segment(_SRC, n) + if seg is None: + continue + if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED: + chunks.append(seg) + elif isinstance(n, ast.Assign): + names = [t.id for t in n.targets if isinstance(t, ast.Name)] + if any(name in _WANTED for name in names): + chunks.append(seg) + elif ( + isinstance(n, ast.AnnAssign) + and isinstance(n.target, ast.Name) + and n.target.id in _WANTED + ): + chunks.append(seg) + mod = {} + exec( + "import threading, time\nfrom typing import Optional\n" + "\n\n".join(chunks), + mod, + ) + return mod + + +def test_parallel_cancel_vs_register_never_drops(): + m = _load_registry_module() + trials = 500 + dropped = 0 + for i in range(trials): + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + cid = f"cid-{i}" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread") + start = threading.Event() + + def do_cancel(): + start.wait() + m["_cancel_by_cancel_id_or_stash"](cid) + + def do_enter(): + start.wait() + tracker.__enter__() + + threads = [ + threading.Thread(target = do_cancel), + threading.Thread(target = do_enter), + ] + random.shuffle(threads) + for t in threads: + t.start() + start.set() + for t in threads: + t.join(timeout = 5.0) + assert not t.is_alive() + + if not ev.is_set(): + dropped += 1 + tracker.__exit__(None, None, None) + + assert dropped == 0, ( + f"TOCTOU regression: {dropped}/{trials} parallel trials silently " + f"dropped the cancel" + ) + + +def test_cancel_before_register_replays_atomically(): + m = _load_registry_module() + cid = "early-cid" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread-x") + + assert m["_cancel_by_cancel_id_or_stash"](cid) == 0 + assert cid in m["_PENDING_CANCELS"] + + tracker.__enter__() + assert ev.is_set() + assert cid not in m["_PENDING_CANCELS"] + tracker.__exit__(None, None, None) + + +def test_cancel_after_register_signals_without_stash(): + m = _load_registry_module() + cid = "post-cid" + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, cid, "thread-y") + tracker.__enter__() + + assert m["_cancel_by_cancel_id_or_stash"](cid) == 1 + assert ev.is_set() + assert cid not in m["_PENDING_CANCELS"] + tracker.__exit__(None, None, None) + + +def test_cancel_by_keys_tolerates_empty_and_falsy_keys(): + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + assert m["_cancel_by_keys"]([]) == 0 + assert m["_cancel_by_keys"](["", None, "unknown"]) == 0 + # Non-stashing fallback must never leak into _PENDING_CANCELS. + assert m["_PENDING_CANCELS"] == {} + + +def test_cancel_by_keys_fans_out_to_all_streams_on_same_session(): + # Compare mode and other flows launch concurrent streams under a + # shared session_id; a single session cancel POST must hit all of them. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + session = "shared-thread" + ev_a = threading.Event() + ev_b = threading.Event() + tracker_a = m["_TrackedCancel"](ev_a, "cancel-a", session, "chatcmpl-a") + tracker_b = m["_TrackedCancel"](ev_b, "cancel-b", session, "chatcmpl-b") + tracker_a.__enter__() + tracker_b.__enter__() + try: + assert m["_cancel_by_keys"]([session]) == 2 + assert ev_a.is_set() and ev_b.is_set() + finally: + tracker_a.__exit__(None, None, None) + tracker_b.__exit__(None, None, None) + assert session not in m["_CANCEL_REGISTRY"] + + +def test_cancel_by_cancel_id_is_exclusive_to_single_run(): + # cancel_id is per-run unique; cancelling run A must not touch run B + # even when both share a session_id. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + session = "shared-thread-2" + ev_a = threading.Event() + ev_b = threading.Event() + tracker_a = m["_TrackedCancel"](ev_a, "cancel-only-a", session, "chatcmpl-a") + tracker_b = m["_TrackedCancel"](ev_b, "cancel-only-b", session, "chatcmpl-b") + tracker_a.__enter__() + tracker_b.__enter__() + try: + assert m["_cancel_by_cancel_id_or_stash"]("cancel-only-a") == 1 + assert ev_a.is_set() + assert not ev_b.is_set() + finally: + tracker_a.__exit__(None, None, None) + tracker_b.__exit__(None, None, None) + + +def test_tracked_cancel_exit_is_idempotent(): + # Outer except BaseException + the generator's finally may both call + # __exit__ under certain race combos; must not raise. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + m["_PENDING_CANCELS"].clear() + ev = threading.Event() + tracker = m["_TrackedCancel"](ev, "cid", "sess", "chatcmpl-x") + tracker.__enter__() + tracker.__exit__(None, None, None) + tracker.__exit__(None, None, None) + tracker.__exit__(None, None, None) + assert not m["_CANCEL_REGISTRY"] diff --git a/tests/studio/test_cancel_id_wiring.py b/tests/studio/test_cancel_id_wiring.py new file mode 100644 index 000000000..5fd76ded9 --- /dev/null +++ b/tests/studio/test_cancel_id_wiring.py @@ -0,0 +1,169 @@ +""" +Wiring tests for the per-run cancel_id field. + +A chat-thread-scoped session_id is not safe as a cancel key because a +late stop POST can match a subsequent run on the same thread. The fix +adds cancel_id (a fresh UUID per generation) that is sent both in the +completion payload and in the /api/inference/cancel body. + +Verifies: + - ChatCompletionRequest exposes an Optional[str] `cancel_id` field. + - /api/inference/cancel accepts `cancel_id` as the first-preferred key. + - OpenAIChatCompletionsRequest (frontend type) includes cancel_id. + - chat-adapter.ts generates a per-run cancelId (crypto.randomUUID + with a Math.random fallback), sends it in the completion payload, + and includes it in the /inference/cancel body on abort. +""" + +from __future__ import annotations + +import ast +import re +from pathlib import Path + + +WORKSPACE = Path(__file__).resolve().parents[2] +MODELS_SRC = (WORKSPACE / "studio/backend/models/inference.py").read_text() +ROUTES_SRC = (WORKSPACE / "studio/backend/routes/inference.py").read_text() +ADAPTER_SRC = ( + WORKSPACE / "studio/frontend/src/features/chat/api/chat-adapter.ts" +).read_text() +API_TYPES_SRC = ( + WORKSPACE / "studio/frontend/src/features/chat/types/api.ts" +).read_text() + + +def _find_class(tree: ast.AST, name: str) -> ast.ClassDef | None: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == name: + return node + return None + + +def test_chat_completion_request_has_cancel_id_field(): + tree = ast.parse(MODELS_SRC) + cls = _find_class(tree, "ChatCompletionRequest") + assert cls is not None + fields = { + n.target.id + for n in cls.body + if isinstance(n, ast.AnnAssign) and isinstance(n.target, ast.Name) + } + assert "cancel_id" in fields, ( + "ChatCompletionRequest must expose a cancel_id field for per-run " + "cancellation routing" + ) + + +def test_cancel_route_matches_cancel_id_exclusively_when_present(): + # A stale cancel POST carrying cancel_id AND session_id must not + # cancel a later run on the same thread via the shared session_id. + # Enforce this by requiring the handler to early-return through an + # exclusive-cancel_id path -- either an atomic helper or a keys + # list containing ONLY cancel_id (never session_id). + for node in ast.walk(ast.parse(ROUTES_SRC)): + if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference": + break + else: + raise AssertionError("cancel_inference handler missing") + + cancel_id_exclusive_branch = False + for sub in ast.walk(node): + if not isinstance(sub, ast.If): + continue + test_src = ast.unparse(sub.test) + if "cancel_id" not in test_src or "isinstance" not in test_src: + continue + branch_src = "\n".join(ast.unparse(s) for s in sub.body) + before_return = branch_src.split("return", 1)[0] + matches_cancel_id_only = ( + "_cancel_by_cancel_id_or_stash(cancel_id)" in branch_src + or "_cancel_by_keys([cancel_id])" in branch_src + ) + if matches_cancel_id_only and "session_id" not in before_return: + cancel_id_exclusive_branch = True + break + assert cancel_id_exclusive_branch, ( + "cancel_inference must early-return with an exclusive cancel_id " + "match when a cancel_id is supplied, so a stale stop POST " + "cannot cancel a later run on the same thread via session_id" + ) + + +def test_cancel_route_falls_back_to_session_or_completion_when_no_cancel_id(): + for node in ast.walk(ast.parse(ROUTES_SRC)): + if isinstance(node, ast.AsyncFunctionDef) and node.name == "cancel_inference": + break + else: + raise AssertionError("cancel_inference handler missing") + + src = ast.unparse(node) + assert "session_id" in src and "completion_id" in src, ( + "cancel_inference must still accept session_id / completion_id as " + "fallback keys when cancel_id is absent" + ) + + +def test_frontend_request_type_has_cancel_id(): + assert re.search( + r"cancel_id\?\s*:\s*string\s*;", API_TYPES_SRC + ), "OpenAIChatCompletionsRequest must expose an optional cancel_id" + + +def test_chat_adapter_generates_cancel_id_per_run(): + m = re.search( + r"const\s+cancelId\s*=\s*([^;]+);", + ADAPTER_SRC, + ) + assert m, "chat-adapter.ts must declare a per-run `cancelId` constant" + rhs = m.group(1) + assert ( + "randomUUID" in rhs + ), "cancelId should prefer crypto.randomUUID() for uniqueness" + + +def test_chat_adapter_sends_cancel_id_in_completion_payload(): + assert "cancel_id: cancelId" in ADAPTER_SRC, ( + "chat-adapter.ts must include cancel_id in the streamChatCompletions " + "request payload so the backend registers under that key" + ) + + +def test_chat_adapter_sends_cancel_id_in_abort_cancel_post(): + m = re.search( + r"const\s+onAbortCancel\s*=\s*\(\)\s*=>\s*\{(.*?)\};", + ADAPTER_SRC, + flags = re.DOTALL, + ) + assert m, "onAbortCancel arrow function missing" + body = m.group(1) + assert re.search(r"cancel_id\s*:\s*cancelId", body), ( + "onAbortCancel must include cancel_id in the /inference/cancel body " + "so a stop POST matches the specific run, not the whole thread" + ) + + +def test_abort_cancel_post_uses_plain_fetch_with_manual_auth_header(): + # authFetch redirects to login on 401, which would kick the user to + # the login page mid-stop if the access token expired during a long + # stream. Use plain fetch + manual Authorization header for a + # best-effort cancel that never triggers the refresh/redirect flow. + start = ADAPTER_SRC.find("const onAbortCancel") + assert start >= 0, "onAbortCancel handler missing" + rest = ADAPTER_SRC[start:] + end = rest.find("\n try {") + body = rest if end < 0 else rest[:end] + assert "/api/inference/cancel" in body + assert "authFetch(" not in body, ( + "onAbortCancel must NOT call authFetch; a 401 from it would " + "redirect the user to the login page during a stop click" + ) + assert "fetch(" in body, "onAbortCancel must use plain fetch(...)" + assert "getAuthToken" in body, ( + "onAbortCancel must read the bearer token via getAuthToken() " + "rather than relying on authFetch's 401 flow" + ) + assert "Authorization" in body + assert ( + "keepalive: true" in body + ), "keepalive is required so the fetch survives page unload during stop" diff --git a/tests/studio/test_llama_cpp_wall_clock_cap.py b/tests/studio/test_llama_cpp_wall_clock_cap.py new file mode 100644 index 000000000..671abea82 --- /dev/null +++ b/tests/studio/test_llama_cpp_wall_clock_cap.py @@ -0,0 +1,123 @@ +""" +Tests for the llama-server wall-clock cap (t_max_predict_ms). + +The UI always sends max_tokens = context_length, so gating +t_max_predict_ms on `max_tokens is None` makes the safety net dead +code. The fix applies the wall-clock cap unconditionally on all three +streaming payload sites and raises the default to 10 minutes so slow +CPU / macOS / Windows installs are not cut off mid-generation. + +Verifies: + - t_max_predict_ms is assigned unconditionally at the three + payload-builder sites (not inside an `if max_tokens is None` else + branch). + - _DEFAULT_T_MAX_PREDICT_MS is at least 10 minutes (previously + 120_000). + - The default max_tokens path still applies _DEFAULT_MAX_TOKENS. + - The three payload variable names (payload x2, stream_payload x1) + each get both `max_tokens` and `t_max_predict_ms`. +""" + +from __future__ import annotations + +import ast +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "core" + / "inference" + / "llama_cpp.py" +) +SRC = SOURCE_PATH.read_text() +TREE = ast.parse(SRC) + + +def _is_subscript_assign(stmt: ast.stmt, target_name: str, key: str) -> bool: + if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1: + return False + t = stmt.targets[0] + if not isinstance(t, ast.Subscript): + return False + if not (isinstance(t.value, ast.Name) and t.value.id == target_name): + return False + slc = t.slice + return isinstance(slc, ast.Constant) and slc.value == key + + +def _collect_assignments(tree, target_name, key): + """Return list of (node, stack_of_enclosing_ifs) for each match.""" + hits = [] + + def visit(node, stack): + if _is_subscript_assign(node, target_name, key): + hits.append((node, stack)) + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.If): + for sub in child.body: + visit(sub, stack + [(child, "body")]) + for sub in child.orelse: + visit(sub, stack + [(child, "orelse")]) + else: + visit(child, stack) + + visit(tree, []) + return hits + + +def test_default_t_max_predict_ms_is_at_least_ten_minutes(): + for node in TREE.body: + if isinstance(node, ast.Assign) and len(node.targets) == 1: + t = node.targets[0] + if isinstance(t, ast.Name) and t.id == "_DEFAULT_T_MAX_PREDICT_MS": + value = node.value + assert isinstance(value, ast.Constant) + assert value.value >= 600_000, ( + f"_DEFAULT_T_MAX_PREDICT_MS must be >= 10 minutes " + f"(600_000 ms) to avoid cutting off slow-CPU generations; " + f"got {value.value}" + ) + return + raise AssertionError("_DEFAULT_T_MAX_PREDICT_MS constant missing") + + +def test_t_max_predict_ms_set_unconditionally_at_three_sites(): + hits_payload = _collect_assignments(TREE, "payload", "t_max_predict_ms") + hits_stream = _collect_assignments(TREE, "stream_payload", "t_max_predict_ms") + total = len(hits_payload) + len(hits_stream) + assert total == 3, ( + f"expected 3 total t_max_predict_ms assignments " + f"(payload x2 + stream_payload x1), got {total}" + ) + for node, stack in hits_payload + hits_stream: + for parent_if, branch in stack: + # The assignment must not be gated by a test that checks + # `max_tokens is None` (which would make it dead code for + # the UI path where max_tokens is always set). + test_src = ast.unparse(parent_if.test) + assert "max_tokens" not in test_src, ( + f"t_max_predict_ms at line {node.lineno} is nested under " + f"`if {test_src}:` -- it must be applied unconditionally so " + f"the wall-clock cap is not dead code for callers that set " + f"max_tokens" + ) + + +def test_max_tokens_default_cap_still_applied(): + # _DEFAULT_MAX_TOKENS must still kick in when caller passes None. + # We check the conditional expression `max_tokens if max_tokens is not + # None else _DEFAULT_MAX_TOKENS` appears at each site. + matches = 0 + for node in ast.walk(TREE): + if not isinstance(node, ast.IfExp): + continue + src = ast.unparse(node) + if "max_tokens" in src and "_DEFAULT_MAX_TOKENS" in src: + matches += 1 + assert matches >= 3, ( + f"expected >=3 `max_tokens if max_tokens is not None else " + f"_DEFAULT_MAX_TOKENS` expressions; got {matches}" + ) diff --git a/tests/studio/test_stream_cancel_registration_timing.py b/tests/studio/test_stream_cancel_registration_timing.py new file mode 100644 index 000000000..40ec3d6e1 --- /dev/null +++ b/tests/studio/test_stream_cancel_registration_timing.py @@ -0,0 +1,718 @@ +""" +Tests that the cancel tracker is registered BEFORE StreamingResponse is +returned, and that cleanup runs via a `finally` block inside each +async generator. + +The zombie-generation scenario is: user clicks Stop during prefill / +warmup / proxy buffering, before the first SSE chunk. If _tracker +__enter__ lives inside the async generator body, the registry is empty +at the moment /api/inference/cancel lands -- so cancel returns 0 and +the decode runs to completion. + +The fix moves _tracker = _TrackedCancel(...) and _tracker.__enter__() +to the synchronous body of openai_chat_completions (before the +StreamingResponse is returned) and places _tracker.__exit__ inside +each generator's `finally` block. Using a generator `finally` (rather +than a Starlette BackgroundTask) guarantees cleanup on every +termination path -- normal exhaustion, CancelledError from +ClientDisconnect, and OSError / BrokenPipeError during send() -- +because Starlette skips `background` callbacks when stream_response +raises. + +Structural verifies: + - No `async def ...:` body contains `_tracker.__enter__()` in + routes/inference.py (registration moved to sync body). + - Each of the four async generators (gguf_tool_stream, + gguf_stream_chunks, stream_chunks, audio_input_stream) contains + `_tracker.__exit__(None, None, None)` inside a try/finally block. + - No StreamingResponse in openai_chat_completions passes + `background=` (cleanup now lives in the generator finally). + +Behavioral verifies (extracting `_TrackedCancel` from source and +exercising the actual runtime semantics): + - `finally: _tracker.__exit__(...)` runs on normal completion, + mid-stream exception (OSError / BrokenPipeError from send()), + and aclose() from Starlette ClientDisconnect. + - A pre-set cancel_event (from `_TrackedCancel.__enter__` replaying + a pending cancel POST) lets the GGUF while-loop break cleanly + and emit final_chunk + [DONE] instead of propagating + `GeneratorExit` out of `_stream_with_retry` into the async + generator's `except Exception` (which would not catch it). +""" + +from __future__ import annotations + +import ast +import asyncio +import threading +import time +from pathlib import Path + + +SOURCE_PATH = ( + Path(__file__).resolve().parents[2] + / "studio" + / "backend" + / "routes" + / "inference.py" +) +SRC = SOURCE_PATH.read_text() +_TREE = ast.parse(SRC) + + +# ── Structural (AST) helpers ───────────────────────────────── + + +def _collect_async_functions(tree: ast.AST): + return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)] + + +def _has_tracker_enter_call(node: ast.AST) -> bool: + for sub in ast.walk(node): + if not isinstance(sub, ast.Call): + continue + fn = sub.func + if ( + isinstance(fn, ast.Attribute) + and fn.attr == "__enter__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + return True + return False + + +def _finalbody_has_tracker_exit(finalbody) -> bool: + for stmt in finalbody: + if not isinstance(stmt, ast.Expr): + continue + call = stmt.value + if not (isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute)): + continue + fn = call.func + if ( + fn.attr == "__exit__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + return True + return False + + +# ── Structural tests ───────────────────────────────────────── + + +def test_no_tracker_enter_inside_async_generators(): + offenders = [] + for fn in _collect_async_functions(_TREE): + if fn.name in { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + }: + if _has_tracker_enter_call(fn): + offenders.append(fn.name) + assert not offenders, ( + f"Cancel tracker registration must live OUTSIDE the async generator " + f"body so a stop POST can find the registry entry before the first " + f"SSE chunk. Offending generators: {offenders}" + ) + + +def test_tracker_enter_exists_in_sync_body_of_chat_completions(): + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None, "openai_chat_completions handler missing" + count = 0 + for sub in ast.walk(top): + if not isinstance(sub, ast.Call): + continue + fn = sub.func + if ( + isinstance(fn, ast.Attribute) + and fn.attr == "__enter__" + and isinstance(fn.value, ast.Name) + and fn.value.id.startswith("_tracker") + ): + count += 1 + assert count >= 3, ( + f"expected >=3 _tracker.__enter__() calls in openai_chat_completions " + f"(one per streaming path), got {count}" + ) + + +def test_async_generators_cleanup_tracker_in_finally(): + required = { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + } + found: set[str] = set() + for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]: + if fn.name not in required: + continue + for sub in ast.walk(fn): + if isinstance(sub, ast.Try) and sub.finalbody: + if _finalbody_has_tracker_exit(sub.finalbody): + found.add(fn.name) + break + missing = required - found + assert not missing, ( + f"Cleanup must run via `finally: _tracker.__exit__(None, None, None)` " + f"inside each streaming generator so ClientDisconnect / OSError paths " + f"also release registry entries (Starlette skips `background` callbacks " + f"when stream_response raises). Missing in: {sorted(missing)}" + ) + + +def test_streaming_responses_have_no_background_task(): + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None + for sub in ast.walk(top): + if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)): + continue + if sub.func.id != "StreamingResponse": + continue + kwargs = {kw.arg for kw in sub.keywords if kw.arg} + assert "background" not in kwargs, ( + "StreamingResponse in openai_chat_completions must not pass " + "`background=` -- cleanup now lives in the generator's finally " + "block; a BackgroundTask would be skipped on abrupt disconnect" + ) + + +# ── Behavioral helpers ─────────────────────────────────────── + +_WANTED = { + "_CANCEL_REGISTRY", + "_CANCEL_LOCK", + "_PENDING_CANCELS", + "_PENDING_CANCEL_TTL_S", + "_prune_pending", + "_TrackedCancel", + "_cancel_by_keys", + "_cancel_by_cancel_id_or_stash", +} + + +def _load_registry_module(): + chunks = [] + for n in _TREE.body: + seg = ast.get_source_segment(SRC, n) + if seg is None: + continue + if isinstance(n, (ast.FunctionDef, ast.ClassDef)) and n.name in _WANTED: + chunks.append(seg) + elif isinstance(n, ast.Assign): + names = [t.id for t in n.targets if isinstance(t, ast.Name)] + if any(name in _WANTED for name in names): + chunks.append(seg) + elif ( + isinstance(n, ast.AnnAssign) + and isinstance(n.target, ast.Name) + and n.target.id in _WANTED + ): + chunks.append(seg) + mod = {} + exec("import threading, time\n" + "\n\n".join(chunks), mod) + return mod + + +def _make_stream(tracker, raise_exc): + async def gen(): + try: + try: + yield "data: first\n\n" + if raise_exc is not None: + raise raise_exc + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + raise + except Exception: + yield "data: error\n\n" + finally: + tracker.__exit__(None, None, None) + except BaseException: + raise + + return gen() + + +async def _consume(agen): + out = [] + try: + async for ch in agen: + out.append(ch) + except BaseException as e: + out.append(type(e).__name__) + return out + + +def _llama_stub_raises_on_preset_cancel(cancel_event): + # Reproduces llama_cpp.py _stream_with_retry:2240 `raise GeneratorExit` + # when cancel_event is already set at entry. + if cancel_event.is_set(): + raise GeneratorExit + yield "cumulative-1" + yield "cumulative-2" + + +async def _post_fix_gguf_loop(cancel_event): + yield "first_chunk" + gen = _llama_stub_raises_on_preset_cancel(cancel_event) + sentinel = object() + while True: + if cancel_event.is_set(): + break + cumulative = await asyncio.to_thread(next, gen, sentinel) + if cumulative is sentinel: + break + yield cumulative + yield "final_chunk" + yield "[DONE]" + + +# ── Behavioral tests ───────────────────────────────────────── + + +def test_finally_cleanup_on_normal_completion(): + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-ok", "sid-ok") + tr.__enter__() + assert "cid-ok" in m["_CANCEL_REGISTRY"] + chunks = asyncio.run(_consume(_make_stream(tr, None))) + assert chunks == ["data: first\n\n", "data: [DONE]\n\n"] + assert "cid-ok" not in m["_CANCEL_REGISTRY"] + assert "sid-ok" not in m["_CANCEL_REGISTRY"] + + +def test_finally_cleanup_on_mid_stream_exception(): + # Simulates OSError / BrokenPipeError from Starlette send() mid-stream -- + # the exact case where pre-fix `background = BackgroundTask(...)` was + # skipped and leaked the registry entry. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-err", "sid-err") + tr.__enter__() + assert "cid-err" in m["_CANCEL_REGISTRY"] + asyncio.run(_consume(_make_stream(tr, OSError("disconnect")))) + assert "cid-err" not in m["_CANCEL_REGISTRY"] + assert "sid-err" not in m["_CANCEL_REGISTRY"] + + +def test_finally_cleanup_on_aclose(): + # Starlette calls aclose() on the async generator when the client + # disconnects mid-stream. The generator's finally block must run. + m = _load_registry_module() + m["_CANCEL_REGISTRY"].clear() + ev = threading.Event() + tr = m["_TrackedCancel"](ev, "cid-abort", "sid-abort") + tr.__enter__() + assert "cid-abort" in m["_CANCEL_REGISTRY"] + + async def run(): + gen = _make_stream(tr, None) + it = gen.__aiter__() + await it.__anext__() + await gen.aclose() + + asyncio.run(run()) + assert "cid-abort" not in m["_CANCEL_REGISTRY"] + assert "sid-abort" not in m["_CANCEL_REGISTRY"] + + +def test_preset_cancel_event_exits_cleanly_with_done(): + # Pending-replay: POST /cancel arrived before the stream registered, + # was stashed, then consumed by _TrackedCancel.__enter__ which set + # cancel_event. The generator must break out of the loop cleanly + # and emit final_chunk + [DONE] rather than calling next(gen) and + # propagating `GeneratorExit` out of the GGUF stream wrapper. + ev = threading.Event() + ev.set() + chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev))) + assert "first_chunk" in chunks + assert "final_chunk" in chunks + assert "[DONE]" in chunks + assert "GeneratorExit" not in chunks + assert "cumulative-1" not in chunks + assert "cumulative-2" not in chunks + + +def test_normal_path_streams_all_tokens(): + # Regression: the top-of-loop cancel_event check must not short-circuit + # when cancel_event is unset. + ev = threading.Event() + chunks = asyncio.run(_consume(_post_fix_gguf_loop(ev))) + assert chunks == [ + "first_chunk", + "cumulative-1", + "cumulative-2", + "final_chunk", + "[DONE]", + ] + + +def test_cancel_during_streaming_stops_iteration_promptly(): + # Setting cancel_event between yields breaks out on the next iteration + # rather than draining the stub generator. + ev = threading.Event() + + async def _run(): + gen = _post_fix_gguf_loop(ev) + seen = [] + async for ch in gen: + seen.append(ch) + if ch == "cumulative-1": + ev.set() + return seen + + seen = asyncio.run(_run()) + assert "first_chunk" in seen + assert "cumulative-1" in seen + assert "cumulative-2" not in seen + assert "final_chunk" in seen + assert "[DONE]" in seen + + +# ── Cancel-event responsiveness in the streaming loops ─────── + + +def _loop_has_cancel_event_check(fn) -> bool: + # An `if cancel_event.is_set():` statement anywhere inside a + # `while`/`for` loop body is sufficient -- without it, a cancel POST + # cannot interrupt the loop because Colab-style proxies do not + # propagate request.is_disconnected(). + for sub in ast.walk(fn): + if not isinstance(sub, (ast.While, ast.For, ast.AsyncFor)): + continue + for stmt in ast.walk(sub): + if not isinstance(stmt, ast.If): + continue + t = stmt.test + if ( + isinstance(t, ast.Call) + and isinstance(t.func, ast.Attribute) + and t.func.attr == "is_set" + and isinstance(t.func.value, ast.Name) + and t.func.value.id == "cancel_event" + ): + return True + return False + + +def test_streaming_generators_check_cancel_event_in_loop(): + required = { + "gguf_tool_stream", + "gguf_stream_chunks", + "stream_chunks", + "audio_input_stream", + } + missing = [] + for fn in [n for n in ast.walk(_TREE) if isinstance(n, ast.AsyncFunctionDef)]: + if fn.name not in required: + continue + if not _loop_has_cancel_event_check(fn): + missing.append(fn.name) + assert not missing, ( + f"Each streaming generator must check `cancel_event.is_set()` inside " + f"its main loop so `POST /api/inference/cancel` can interrupt the " + f"stream through proxies that do not forward fetch aborts. " + f"Missing in: {sorted(missing)}" + ) + + +def test_audio_input_stream_offloads_blocking_next_to_thread(): + # Guards against regression back to `for chunk_text in + # audio_input_generate():` -- which blocks the event loop on each + # whisper chunk and prevents POST /api/inference/cancel from being + # serviced until the chunk yields. + audio = None + for fn in ast.walk(_TREE): + if isinstance(fn, ast.AsyncFunctionDef) and fn.name == "audio_input_stream": + audio = fn + break + assert audio is not None, "audio_input_stream generator missing" + + for sub in ast.walk(audio): + if isinstance(sub, (ast.For, ast.AsyncFor)): + it_src = ast.unparse(sub.iter) + assert "audio_input_generate" not in it_src, ( + "audio_input_stream must not iterate audio_input_generate() " + "directly -- that blocks the event loop. Use " + "`await asyncio.to_thread(next, gen, _DONE)` inside a " + "`while True` loop instead" + ) + + found_to_thread_next = False + for sub in ast.walk(audio): + if not isinstance(sub, ast.Call): + continue + fn_expr = sub.func + if not ( + isinstance(fn_expr, ast.Attribute) + and fn_expr.attr == "to_thread" + and isinstance(fn_expr.value, ast.Name) + and fn_expr.value.id == "asyncio" + ): + continue + if sub.args and isinstance(sub.args[0], ast.Name) and sub.args[0].id == "next": + found_to_thread_next = True + break + assert found_to_thread_next, ( + "audio_input_stream must call `asyncio.to_thread(next, gen, ...)` " + "to keep the event loop free while whisper yields the next chunk" + ) + + +def test_stream_chunks_cancel_branch_resets_backend_state(): + # The Unsloth path's cancel branch must flush GPU / KV-cache state + # via `backend.reset_generation_state()` -- the orchestrator's + # internal cancel path does not do this, so a cancel-via-POST that + # only broke the loop would leave the subprocess in a dirty state + # for the next request. + fn = None + top = None + for n in ast.walk(_TREE): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": + top = n + break + assert top is not None + for n in ast.walk(top): + if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks": + fn = n + break + assert fn is not None, "stream_chunks generator missing" + + for sub in ast.walk(fn): + if not isinstance(sub, ast.If): + continue + t = sub.test + if not ( + isinstance(t, ast.Call) + and isinstance(t.func, ast.Attribute) + and t.func.attr == "is_set" + and isinstance(t.func.value, ast.Name) + and t.func.value.id == "cancel_event" + ): + continue + body_src = "\n".join(ast.unparse(s) for s in sub.body) + if "backend.reset_generation_state()" in body_src: + return + raise AssertionError( + "stream_chunks `if cancel_event.is_set():` branch must call " + "backend.reset_generation_state() -- matches the existing " + "request.is_disconnected() / CancelledError cleanup paths and " + "prevents KV-cache drift after cancel-via-POST" + ) + + +# ── Behavioral simulations for the iter-1 fixes ────────────── + + +def test_unsloth_stream_loop_breaks_on_external_cancel_event(): + cancel_event = threading.Event() + reset_calls = [0] + + class _Backend: + def reset_generation_state(self): + reset_calls[0] += 1 + + backend = _Backend() + + def _generate(): + for i in range(200): + time.sleep(0.005) + yield f"cum-{i}" + + async def _loop(): + _DONE = object() + loop = asyncio.get_event_loop() + gen = _generate() + seen = [] + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + cumulative = await loop.run_in_executor(None, next, gen, _DONE) + if cumulative is _DONE: + break + seen.append(cumulative) + return seen + + async def _fire(): + await asyncio.sleep(0.05) + cancel_event.set() + + async def _main(): + return await asyncio.gather(_loop(), _fire()) + + seen, _ = asyncio.run(_main()) + assert ( + len(seen) < 200 + ), f"loop must not drain the generator after cancel; got {len(seen)} tokens" + assert reset_calls[0] == 1, ( + f"backend.reset_generation_state() must be called exactly once on " + f"cancel-via-POST, got {reset_calls[0]}" + ) + + +def test_audio_stream_stays_responsive_under_blocking_next(): + # Regression guard: replace the post-fix loop with the pre-fix + # `for chunk in audio_input_generate()` pattern and assert it blocks + # the event loop; then confirm the post-fix pattern exits promptly. + cancel_event = threading.Event() + + def _audio_gen(): + for i in range(8): + time.sleep(0.15) + yield f"chunk-{i}" + + async def _prefix_loop(): + seen = [] + for chunk_text in _audio_gen(): + if cancel_event.is_set(): + break + seen.append(chunk_text) + return seen + + async def _postfix_loop(): + _DONE = object() + gen = _audio_gen() + seen = [] + while True: + if cancel_event.is_set(): + break + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break + seen.append(chunk_text) + return seen + + async def _fire_early(): + await asyncio.sleep(0.05) + cancel_event.set() + + async def _run(loop_coro): + return await asyncio.gather(loop_coro, _fire_early()) + + cancel_event.clear() + t0 = time.monotonic() + prefix_seen, _ = asyncio.run(_run(_prefix_loop())) + prefix_elapsed = time.monotonic() - t0 + assert prefix_elapsed >= 0.13, ( + f"pre-fix pattern should block event loop for >=1 chunk time " + f"(~150ms); got {prefix_elapsed:.3f}s, {len(prefix_seen)} chunks" + ) + + cancel_event.clear() + t0 = time.monotonic() + postfix_seen, _ = asyncio.run(_run(_postfix_loop())) + postfix_elapsed = time.monotonic() - t0 + assert postfix_elapsed < prefix_elapsed, ( + f"post-fix pattern must exit faster than pre-fix (blocking) " + f"pattern; post={postfix_elapsed:.3f}s vs pre={prefix_elapsed:.3f}s" + ) + assert ( + len(postfix_seen) < 8 + ), f"post-fix loop must not drain all chunks; got {len(postfix_seen)}" + + +def test_unsloth_stream_loop_emits_zero_tokens_on_preset_cancel(): + # Pending-cancel replay: _TrackedCancel.__enter__ already set + # cancel_event before the generator body starts iterating. The + # top-of-loop check must short-circuit the very first iteration so + # no token is emitted. Catches a regression that moves the check + # below `next()` -- the mid-loop test would still pass but this + # test would observe one extra token leak. + cancel_event = threading.Event() + cancel_event.set() + reset_calls = [0] + + class _Backend: + def reset_generation_state(self): + reset_calls[0] += 1 + + backend = _Backend() + + next_calls = [0] + + def _generate(): + while True: + next_calls[0] += 1 + yield f"cum-{next_calls[0]}" + + async def _loop(): + _DONE = object() + loop = asyncio.get_event_loop() + gen = _generate() + seen = [] + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + cumulative = await loop.run_in_executor(None, next, gen, _DONE) + if cumulative is _DONE: + break + seen.append(cumulative) + return seen + + seen = asyncio.run(_loop()) + assert seen == [], ( + f"loop must emit zero tokens when cancel_event is pre-set " + f"(pending-replay path); got {seen}" + ) + assert next_calls[0] == 0, ( + f"loop must not call next() at all on pre-set cancel; got " + f"{next_calls[0]} calls" + ) + assert reset_calls[0] == 1, ( + f"backend.reset_generation_state() must still fire exactly once " + f"on pre-set cancel; got {reset_calls[0]}" + ) + + +def test_audio_stream_emits_zero_chunks_on_preset_cancel(): + # Symmetric to the Unsloth pre-set test: the audio loop's top-of-loop + # cancel check must skip the asyncio.to_thread(next, ...) call when + # cancel_event was already set via pending-replay. + cancel_event = threading.Event() + cancel_event.set() + + next_calls = [0] + + def _audio_gen(): + while True: + next_calls[0] += 1 + yield f"chunk-{next_calls[0]}" + + async def _loop(): + _DONE = object() + gen = _audio_gen() + seen = [] + while True: + if cancel_event.is_set(): + break + chunk_text = await asyncio.to_thread(next, gen, _DONE) + if chunk_text is _DONE: + break + seen.append(chunk_text) + return seen + + seen = asyncio.run(_loop()) + assert seen == [], f"audio loop must emit zero chunks on pre-set cancel; got {seen}" + assert next_calls[0] == 0, ( + f"audio loop must not call next() on pre-set cancel; got " + f"{next_calls[0]} calls" + )