fix(SKY-8986): stop SSE disconnect from killing the copilot agent (#5560)

This commit is contained in:
Andrew Neilson 2026-04-20 11:54:32 -07:00 committed by GitHub
parent e63689c981
commit 2fa21f8799
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 334 additions and 140 deletions

View file

@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.copilot.exceptions import CopilotClientDisconnectedError
from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse
@ -85,38 +84,50 @@ async def _stream_events_from(*events: Any) -> Any:
@pytest.mark.asyncio
async def test_stream_to_sse_raises_when_send_reports_disconnect() -> None:
async def test_stream_to_sse_keeps_running_after_client_disconnect() -> None:
"""SKY-8986 regression: a dropped SSE client must NOT cancel the agent run.
The handler task outlives the SSE response so the agent's reply can be
persisted to the chat history. stream_to_sse keeps draining the SDK's
event stream; emissions turn into no-ops when is_disconnected() returns
True, but result.cancel() is never called and no exception escapes.
"""
from agents.items import RunItem
from agents.stream_events import RunItemStreamEvent
raw = {"call_id": "c1", "name": "click", "arguments": "{}"}
item = MagicMock(spec=RunItem)
item.raw_item = raw
event = RunItemStreamEvent(name="tool_called", item=item)
raw_call = {"call_id": "c1", "name": "click", "arguments": "{}"}
call_item = MagicMock(spec=RunItem)
call_item.raw_item = raw_call
tool_call = RunItemStreamEvent(name="tool_called", item=call_item)
raw_output = {"call_id": "c1"}
output_item = MagicMock(spec=RunItem)
output_item.raw_item = raw_output
output_item.output = None
tool_output = RunItemStreamEvent(name="tool_output", item=output_item)
result = MagicMock()
result.stream_events = lambda: _stream_events_from(event)
result.stream_events = lambda: _stream_events_from(tool_call, tool_output)
result.cancel = MagicMock()
stream = MagicMock()
stream.is_disconnected = AsyncMock(return_value=False)
stream.send = AsyncMock(return_value=False)
stream.is_disconnected = AsyncMock(return_value=True)
stream.send = AsyncMock(return_value=True)
ctx = SimpleNamespace()
with pytest.raises(CopilotClientDisconnectedError):
await stream_to_sse(result, stream, ctx)
await stream_to_sse(result, stream, ctx)
result.cancel.assert_called_once()
result.cancel.assert_not_called()
stream.send.assert_not_called()
@pytest.mark.asyncio
async def test_stream_to_sse_propagates_cancelled_error() -> None:
"""A generic asyncio.CancelledError must not be relabeled as a client
disconnect. Relabeling would silence cancellation (CopilotClientDisconnectedError
is a plain Exception, CancelledError is BaseException) and break the event
loop's cancellation machinery for non-disconnect cancels such as task-group
cancel, upstream timeout, or parent abort.
"""A generic asyncio.CancelledError must propagate up from stream_to_sse so
the event loop's cancellation machinery still works for task-group cancel,
upstream timeout, or parent abort. The adapter must not catch it and turn
it into a normal return.
"""
async def _raises_cancelled() -> Any: