mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
copilot: guarantee a terminal SSE frame on every chat turn (#5645)
This commit is contained in:
parent
c8915d72c0
commit
2809885b53
2 changed files with 84 additions and 0 deletions
|
|
@ -82,6 +82,27 @@ class BlockRunInfo:
|
|||
output: str | None
|
||||
|
||||
|
||||
async def _ensure_terminal_frame(stream: EventSourceStream, already_emitted: bool) -> None:
|
||||
"""Emit a fallback ERROR frame if the turn hasn't sent a terminal one.
|
||||
|
||||
Shielded so cancellation on the outer scope doesn't abort the send;
|
||||
swallows BaseException so a failed cleanup never masks the original.
|
||||
"""
|
||||
if already_emitted:
|
||||
return
|
||||
try:
|
||||
await asyncio.shield(
|
||||
stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
error="The assistant didn't finish this turn. Please try again.",
|
||||
)
|
||||
)
|
||||
)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
|
||||
def _should_restore_persisted_workflow(auto_accept: bool | None, agent_result: object | None) -> bool:
|
||||
"""Return True when a persisted draft should be rolled back.
|
||||
|
||||
|
|
@ -741,6 +762,7 @@ async def _new_copilot_chat_post(
|
|||
original_workflow: Workflow | None = None
|
||||
chat = None
|
||||
agent_result: Any = None
|
||||
terminal_frame_emitted = False
|
||||
|
||||
try:
|
||||
await stream.send(
|
||||
|
|
@ -903,6 +925,7 @@ async def _new_copilot_chat_post(
|
|||
global_llm_context=updated_global_llm_context,
|
||||
)
|
||||
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamResponseUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
||||
|
|
@ -917,6 +940,7 @@ async def _new_copilot_chat_post(
|
|||
except HTTPException as exc:
|
||||
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
||||
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
|
|
@ -932,6 +956,7 @@ async def _new_copilot_chat_post(
|
|||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
|
|
@ -954,12 +979,15 @@ async def _new_copilot_chat_post(
|
|||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
error="An error occurred. Please try again.",
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await _ensure_terminal_frame(stream, terminal_frame_emitted)
|
||||
|
||||
return FastAPIEventSourceStream.create(request, stream_handler)
|
||||
|
||||
|
|
@ -1008,6 +1036,7 @@ async def workflow_copilot_chat_post(
|
|||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
terminal_frame_emitted = False
|
||||
try:
|
||||
await stream.send(
|
||||
WorkflowCopilotProcessingUpdate(
|
||||
|
|
@ -1096,6 +1125,7 @@ async def workflow_copilot_chat_post(
|
|||
global_llm_context=updated_global_llm_context,
|
||||
)
|
||||
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamResponseUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
||||
|
|
@ -1106,6 +1136,7 @@ async def workflow_copilot_chat_post(
|
|||
)
|
||||
)
|
||||
except HTTPException as exc:
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
|
|
@ -1119,6 +1150,7 @@ async def workflow_copilot_chat_post(
|
|||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
|
|
@ -1132,12 +1164,15 @@ async def workflow_copilot_chat_post(
|
|||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
terminal_frame_emitted = True
|
||||
await stream.send(
|
||||
WorkflowCopilotStreamErrorUpdate(
|
||||
type=WorkflowCopilotStreamMessageType.ERROR,
|
||||
error="An error occurred. Please try again.",
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await _ensure_terminal_frame(stream, terminal_frame_emitted)
|
||||
|
||||
return FastAPIEventSourceStream.create(request, stream_handler)
|
||||
|
||||
|
|
|
|||
49
tests/unit/test_copilot_sse_terminal_frame.py
Normal file
49
tests/unit/test_copilot_sse_terminal_frame.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Tests for the copilot SSE terminal-frame invariant (SKY-9232)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.routes.workflow_copilot import _ensure_terminal_frame
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
def __init__(self, raise_on_send: BaseException | None = None) -> None:
|
||||
self.sent: list[Any] = []
|
||||
self._raise_on_send = raise_on_send
|
||||
|
||||
async def send(self, message: Any) -> None:
|
||||
if self._raise_on_send is not None:
|
||||
raise self._raise_on_send
|
||||
self.sent.append(message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_terminal_frame_noop_when_already_emitted() -> None:
|
||||
stream = _FakeStream()
|
||||
await _ensure_terminal_frame(stream, already_emitted=True) # type: ignore[arg-type]
|
||||
assert stream.sent == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_terminal_frame_sends_fallback_error_when_missing() -> None:
|
||||
stream = _FakeStream()
|
||||
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]
|
||||
assert len(stream.sent) == 1
|
||||
frame = stream.sent[0]
|
||||
assert getattr(frame, "error", "").startswith("The assistant didn't finish")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_terminal_frame_swallows_send_exception() -> None:
|
||||
stream = _FakeStream(raise_on_send=RuntimeError("client already gone"))
|
||||
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_terminal_frame_swallows_send_cancellation() -> None:
|
||||
stream = _FakeStream(raise_on_send=asyncio.CancelledError())
|
||||
await _ensure_terminal_frame(stream, already_emitted=False) # type: ignore[arg-type]
|
||||
Loading…
Add table
Add a link
Reference in a new issue