From 2809885b53fb0b6b9af294bf261c1c6a39d1b0a5 Mon Sep 17 00:00:00 2001 From: Andrew Neilson Date: Fri, 24 Apr 2026 09:38:58 -0700 Subject: [PATCH] copilot: guarantee a terminal SSE frame on every chat turn (#5645) --- skyvern/forge/sdk/routes/workflow_copilot.py | 35 +++++++++++++ tests/unit/test_copilot_sse_terminal_frame.py | 49 +++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 tests/unit/test_copilot_sse_terminal_frame.py diff --git a/skyvern/forge/sdk/routes/workflow_copilot.py b/skyvern/forge/sdk/routes/workflow_copilot.py index 96500e739..cf51467c4 100644 --- a/skyvern/forge/sdk/routes/workflow_copilot.py +++ b/skyvern/forge/sdk/routes/workflow_copilot.py @@ -82,6 +82,27 @@ class BlockRunInfo: output: str | None +async def _ensure_terminal_frame(stream: EventSourceStream, already_emitted: bool) -> None: + """Emit a fallback ERROR frame if the turn hasn't sent a terminal one. + + Shielded so cancellation on the outer scope doesn't abort the send; + swallows BaseException so a failed cleanup never masks the original. + """ + if already_emitted: + return + try: + await asyncio.shield( + stream.send( + WorkflowCopilotStreamErrorUpdate( + type=WorkflowCopilotStreamMessageType.ERROR, + error="The assistant didn't finish this turn. Please try again.", + ) + ) + ) + except BaseException: + pass + + def _should_restore_persisted_workflow(auto_accept: bool | None, agent_result: object | None) -> bool: """Return True when a persisted draft should be rolled back. @@ -741,6 +762,7 @@ async def _new_copilot_chat_post( original_workflow: Workflow | None = None chat = None agent_result: Any = None + terminal_frame_emitted = False try: await stream.send( @@ -903,6 +925,7 @@ async def _new_copilot_chat_post( global_llm_context=updated_global_llm_context, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamResponseUpdate( type=WorkflowCopilotStreamMessageType.RESPONSE, @@ -917,6 +940,7 @@ async def _new_copilot_chat_post( except HTTPException as exc: if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result): await _restore_workflow_definition(original_workflow, organization.organization_id) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, @@ -932,6 +956,7 @@ async def _new_copilot_chat_post( error=str(exc), exc_info=True, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, @@ -954,12 +979,15 @@ async def _new_copilot_chat_post( error=str(exc), exc_info=True, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again.", ) ) + finally: + await _ensure_terminal_frame(stream, terminal_frame_emitted) return FastAPIEventSourceStream.create(request, stream_handler) @@ -1008,6 +1036,7 @@ async def workflow_copilot_chat_post( organization_id=organization.organization_id, ) + terminal_frame_emitted = False try: await stream.send( WorkflowCopilotProcessingUpdate( @@ -1096,6 +1125,7 @@ async def workflow_copilot_chat_post( global_llm_context=updated_global_llm_context, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamResponseUpdate( type=WorkflowCopilotStreamMessageType.RESPONSE, @@ -1106,6 +1136,7 @@ async def workflow_copilot_chat_post( ) ) except HTTPException as exc: + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, @@ -1119,6 +1150,7 @@ async def workflow_copilot_chat_post( error=str(exc), exc_info=True, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, @@ -1132,12 +1164,15 @@ async def workflow_copilot_chat_post( error=str(exc), exc_info=True, ) + terminal_frame_emitted = True await stream.send( WorkflowCopilotStreamErrorUpdate( type=WorkflowCopilotStreamMessageType.ERROR, error="An error occurred. Please try again.", ) ) + finally: + await _ensure_terminal_frame(stream, terminal_frame_emitted) return FastAPIEventSourceStream.create(request, stream_handler) diff --git a/tests/unit/test_copilot_sse_terminal_frame.py b/tests/unit/test_copilot_sse_terminal_frame.py new file mode 100644 index 000000000..809095f7b --- /dev/null +++ b/tests/unit/test_copilot_sse_terminal_frame.py @@ -0,0 +1,49 @@ +"""Tests for the copilot SSE terminal-frame invariant (SKY-9232).""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest + +from skyvern.forge.sdk.routes.workflow_copilot import _ensure_terminal_frame + + +class _FakeStream: + def __init__(self, raise_on_send: BaseException | None = None) -> None: + self.sent: list[Any] = [] + self._raise_on_send = raise_on_send + + async def send(self, message: Any) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send + self.sent.append(message) + + +@pytest.mark.asyncio +async def test_ensure_terminal_frame_noop_when_already_emitted() -> None: + stream = _FakeStream() + await _ensure_terminal_frame(stream, already_emitted=True) # type: ignore[arg-type] + assert stream.sent == [] + + +@pytest.mark.asyncio +async def test_ensure_terminal_frame_sends_fallback_error_when_missing() -> None: + stream = _FakeStream() + await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type] + assert len(stream.sent) == 1 + frame = stream.sent[0] + assert getattr(frame, "error", "").startswith("The assistant didn't finish") + + +@pytest.mark.asyncio +async def test_ensure_terminal_frame_swallows_send_exception() -> None: + stream = _FakeStream(raise_on_send=RuntimeError("client already gone")) + await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_ensure_terminal_frame_swallows_send_cancellation() -> None: + stream = _FakeStream(raise_on_send=asyncio.CancelledError()) + await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]