fix(SKY-8986): stop SSE disconnect from killing the copilot agent (#5560)

This commit is contained in:
Andrew Neilson 2026-04-20 11:54:32 -07:00 committed by GitHub
parent e63689c981
commit 2fa21f8799
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 334 additions and 140 deletions

View file

@ -281,7 +281,6 @@ async def run_copilot_agent(
from skyvern.cli.mcp_tools import mcp as skyvern_mcp from skyvern.cli.mcp_tools import mcp as skyvern_mcp
from skyvern.forge.sdk.copilot.enforcement import ( from skyvern.forge.sdk.copilot.enforcement import (
CopilotClientDisconnectedError,
CopilotTotalTimeoutError, CopilotTotalTimeoutError,
run_with_enforcement, run_with_enforcement,
) )
@ -392,8 +391,8 @@ async def run_copilot_agent(
chat_request, chat_request,
organization_id, organization_id,
) )
except (CopilotClientDisconnectedError, asyncio.CancelledError): except asyncio.CancelledError:
LOG.info("Copilot client disconnected") LOG.info("Copilot run cancelled")
return AgentResult( return AgentResult(
user_response="Request cancelled.", user_response="Request cancelled.",
updated_workflow=ctx.last_workflow, updated_workflow=ctx.last_workflow,

View file

@ -246,10 +246,6 @@ def _is_progress_narration(user_response: Any) -> bool:
return any(pattern.search(user_response) for pattern in _PROGRESS_NARRATION_PATTERNS) return any(pattern.search(user_response) for pattern in _PROGRESS_NARRATION_PATTERNS)
class CopilotClientDisconnectedError(Exception):
"""Raised when the client disconnects during agent execution."""
class CopilotTotalTimeoutError(Exception): class CopilotTotalTimeoutError(Exception):
"""Raised when the copilot agent exceeds the total allowed runtime.""" """Raised when the copilot agent exceeds the total allowed runtime."""
@ -853,9 +849,10 @@ async def run_with_enforcement(
pending_recovery_nudge: str | None = None pending_recovery_nudge: str | None = None
while True: while True:
if await stream.is_disconnected(): # Client disconnect is no longer treated as a stop signal. The
raise CopilotClientDisconnectedError() # SSE stream silently drops events once the browser is gone, but
# the agent keeps running so the reply can be persisted to the
# chat history on the server side (see SKY-8986).
elapsed = time.monotonic() - start_time elapsed = time.monotonic() - start_time
if elapsed > TOTAL_TIMEOUT_SECONDS: if elapsed > TOTAL_TIMEOUT_SECONDS:
raise CopilotTotalTimeoutError() raise CopilotTotalTimeoutError()
@ -924,9 +921,6 @@ async def run_with_enforcement(
) )
raise raise
if await stream.is_disconnected():
raise CopilotClientDisconnectedError()
# Inject pending screenshots as a follow-up user message because OpenAI # Inject pending screenshots as a follow-up user message because OpenAI
# rejects images in tool messages. # rejects images in tool messages.
screenshot_msg = _consume_pending_screenshots(ctx) screenshot_msg = _consume_pending_screenshots(ctx)

View file

@ -1,7 +0,0 @@
from __future__ import annotations
from skyvern.exceptions import SkyvernException
class CopilotClientDisconnectedError(SkyvernException):
"""Raised when the SSE client disconnects during a streaming copilot run."""

View file

@ -11,7 +11,6 @@ import structlog
# Reuse the HTTP-logging redactor so SSE tool inputs and request-body logs # Reuse the HTTP-logging redactor so SSE tool inputs and request-body logs
# share one exact-match sensitive-key policy. # share one exact-match sensitive-key policy.
from skyvern.forge.request_logging import redact_sensitive_fields from skyvern.forge.request_logging import redact_sensitive_fields
from skyvern.forge.sdk.copilot.exceptions import CopilotClientDisconnectedError
from skyvern.forge.sdk.copilot.output_utils import summarize_tool_result from skyvern.forge.sdk.copilot.output_utils import summarize_tool_result
from skyvern.forge.sdk.schemas.workflow_copilot import ( from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotStreamMessageType, WorkflowCopilotStreamMessageType,
@ -51,13 +50,15 @@ async def stream_to_sse(
``post_update_nudge_count``, ``navigate_called``, and ``post_update_nudge_count``, ``navigate_called``, and
``observation_after_navigate``. ``observation_after_navigate``.
A true client disconnect -- detected by ``stream.is_disconnected()`` or a A client disconnect does NOT cancel the agent run: we continue to iterate
failed ``stream.send()`` -- cancels the agent run and raises ``result.stream_events()`` so the agent completes whatever work it is
``CopilotClientDisconnectedError``. A plain ``asyncio.CancelledError`` in the middle of and the caller can persist the reply to the DB. Events
from some other source (task-group cancel, upstream timeout, parent sent through ``stream.send`` after disconnect are silently dropped by the
abort) is allowed to propagate unchanged so asyncio's cancellation stream, so the queue cannot grow unbounded.
machinery runs normally; callers that want the two to share a UX path
should catch both at the call site. Real asyncio cancellation (server shutdown, parent task cancelled for
reasons unrelated to a dropped client) is re-raised unchanged so
asyncio's cancellation machinery still runs normally.
""" """
from agents.stream_events import RunItemStreamEvent from agents.stream_events import RunItemStreamEvent
@ -69,18 +70,21 @@ async def stream_to_sse(
try: try:
async for event in result.stream_events(): async for event in result.stream_events():
if await stream.is_disconnected():
result.cancel()
raise CopilotClientDisconnectedError()
if not isinstance(event, RunItemStreamEvent): if not isinstance(event, RunItemStreamEvent):
continue continue
# Skip emission work (serialization, redaction) once the client
# is gone, but keep draining the SDK stream so the agent can
# finish. stream.send below would drop the payload anyway.
client_gone = await stream.is_disconnected()
if event.name == "tool_called": if event.name == "tool_called":
raw = event.item.raw_item raw = event.item.raw_item
call_id = _get_raw_field(raw, "call_id") or _get_raw_field(raw, "id") or "" call_id = _get_raw_field(raw, "call_id") or _get_raw_field(raw, "id") or ""
tool_name = _get_raw_field(raw, "name") or "unknown" tool_name = _get_raw_field(raw, "name") or "unknown"
call_id_to_name[call_id] = tool_name call_id_to_name[call_id] = tool_name
if not client_gone:
raw_args = _get_raw_field(raw, "arguments") raw_args = _get_raw_field(raw, "arguments")
tool_input: dict[str, Any] = {} tool_input: dict[str, Any] = {}
if isinstance(raw_args, str): if isinstance(raw_args, str):
@ -91,7 +95,7 @@ async def stream_to_sse(
elif isinstance(raw_args, dict): elif isinstance(raw_args, dict):
tool_input = raw_args tool_input = raw_args
if not await stream.send( await stream.send(
WorkflowCopilotToolCallUpdate( WorkflowCopilotToolCallUpdate(
type=WorkflowCopilotStreamMessageType.TOOL_CALL, type=WorkflowCopilotStreamMessageType.TOOL_CALL,
tool_name=tool_name, tool_name=tool_name,
@ -99,9 +103,7 @@ async def stream_to_sse(
iteration=iteration, iteration=iteration,
tool_call_id=call_id, tool_call_id=call_id,
) )
): )
result.cancel()
raise CopilotClientDisconnectedError()
elif event.name == "tool_output": elif event.name == "tool_output":
raw = event.item.raw_item raw = event.item.raw_item
@ -110,10 +112,11 @@ async def stream_to_sse(
output = getattr(event.item, "output", None) output = getattr(event.item, "output", None)
parsed = parse_tool_output(output) parsed = parse_tool_output(output)
if not client_gone:
summary = summarize_tool_result(tool_name, parsed) summary = summarize_tool_result(tool_name, parsed)
success = parsed.get("ok", True) success = parsed.get("ok", True)
await stream.send(
if not await stream.send(
WorkflowCopilotToolResultUpdate( WorkflowCopilotToolResultUpdate(
type=WorkflowCopilotStreamMessageType.TOOL_RESULT, type=WorkflowCopilotStreamMessageType.TOOL_RESULT,
tool_name=tool_name, tool_name=tool_name,
@ -122,19 +125,14 @@ async def stream_to_sse(
iteration=iteration, iteration=iteration,
tool_call_id=call_id, tool_call_id=call_id,
) )
): )
result.cancel()
raise CopilotClientDisconnectedError()
_update_enforcement_from_tool(ctx, tool_name, parsed) _update_enforcement_from_tool(ctx, tool_name, parsed)
iteration += 1 iteration += 1
except asyncio.CancelledError: except asyncio.CancelledError:
# Don't relabel generic cancellation as a client disconnect -- the # Real cancellation (server shutdown, upstream abort). Propagate so
# inline is_disconnected / send-failure branches above already raise # asyncio's task machinery sees the cancel; also cancel the SDK
# CopilotClientDisconnectedError when there is real evidence of a # run to free provider resources.
# dropped client. Preserve cancellation semantics by re-raising so
# asyncio's task machinery sees the cancel and any upstream
# except Exception does NOT swallow it.
result.cancel() result.cancel()
raise raise

View file

@ -826,9 +826,12 @@ async def _run_blocks_and_collect_debug(
final_status = run.status final_status = run.status
break break
if await ctx.stream.is_disconnected(): # Deliberately do NOT short-circuit on client disconnect here.
await _cancel_run_task_if_not_final(run_task, workflow_run.workflow_run_id) # The agent loop is allowed to run to completion after the SSE
return {"ok": False, "error": "Client disconnected during block execution."} # stream is gone (see SKY-8986) so its reply can be persisted;
# aborting an in-flight block execution would leave the
# workflow run in a limbo state and the agent would have no
# debug output to summarize in the final chat message.
run = await app.DATABASE.workflow_runs.get_workflow_run( run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,

View file

@ -1,12 +1,22 @@
import asyncio import asyncio
from typing import Any, Awaitable, Callable, Protocol from typing import Any, Awaitable, Callable, Protocol
import structlog
from fastapi import Request from fastapi import Request
from pydantic import BaseModel from pydantic import BaseModel
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
LOG = structlog.get_logger()
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10 DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10
# Strong references to handler tasks that outlive their SSE response. When a
# client disconnects mid-stream, we let the handler keep running so any
# in-flight work (agent loops, DB persistence) finishes cleanly. Without a
# strong reference, the event loop could garbage-collect the task while it
# is still running.
_BACKGROUND_HANDLER_TASKS: set["asyncio.Task[None]"] = set()
class EventSourceStream(Protocol): class EventSourceStream(Protocol):
"""Protocol for Server-Sent Events (SSE) streams.""" """Protocol for Server-Sent Events (SSE) streams."""
@ -16,7 +26,11 @@ class EventSourceStream(Protocol):
Send data as an SSE event. Send data as an SSE event.
Returns: Returns:
True if the event was queued successfully, False if disconnected or closed. True if the event was accepted (queued for delivery or dropped
because the client is gone). False only if the stream has been
explicitly closed. Callers should treat a False return as a
terminal state and stop emitting; they should NOT treat a
client disconnect as a reason to abort in-flight work.
""" """
... ...
@ -33,8 +47,12 @@ class FastAPIEventSourceStream:
""" """
FastAPI implementation of EventSourceStream. FastAPI implementation of EventSourceStream.
This class provides a cleaner interface for sending SSE updates from async functions Sending is decoupled from client presence. When the client disconnects,
instead of using yield-based generators directly. the handler task keeps running (so agent loops and DB persistence
complete) and subsequent send() calls silently drop their payload
instead of growing the queue. The generator stops yielding and
sse-starlette closes the HTTP response; the handler runs out in the
background.
Usage: Usage:
@app.post("/stream") @app.post("/stream")
@ -51,22 +69,52 @@ class FastAPIEventSourceStream:
self._request = request self._request = request
self._queue: asyncio.Queue[Any] = asyncio.Queue() self._queue: asyncio.Queue[Any] = asyncio.Queue()
self._closed = False self._closed = False
# Latches once the client goes away so repeated is_disconnected()
# calls don't hit the ASGI receive channel after the response has
# been torn down.
self._client_gone = False
async def send(self, data: Any) -> bool: async def send(self, data: Any) -> bool:
""" """Send data as an SSE event. Accepts Pydantic models or dicts.
Send data as an SSE event. Accepts Pydantic models or dicts.
Returns: When the client is still connected, the event is queued for
True if the event was queued successfully, False if disconnected or closed. delivery. When the client has disconnected, the event is dropped
silently so the handler can continue to completion without
growing memory. Returns False only if the stream has been
explicitly closed.
""" """
if self._closed or await self.is_disconnected(): if self._closed:
return False return False
if await self.is_disconnected():
return True
await self._queue.put(data) await self._queue.put(data)
return True return True
async def is_disconnected(self) -> bool: async def is_disconnected(self) -> bool:
"""Check if the client has disconnected.""" """Check if the client has disconnected.
return await self._request.is_disconnected()
Caches the first positive result so callers made after the ASGI
response has been torn down don't try to pull from a receive
channel that may no longer be live.
"""
if self._client_gone:
return True
try:
disconnected = await self._request.is_disconnected()
except Exception as exc:
# Starlette's is_disconnected can raise various errors once
# the ASGI receive channel is gone (RuntimeError on closed
# queues, anyio/asyncio cancellation oddities, etc.). Treat
# any such failure as "client gone" -- we're polling to
# decide whether to skip emitting events, and if we can't
# tell we'd rather stop emitting than spin. Log so a
# genuinely new failure mode shows up in telemetry instead
# of hiding behind the disconnect path.
LOG.debug("is_disconnected raised; treating as disconnect", error=str(exc))
disconnected = True
if disconnected:
self._client_gone = True
return disconnected
async def close(self) -> None: async def close(self) -> None:
"""Signal that the stream is complete.""" """Signal that the stream is complete."""
@ -111,19 +159,27 @@ class FastAPIEventSourceStream:
An EventSourceResponse that can be returned from a FastAPI endpoint An EventSourceResponse that can be returned from a FastAPI endpoint
""" """
stream = cls(request) stream = cls(request)
task = asyncio.create_task(cls._run_handler(stream, handler))
# Hold a strong reference so the event loop can't GC the task if the
# generator is torn down first (SSE client disconnect). The set is
# bounded in practice by the handler's own timeout — every copilot
# handler must eventually return (see TOTAL_TIMEOUT_SECONDS in
# skyvern/forge/sdk/copilot/enforcement.py) which removes the task
# from the set via the done_callback below. If you wire a new
# EventSource endpoint through here, give its handler a similar
# hard cap or this set becomes an unbounded leak.
_BACKGROUND_HANDLER_TASKS.add(task)
task.add_done_callback(_BACKGROUND_HANDLER_TASKS.discard)
async def event_generator() -> Any: async def event_generator() -> Any:
task = asyncio.create_task(cls._run_handler(stream, handler))
try:
async for event in stream._generate(): async for event in stream._generate():
yield event yield event
finally: # Intentionally do NOT cancel the handler task here. SSE
if not task.done(): # disconnect must not kill in-flight work (the copilot agent
task.cancel() # often needs tens of seconds to finish and then persist its
try: # reply to the chat history). The handler keeps running in
await task # the background; stream.send() silently drops events once
except asyncio.CancelledError: # the client is gone so the queue cannot grow unbounded.
pass
def ping_message_factory() -> ServerSentEvent: def ping_message_factory() -> ServerSentEvent:
return ServerSentEvent(comment="keep-alive") return ServerSentEvent(comment="keep-alive")
@ -142,5 +198,11 @@ class FastAPIEventSourceStream:
"""Run the handler and ensure the stream is closed when done.""" """Run the handler and ensure the stream is closed when done."""
try: try:
await handler(stream) await handler(stream)
except asyncio.CancelledError:
# Process/server shutdown — propagate. Client disconnect alone
# no longer cancels this task (see create()).
raise
except Exception:
LOG.exception("SSE handler failed")
finally: finally:
await stream.close() await stream.close()

View file

@ -795,12 +795,9 @@ async def _new_copilot_chat_post(
) )
) )
if await stream.is_disconnected(): # No early exit on disconnect (SKY-8986): the agent runs to
LOG.info( # completion even after the SSE stream drops so its reply is
"Workflow copilot v2 chat request is disconnected before agent loop", # persisted to the chat history and visible after reconnect.
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
return
original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id( original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id(
workflow_permanent_id=chat_request.workflow_permanent_id, workflow_permanent_id=chat_request.workflow_permanent_id,
@ -838,15 +835,11 @@ async def _new_copilot_chat_post(
updated_workflow = agent_result.updated_workflow updated_workflow = agent_result.updated_workflow
updated_global_llm_context = agent_result.global_llm_context updated_global_llm_context = agent_result.global_llm_context
if await stream.is_disconnected(): # Persist rollback / proposed-workflow state and the chat
LOG.info( # messages regardless of whether the SSE client is still
"Workflow copilot v2 chat request is disconnected after agent loop", # connected: the user needs to see the reply on reconnect.
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id, # SKY-8986: client disconnect used to short-circuit this block
) # and leave the chat history without the AI response.
if _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id)
return
if chat.auto_accept is not True: if chat.auto_accept is not True:
if _should_restore_persisted_workflow(chat.auto_accept, agent_result): if _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id) await _restore_workflow_definition(original_workflow, organization.organization_id)
@ -1017,13 +1010,9 @@ async def workflow_copilot_chat_post(
) )
) )
if await stream.is_disconnected(): # SKY-8986: do not short-circuit on client disconnect. The LLM
LOG.info( # call and the DB persistence below must complete so the reply
"Workflow copilot chat request is disconnected before LLM call", # is in the chat history when the user reconnects.
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
return
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm( user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
stream, stream,
organization.organization_id, organization.organization_id,
@ -1033,13 +1022,6 @@ async def workflow_copilot_chat_post(
debug_run_info_text, debug_run_info_text,
) )
if await stream.is_disconnected():
LOG.info(
"Workflow copilot chat request is disconnected after LLM call",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
return
if updated_workflow and chat.auto_accept is not True: if updated_workflow and chat.auto_accept is not True:
await app.DATABASE.workflow_params.update_workflow_copilot_chat( await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=chat.organization_id, organization_id=chat.organization_id,

View file

@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from skyvern.forge.sdk.copilot.exceptions import CopilotClientDisconnectedError
from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse
@ -85,38 +84,50 @@ async def _stream_events_from(*events: Any) -> Any:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_to_sse_raises_when_send_reports_disconnect() -> None: async def test_stream_to_sse_keeps_running_after_client_disconnect() -> None:
"""SKY-8986 regression: a dropped SSE client must NOT cancel the agent run.
The handler task outlives the SSE response so the agent's reply can be
persisted to the chat history. stream_to_sse keeps draining the SDK's
event stream; emissions turn into no-ops when is_disconnected() returns
True, but result.cancel() is never called and no exception escapes.
"""
from agents.items import RunItem from agents.items import RunItem
from agents.stream_events import RunItemStreamEvent from agents.stream_events import RunItemStreamEvent
raw = {"call_id": "c1", "name": "click", "arguments": "{}"} raw_call = {"call_id": "c1", "name": "click", "arguments": "{}"}
item = MagicMock(spec=RunItem) call_item = MagicMock(spec=RunItem)
item.raw_item = raw call_item.raw_item = raw_call
event = RunItemStreamEvent(name="tool_called", item=item) tool_call = RunItemStreamEvent(name="tool_called", item=call_item)
raw_output = {"call_id": "c1"}
output_item = MagicMock(spec=RunItem)
output_item.raw_item = raw_output
output_item.output = None
tool_output = RunItemStreamEvent(name="tool_output", item=output_item)
result = MagicMock() result = MagicMock()
result.stream_events = lambda: _stream_events_from(event) result.stream_events = lambda: _stream_events_from(tool_call, tool_output)
result.cancel = MagicMock() result.cancel = MagicMock()
stream = MagicMock() stream = MagicMock()
stream.is_disconnected = AsyncMock(return_value=False) stream.is_disconnected = AsyncMock(return_value=True)
stream.send = AsyncMock(return_value=False) stream.send = AsyncMock(return_value=True)
ctx = SimpleNamespace() ctx = SimpleNamespace()
with pytest.raises(CopilotClientDisconnectedError):
await stream_to_sse(result, stream, ctx) await stream_to_sse(result, stream, ctx)
result.cancel.assert_called_once() result.cancel.assert_not_called()
stream.send.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_to_sse_propagates_cancelled_error() -> None: async def test_stream_to_sse_propagates_cancelled_error() -> None:
"""A generic asyncio.CancelledError must not be relabeled as a client """A generic asyncio.CancelledError must propagate up from stream_to_sse so
disconnect. Relabeling would silence cancellation (CopilotClientDisconnectedError the event loop's cancellation machinery still works for task-group cancel,
is a plain Exception, CancelledError is BaseException) and break the event upstream timeout, or parent abort. The adapter must not catch it and turn
loop's cancellation machinery for non-disconnect cancels such as task-group it into a normal return.
cancel, upstream timeout, or parent abort.
""" """
async def _raises_cancelled() -> Any: async def _raises_cancelled() -> Any:

View file

@ -0,0 +1,151 @@
"""Regression tests for SKY-8986: SSE disconnect must not kill the handler.
The SSE stream is a view of work the backend is doing on the client's behalf
(e.g., the workflow copilot agent). Closing the browser tab or losing the TCP
connection mid-stream used to cancel the handler task, which in turn
cancelled the agent run and lost the unpersisted chat reply. The fix in
SKY-8986 decouples the handler from the SSE response lifecycle: the handler
runs to completion even after the client goes away, and subsequent send()
calls drop their payload silently instead of backing up the in-memory queue.
"""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.routes.event_source_stream import (
_BACKGROUND_HANDLER_TASKS,
FastAPIEventSourceStream,
)
def _make_request(is_disconnected_values: list[bool]) -> MagicMock:
"""Build a fake Starlette Request whose is_disconnected() replays a script."""
request = MagicMock()
request.is_disconnected = AsyncMock(side_effect=is_disconnected_values)
return request
@pytest.mark.asyncio
async def test_send_drops_events_silently_after_disconnect() -> None:
"""Once the client is gone, send() returns True but does not queue events.
If send() queued instead, a long-running agent would grow the queue
unbounded (no one is reading it) and leak memory until the process
was restarted.
"""
request = _make_request([False, True, True, True])
stream = FastAPIEventSourceStream(request)
first = await stream.send({"n": 1})
second = await stream.send({"n": 2})
third = await stream.send({"n": 3})
assert first is True
assert second is True
assert third is True
# Only the first send (when connected) queued anything.
assert stream._queue.qsize() == 1
@pytest.mark.asyncio
async def test_is_disconnected_latches_after_first_positive() -> None:
"""Avoid hammering the ASGI receive channel after the response is torn down.
The underlying Request.is_disconnected reads from the receive channel,
which may not be live after the ASGI task group has exited. Once we
observe disconnect once, cache it so later calls don't hit the channel.
"""
request = MagicMock()
# If we didn't cache, the second call would raise.
request.is_disconnected = AsyncMock(side_effect=[True, RuntimeError("channel closed")])
stream = FastAPIEventSourceStream(request)
assert await stream.is_disconnected() is True
assert await stream.is_disconnected() is True
# Only one underlying call thanks to caching.
assert request.is_disconnected.await_count == 1
@pytest.mark.asyncio
async def test_is_disconnected_treats_exception_as_disconnect() -> None:
"""If checking the ASGI receive channel fails, assume the client is gone.
This protects handlers that keep running after the response has been
torn down: they still call is_disconnected periodically and must not
crash on a stale receive channel.
"""
request = MagicMock()
request.is_disconnected = AsyncMock(side_effect=RuntimeError("closed"))
stream = FastAPIEventSourceStream(request)
assert await stream.is_disconnected() is True
@pytest.mark.asyncio
async def test_handler_runs_to_completion_after_sse_generator_exits() -> None:
"""SKY-8986 regression: handler must NOT be cancelled on client disconnect.
Simulates a client that disconnects immediately (generator receives no
events before exiting). The handler should keep running in the
background and finish its work. This is the bug fix: the previous
implementation cancelled the handler task in the generator's finally
block, killing an in-flight copilot agent.
"""
request = MagicMock()
request.is_disconnected = AsyncMock(return_value=True)
handler_finished = asyncio.Event()
async def handler(stream: Any) -> None:
# Simulate agent work that takes some time and runs past the
# moment the SSE generator decides the client is gone.
await asyncio.sleep(0.01)
await stream.send({"progress": "halfway"})
await asyncio.sleep(0.01)
await stream.send({"progress": "done"})
handler_finished.set()
response = FastAPIEventSourceStream.create(request, handler)
# The EventSourceResponse body iterator should close immediately since
# the client is already disconnected, but the handler keeps running in
# the background. We only care that the handler eventually finishes.
async for _ in response.body_iterator: # drain (should be empty)
pass
await asyncio.wait_for(handler_finished.wait(), timeout=2.0)
# One task can remain in the registry briefly while its done callback
# fires; yield once so the callback runs.
await asyncio.sleep(0)
# The registry is cleaned up once the task is fully done.
assert not any(not t.done() for t in _BACKGROUND_HANDLER_TASKS)
@pytest.mark.asyncio
async def test_handler_exception_does_not_break_other_streams() -> None:
"""An error inside a handler after disconnect must not crash the process.
The handler runs as a background task after disconnect; without the
catch-and-log inside _run_handler an unhandled exception would surface
only as an asyncio warning at GC time.
"""
request = MagicMock()
request.is_disconnected = AsyncMock(return_value=True)
async def handler(stream: Any) -> None:
raise RuntimeError("boom")
response = FastAPIEventSourceStream.create(request, handler)
async for _ in response.body_iterator:
pass
# Allow the handler task to complete and its done-callback to fire.
await asyncio.sleep(0.05)
# No exception propagated out of the ASGI response iteration.
# Background tasks are cleaned from the registry on completion.
assert not any(not t.done() for t in _BACKGROUND_HANDLER_TASKS)

View file

@ -15,6 +15,7 @@ real database -- all DB / LLM / agent surfaces are patched.
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -144,7 +145,7 @@ def _setup_new_copilot_mocks(
get_workflow_copilot_chat_messages=AsyncMock(return_value=[]), get_workflow_copilot_chat_messages=AsyncMock(return_value=[]),
update_workflow_copilot_chat=AsyncMock(), update_workflow_copilot_chat=AsyncMock(),
create_workflow_copilot_chat_message=AsyncMock( create_workflow_copilot_chat_message=AsyncMock(
return_value=SimpleNamespace(created_at=SimpleNamespace(isoformat=lambda: "2026-04-14T00:00:00Z")) return_value=SimpleNamespace(created_at=datetime(2026, 4, 14, tzinfo=timezone.utc))
), ),
) )
app.DATABASE.workflows = SimpleNamespace( app.DATABASE.workflows = SimpleNamespace(