copilot: guarantee a terminal SSE frame on every chat turn (#5645)

This commit is contained in:
Andrew Neilson 2026-04-24 09:38:58 -07:00 committed by GitHub
parent c8915d72c0
commit 2809885b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 84 additions and 0 deletions

View file

@ -82,6 +82,27 @@ class BlockRunInfo:
output: str | None
async def _ensure_terminal_frame(stream: EventSourceStream, already_emitted: bool) -> None:
"""Emit a fallback ERROR frame if the turn hasn't sent a terminal one.
Shielded so cancellation on the outer scope doesn't abort the send;
swallows BaseException so a failed cleanup never masks the original.
"""
if already_emitted:
return
try:
await asyncio.shield(
stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="The assistant didn't finish this turn. Please try again.",
)
)
)
except BaseException:
pass
def _should_restore_persisted_workflow(auto_accept: bool | None, agent_result: object | None) -> bool:
"""Return True when a persisted draft should be rolled back.
@ -741,6 +762,7 @@ async def _new_copilot_chat_post(
original_workflow: Workflow | None = None
chat = None
agent_result: Any = None
terminal_frame_emitted = False
try:
await stream.send(
@ -903,6 +925,7 @@ async def _new_copilot_chat_post(
global_llm_context=updated_global_llm_context,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
@ -917,6 +940,7 @@ async def _new_copilot_chat_post(
except HTTPException as exc:
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
@ -932,6 +956,7 @@ async def _new_copilot_chat_post(
error=str(exc),
exc_info=True,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
@ -954,12 +979,15 @@ async def _new_copilot_chat_post(
error=str(exc),
exc_info=True,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="An error occurred. Please try again.",
)
)
finally:
await _ensure_terminal_frame(stream, terminal_frame_emitted)
return FastAPIEventSourceStream.create(request, stream_handler)
@ -1008,6 +1036,7 @@ async def workflow_copilot_chat_post(
organization_id=organization.organization_id,
)
terminal_frame_emitted = False
try:
await stream.send(
WorkflowCopilotProcessingUpdate(
@ -1096,6 +1125,7 @@ async def workflow_copilot_chat_post(
global_llm_context=updated_global_llm_context,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
@ -1106,6 +1136,7 @@ async def workflow_copilot_chat_post(
)
)
except HTTPException as exc:
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
@ -1119,6 +1150,7 @@ async def workflow_copilot_chat_post(
error=str(exc),
exc_info=True,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
@ -1132,12 +1164,15 @@ async def workflow_copilot_chat_post(
error=str(exc),
exc_info=True,
)
terminal_frame_emitted = True
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="An error occurred. Please try again.",
)
)
finally:
await _ensure_terminal_frame(stream, terminal_frame_emitted)
return FastAPIEventSourceStream.create(request, stream_handler)

View file

@ -0,0 +1,49 @@
"""Tests for the copilot SSE terminal-frame invariant (SKY-9232)."""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from skyvern.forge.sdk.routes.workflow_copilot import _ensure_terminal_frame
class _FakeStream:
def __init__(self, raise_on_send: BaseException | None = None) -> None:
self.sent: list[Any] = []
self._raise_on_send = raise_on_send
async def send(self, message: Any) -> None:
if self._raise_on_send is not None:
raise self._raise_on_send
self.sent.append(message)
@pytest.mark.asyncio
async def test_ensure_terminal_frame_noop_when_already_emitted() -> None:
stream = _FakeStream()
await _ensure_terminal_frame(stream, already_emitted=True) # type: ignore[arg-type]
assert stream.sent == []
@pytest.mark.asyncio
async def test_ensure_terminal_frame_sends_fallback_error_when_missing() -> None:
stream = _FakeStream()
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]
assert len(stream.sent) == 1
frame = stream.sent[0]
assert getattr(frame, "error", "").startswith("The assistant didn't finish")
@pytest.mark.asyncio
async def test_ensure_terminal_frame_swallows_send_exception() -> None:
stream = _FakeStream(raise_on_send=RuntimeError("client already gone"))
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]
@pytest.mark.asyncio
async def test_ensure_terminal_frame_swallows_send_cancellation() -> None:
stream = _FakeStream(raise_on_send=asyncio.CancelledError())
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]