mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
fix(SKY-8986): stop SSE disconnect from killing the copilot agent (#5560)
This commit is contained in:
parent
e63689c981
commit
2fa21f8799
10 changed files with 334 additions and 140 deletions
|
|
@ -281,7 +281,6 @@ async def run_copilot_agent(
|
||||||
|
|
||||||
from skyvern.cli.mcp_tools import mcp as skyvern_mcp
|
from skyvern.cli.mcp_tools import mcp as skyvern_mcp
|
||||||
from skyvern.forge.sdk.copilot.enforcement import (
|
from skyvern.forge.sdk.copilot.enforcement import (
|
||||||
CopilotClientDisconnectedError,
|
|
||||||
CopilotTotalTimeoutError,
|
CopilotTotalTimeoutError,
|
||||||
run_with_enforcement,
|
run_with_enforcement,
|
||||||
)
|
)
|
||||||
|
|
@ -392,8 +391,8 @@ async def run_copilot_agent(
|
||||||
chat_request,
|
chat_request,
|
||||||
organization_id,
|
organization_id,
|
||||||
)
|
)
|
||||||
except (CopilotClientDisconnectedError, asyncio.CancelledError):
|
except asyncio.CancelledError:
|
||||||
LOG.info("Copilot client disconnected")
|
LOG.info("Copilot run cancelled")
|
||||||
return AgentResult(
|
return AgentResult(
|
||||||
user_response="Request cancelled.",
|
user_response="Request cancelled.",
|
||||||
updated_workflow=ctx.last_workflow,
|
updated_workflow=ctx.last_workflow,
|
||||||
|
|
|
||||||
|
|
@ -246,10 +246,6 @@ def _is_progress_narration(user_response: Any) -> bool:
|
||||||
return any(pattern.search(user_response) for pattern in _PROGRESS_NARRATION_PATTERNS)
|
return any(pattern.search(user_response) for pattern in _PROGRESS_NARRATION_PATTERNS)
|
||||||
|
|
||||||
|
|
||||||
class CopilotClientDisconnectedError(Exception):
|
|
||||||
"""Raised when the client disconnects during agent execution."""
|
|
||||||
|
|
||||||
|
|
||||||
class CopilotTotalTimeoutError(Exception):
|
class CopilotTotalTimeoutError(Exception):
|
||||||
"""Raised when the copilot agent exceeds the total allowed runtime."""
|
"""Raised when the copilot agent exceeds the total allowed runtime."""
|
||||||
|
|
||||||
|
|
@ -853,9 +849,10 @@ async def run_with_enforcement(
|
||||||
pending_recovery_nudge: str | None = None
|
pending_recovery_nudge: str | None = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if await stream.is_disconnected():
|
# Client disconnect is no longer treated as a stop signal. The
|
||||||
raise CopilotClientDisconnectedError()
|
# SSE stream silently drops events once the browser is gone, but
|
||||||
|
# the agent keeps running so the reply can be persisted to the
|
||||||
|
# chat history on the server side (see SKY-8986).
|
||||||
elapsed = time.monotonic() - start_time
|
elapsed = time.monotonic() - start_time
|
||||||
if elapsed > TOTAL_TIMEOUT_SECONDS:
|
if elapsed > TOTAL_TIMEOUT_SECONDS:
|
||||||
raise CopilotTotalTimeoutError()
|
raise CopilotTotalTimeoutError()
|
||||||
|
|
@ -924,9 +921,6 @@ async def run_with_enforcement(
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if await stream.is_disconnected():
|
|
||||||
raise CopilotClientDisconnectedError()
|
|
||||||
|
|
||||||
# Inject pending screenshots as a follow-up user message because OpenAI
|
# Inject pending screenshots as a follow-up user message because OpenAI
|
||||||
# rejects images in tool messages.
|
# rejects images in tool messages.
|
||||||
screenshot_msg = _consume_pending_screenshots(ctx)
|
screenshot_msg = _consume_pending_screenshots(ctx)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from skyvern.exceptions import SkyvernException
|
|
||||||
|
|
||||||
|
|
||||||
class CopilotClientDisconnectedError(SkyvernException):
|
|
||||||
"""Raised when the SSE client disconnects during a streaming copilot run."""
|
|
||||||
|
|
@ -11,7 +11,6 @@ import structlog
|
||||||
# Reuse the HTTP-logging redactor so SSE tool inputs and request-body logs
|
# Reuse the HTTP-logging redactor so SSE tool inputs and request-body logs
|
||||||
# share one exact-match sensitive-key policy.
|
# share one exact-match sensitive-key policy.
|
||||||
from skyvern.forge.request_logging import redact_sensitive_fields
|
from skyvern.forge.request_logging import redact_sensitive_fields
|
||||||
from skyvern.forge.sdk.copilot.exceptions import CopilotClientDisconnectedError
|
|
||||||
from skyvern.forge.sdk.copilot.output_utils import summarize_tool_result
|
from skyvern.forge.sdk.copilot.output_utils import summarize_tool_result
|
||||||
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||||
WorkflowCopilotStreamMessageType,
|
WorkflowCopilotStreamMessageType,
|
||||||
|
|
@ -51,13 +50,15 @@ async def stream_to_sse(
|
||||||
``post_update_nudge_count``, ``navigate_called``, and
|
``post_update_nudge_count``, ``navigate_called``, and
|
||||||
``observation_after_navigate``.
|
``observation_after_navigate``.
|
||||||
|
|
||||||
A true client disconnect -- detected by ``stream.is_disconnected()`` or a
|
A client disconnect does NOT cancel the agent run: we continue to iterate
|
||||||
failed ``stream.send()`` -- cancels the agent run and raises
|
``result.stream_events()`` so the agent completes whatever work it is
|
||||||
``CopilotClientDisconnectedError``. A plain ``asyncio.CancelledError``
|
in the middle of and the caller can persist the reply to the DB. Events
|
||||||
from some other source (task-group cancel, upstream timeout, parent
|
sent through ``stream.send`` after disconnect are silently dropped by the
|
||||||
abort) is allowed to propagate unchanged so asyncio's cancellation
|
stream, so the queue cannot grow unbounded.
|
||||||
machinery runs normally; callers that want the two to share a UX path
|
|
||||||
should catch both at the call site.
|
Real asyncio cancellation (server shutdown, parent task cancelled for
|
||||||
|
reasons unrelated to a dropped client) is re-raised unchanged so
|
||||||
|
asyncio's cancellation machinery still runs normally.
|
||||||
"""
|
"""
|
||||||
from agents.stream_events import RunItemStreamEvent
|
from agents.stream_events import RunItemStreamEvent
|
||||||
|
|
||||||
|
|
@ -69,39 +70,40 @@ async def stream_to_sse(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for event in result.stream_events():
|
async for event in result.stream_events():
|
||||||
if await stream.is_disconnected():
|
|
||||||
result.cancel()
|
|
||||||
raise CopilotClientDisconnectedError()
|
|
||||||
if not isinstance(event, RunItemStreamEvent):
|
if not isinstance(event, RunItemStreamEvent):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip emission work (serialization, redaction) once the client
|
||||||
|
# is gone, but keep draining the SDK stream so the agent can
|
||||||
|
# finish. stream.send below would drop the payload anyway.
|
||||||
|
client_gone = await stream.is_disconnected()
|
||||||
|
|
||||||
if event.name == "tool_called":
|
if event.name == "tool_called":
|
||||||
raw = event.item.raw_item
|
raw = event.item.raw_item
|
||||||
call_id = _get_raw_field(raw, "call_id") or _get_raw_field(raw, "id") or ""
|
call_id = _get_raw_field(raw, "call_id") or _get_raw_field(raw, "id") or ""
|
||||||
tool_name = _get_raw_field(raw, "name") or "unknown"
|
tool_name = _get_raw_field(raw, "name") or "unknown"
|
||||||
call_id_to_name[call_id] = tool_name
|
call_id_to_name[call_id] = tool_name
|
||||||
|
|
||||||
raw_args = _get_raw_field(raw, "arguments")
|
if not client_gone:
|
||||||
tool_input: dict[str, Any] = {}
|
raw_args = _get_raw_field(raw, "arguments")
|
||||||
if isinstance(raw_args, str):
|
tool_input: dict[str, Any] = {}
|
||||||
try:
|
if isinstance(raw_args, str):
|
||||||
tool_input = json.loads(raw_args)
|
try:
|
||||||
except (json.JSONDecodeError, TypeError):
|
tool_input = json.loads(raw_args)
|
||||||
tool_input = {"raw": raw_args}
|
except (json.JSONDecodeError, TypeError):
|
||||||
elif isinstance(raw_args, dict):
|
tool_input = {"raw": raw_args}
|
||||||
tool_input = raw_args
|
elif isinstance(raw_args, dict):
|
||||||
|
tool_input = raw_args
|
||||||
|
|
||||||
if not await stream.send(
|
await stream.send(
|
||||||
WorkflowCopilotToolCallUpdate(
|
WorkflowCopilotToolCallUpdate(
|
||||||
type=WorkflowCopilotStreamMessageType.TOOL_CALL,
|
type=WorkflowCopilotStreamMessageType.TOOL_CALL,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
tool_input=_sanitize_input(tool_input),
|
tool_input=_sanitize_input(tool_input),
|
||||||
iteration=iteration,
|
iteration=iteration,
|
||||||
tool_call_id=call_id,
|
tool_call_id=call_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
):
|
|
||||||
result.cancel()
|
|
||||||
raise CopilotClientDisconnectedError()
|
|
||||||
|
|
||||||
elif event.name == "tool_output":
|
elif event.name == "tool_output":
|
||||||
raw = event.item.raw_item
|
raw = event.item.raw_item
|
||||||
|
|
@ -110,31 +112,27 @@ async def stream_to_sse(
|
||||||
|
|
||||||
output = getattr(event.item, "output", None)
|
output = getattr(event.item, "output", None)
|
||||||
parsed = parse_tool_output(output)
|
parsed = parse_tool_output(output)
|
||||||
summary = summarize_tool_result(tool_name, parsed)
|
|
||||||
success = parsed.get("ok", True)
|
|
||||||
|
|
||||||
if not await stream.send(
|
if not client_gone:
|
||||||
WorkflowCopilotToolResultUpdate(
|
summary = summarize_tool_result(tool_name, parsed)
|
||||||
type=WorkflowCopilotStreamMessageType.TOOL_RESULT,
|
success = parsed.get("ok", True)
|
||||||
tool_name=tool_name,
|
await stream.send(
|
||||||
success=success,
|
WorkflowCopilotToolResultUpdate(
|
||||||
summary=summary,
|
type=WorkflowCopilotStreamMessageType.TOOL_RESULT,
|
||||||
iteration=iteration,
|
tool_name=tool_name,
|
||||||
tool_call_id=call_id,
|
success=success,
|
||||||
|
summary=summary,
|
||||||
|
iteration=iteration,
|
||||||
|
tool_call_id=call_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
):
|
|
||||||
result.cancel()
|
|
||||||
raise CopilotClientDisconnectedError()
|
|
||||||
|
|
||||||
_update_enforcement_from_tool(ctx, tool_name, parsed)
|
_update_enforcement_from_tool(ctx, tool_name, parsed)
|
||||||
iteration += 1
|
iteration += 1
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Don't relabel generic cancellation as a client disconnect -- the
|
# Real cancellation (server shutdown, upstream abort). Propagate so
|
||||||
# inline is_disconnected / send-failure branches above already raise
|
# asyncio's task machinery sees the cancel; also cancel the SDK
|
||||||
# CopilotClientDisconnectedError when there is real evidence of a
|
# run to free provider resources.
|
||||||
# dropped client. Preserve cancellation semantics by re-raising so
|
|
||||||
# asyncio's task machinery sees the cancel and any upstream
|
|
||||||
# except Exception does NOT swallow it.
|
|
||||||
result.cancel()
|
result.cancel()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -826,9 +826,12 @@ async def _run_blocks_and_collect_debug(
|
||||||
final_status = run.status
|
final_status = run.status
|
||||||
break
|
break
|
||||||
|
|
||||||
if await ctx.stream.is_disconnected():
|
# Deliberately do NOT short-circuit on client disconnect here.
|
||||||
await _cancel_run_task_if_not_final(run_task, workflow_run.workflow_run_id)
|
# The agent loop is allowed to run to completion after the SSE
|
||||||
return {"ok": False, "error": "Client disconnected during block execution."}
|
# stream is gone (see SKY-8986) so its reply can be persisted;
|
||||||
|
# aborting an in-flight block execution would leave the
|
||||||
|
# workflow run in a limbo state and the agent would have no
|
||||||
|
# debug output to summarize in the final chat message.
|
||||||
|
|
||||||
run = await app.DATABASE.workflow_runs.get_workflow_run(
|
run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,22 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Awaitable, Callable, Protocol
|
from typing import Any, Awaitable, Callable, Protocol
|
||||||
|
|
||||||
|
import structlog
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
from sse_starlette import EventSourceResponse, JSONServerSentEvent, ServerSentEvent
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10
|
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 10
|
||||||
|
|
||||||
|
# Strong references to handler tasks that outlive their SSE response. When a
|
||||||
|
# client disconnects mid-stream, we let the handler keep running so any
|
||||||
|
# in-flight work (agent loops, DB persistence) finishes cleanly. Without a
|
||||||
|
# strong reference, the event loop could garbage-collect the task while it
|
||||||
|
# is still running.
|
||||||
|
_BACKGROUND_HANDLER_TASKS: set["asyncio.Task[None]"] = set()
|
||||||
|
|
||||||
|
|
||||||
class EventSourceStream(Protocol):
|
class EventSourceStream(Protocol):
|
||||||
"""Protocol for Server-Sent Events (SSE) streams."""
|
"""Protocol for Server-Sent Events (SSE) streams."""
|
||||||
|
|
@ -16,7 +26,11 @@ class EventSourceStream(Protocol):
|
||||||
Send data as an SSE event.
|
Send data as an SSE event.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the event was queued successfully, False if disconnected or closed.
|
True if the event was accepted (queued for delivery or dropped
|
||||||
|
because the client is gone). False only if the stream has been
|
||||||
|
explicitly closed. Callers should treat a False return as a
|
||||||
|
terminal state and stop emitting; they should NOT treat a
|
||||||
|
client disconnect as a reason to abort in-flight work.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
@ -33,8 +47,12 @@ class FastAPIEventSourceStream:
|
||||||
"""
|
"""
|
||||||
FastAPI implementation of EventSourceStream.
|
FastAPI implementation of EventSourceStream.
|
||||||
|
|
||||||
This class provides a cleaner interface for sending SSE updates from async functions
|
Sending is decoupled from client presence. When the client disconnects,
|
||||||
instead of using yield-based generators directly.
|
the handler task keeps running (so agent loops and DB persistence
|
||||||
|
complete) and subsequent send() calls silently drop their payload
|
||||||
|
instead of growing the queue. The generator stops yielding and
|
||||||
|
sse-starlette closes the HTTP response; the handler runs out in the
|
||||||
|
background.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@app.post("/stream")
|
@app.post("/stream")
|
||||||
|
|
@ -51,22 +69,52 @@ class FastAPIEventSourceStream:
|
||||||
self._request = request
|
self._request = request
|
||||||
self._queue: asyncio.Queue[Any] = asyncio.Queue()
|
self._queue: asyncio.Queue[Any] = asyncio.Queue()
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
# Latches once the client goes away so repeated is_disconnected()
|
||||||
|
# calls don't hit the ASGI receive channel after the response has
|
||||||
|
# been torn down.
|
||||||
|
self._client_gone = False
|
||||||
|
|
||||||
async def send(self, data: Any) -> bool:
|
async def send(self, data: Any) -> bool:
|
||||||
"""
|
"""Send data as an SSE event. Accepts Pydantic models or dicts.
|
||||||
Send data as an SSE event. Accepts Pydantic models or dicts.
|
|
||||||
|
|
||||||
Returns:
|
When the client is still connected, the event is queued for
|
||||||
True if the event was queued successfully, False if disconnected or closed.
|
delivery. When the client has disconnected, the event is dropped
|
||||||
|
silently so the handler can continue to completion without
|
||||||
|
growing memory. Returns False only if the stream has been
|
||||||
|
explicitly closed.
|
||||||
"""
|
"""
|
||||||
if self._closed or await self.is_disconnected():
|
if self._closed:
|
||||||
return False
|
return False
|
||||||
|
if await self.is_disconnected():
|
||||||
|
return True
|
||||||
await self._queue.put(data)
|
await self._queue.put(data)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def is_disconnected(self) -> bool:
|
async def is_disconnected(self) -> bool:
|
||||||
"""Check if the client has disconnected."""
|
"""Check if the client has disconnected.
|
||||||
return await self._request.is_disconnected()
|
|
||||||
|
Caches the first positive result so callers made after the ASGI
|
||||||
|
response has been torn down don't try to pull from a receive
|
||||||
|
channel that may no longer be live.
|
||||||
|
"""
|
||||||
|
if self._client_gone:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
disconnected = await self._request.is_disconnected()
|
||||||
|
except Exception as exc:
|
||||||
|
# Starlette's is_disconnected can raise various errors once
|
||||||
|
# the ASGI receive channel is gone (RuntimeError on closed
|
||||||
|
# queues, anyio/asyncio cancellation oddities, etc.). Treat
|
||||||
|
# any such failure as "client gone" -- we're polling to
|
||||||
|
# decide whether to skip emitting events, and if we can't
|
||||||
|
# tell we'd rather stop emitting than spin. Log so a
|
||||||
|
# genuinely new failure mode shows up in telemetry instead
|
||||||
|
# of hiding behind the disconnect path.
|
||||||
|
LOG.debug("is_disconnected raised; treating as disconnect", error=str(exc))
|
||||||
|
disconnected = True
|
||||||
|
if disconnected:
|
||||||
|
self._client_gone = True
|
||||||
|
return disconnected
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Signal that the stream is complete."""
|
"""Signal that the stream is complete."""
|
||||||
|
|
@ -111,19 +159,27 @@ class FastAPIEventSourceStream:
|
||||||
An EventSourceResponse that can be returned from a FastAPI endpoint
|
An EventSourceResponse that can be returned from a FastAPI endpoint
|
||||||
"""
|
"""
|
||||||
stream = cls(request)
|
stream = cls(request)
|
||||||
|
task = asyncio.create_task(cls._run_handler(stream, handler))
|
||||||
|
# Hold a strong reference so the event loop can't GC the task if the
|
||||||
|
# generator is torn down first (SSE client disconnect). The set is
|
||||||
|
# bounded in practice by the handler's own timeout — every copilot
|
||||||
|
# handler must eventually return (see TOTAL_TIMEOUT_SECONDS in
|
||||||
|
# skyvern/forge/sdk/copilot/enforcement.py) which removes the task
|
||||||
|
# from the set via the done_callback below. If you wire a new
|
||||||
|
# EventSource endpoint through here, give its handler a similar
|
||||||
|
# hard cap or this set becomes an unbounded leak.
|
||||||
|
_BACKGROUND_HANDLER_TASKS.add(task)
|
||||||
|
task.add_done_callback(_BACKGROUND_HANDLER_TASKS.discard)
|
||||||
|
|
||||||
async def event_generator() -> Any:
|
async def event_generator() -> Any:
|
||||||
task = asyncio.create_task(cls._run_handler(stream, handler))
|
async for event in stream._generate():
|
||||||
try:
|
yield event
|
||||||
async for event in stream._generate():
|
# Intentionally do NOT cancel the handler task here. SSE
|
||||||
yield event
|
# disconnect must not kill in-flight work (the copilot agent
|
||||||
finally:
|
# often needs tens of seconds to finish and then persist its
|
||||||
if not task.done():
|
# reply to the chat history). The handler keeps running in
|
||||||
task.cancel()
|
# the background; stream.send() silently drops events once
|
||||||
try:
|
# the client is gone so the queue cannot grow unbounded.
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def ping_message_factory() -> ServerSentEvent:
|
def ping_message_factory() -> ServerSentEvent:
|
||||||
return ServerSentEvent(comment="keep-alive")
|
return ServerSentEvent(comment="keep-alive")
|
||||||
|
|
@ -142,5 +198,11 @@ class FastAPIEventSourceStream:
|
||||||
"""Run the handler and ensure the stream is closed when done."""
|
"""Run the handler and ensure the stream is closed when done."""
|
||||||
try:
|
try:
|
||||||
await handler(stream)
|
await handler(stream)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Process/server shutdown — propagate. Client disconnect alone
|
||||||
|
# no longer cancels this task (see create()).
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("SSE handler failed")
|
||||||
finally:
|
finally:
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
|
||||||
|
|
@ -795,12 +795,9 @@ async def _new_copilot_chat_post(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if await stream.is_disconnected():
|
# No early exit on disconnect (SKY-8986): the agent runs to
|
||||||
LOG.info(
|
# completion even after the SSE stream drops so its reply is
|
||||||
"Workflow copilot v2 chat request is disconnected before agent loop",
|
# persisted to the chat history and visible after reconnect.
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id(
|
original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id(
|
||||||
workflow_permanent_id=chat_request.workflow_permanent_id,
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
||||||
|
|
@ -838,15 +835,11 @@ async def _new_copilot_chat_post(
|
||||||
updated_workflow = agent_result.updated_workflow
|
updated_workflow = agent_result.updated_workflow
|
||||||
updated_global_llm_context = agent_result.global_llm_context
|
updated_global_llm_context = agent_result.global_llm_context
|
||||||
|
|
||||||
if await stream.is_disconnected():
|
# Persist rollback / proposed-workflow state and the chat
|
||||||
LOG.info(
|
# messages regardless of whether the SSE client is still
|
||||||
"Workflow copilot v2 chat request is disconnected after agent loop",
|
# connected: the user needs to see the reply on reconnect.
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
# SKY-8986: client disconnect used to short-circuit this block
|
||||||
)
|
# and leave the chat history without the AI response.
|
||||||
if _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
|
||||||
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if chat.auto_accept is not True:
|
if chat.auto_accept is not True:
|
||||||
if _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
if _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
||||||
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
||||||
|
|
@ -1017,13 +1010,9 @@ async def workflow_copilot_chat_post(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if await stream.is_disconnected():
|
# SKY-8986: do not short-circuit on client disconnect. The LLM
|
||||||
LOG.info(
|
# call and the DB persistence below must complete so the reply
|
||||||
"Workflow copilot chat request is disconnected before LLM call",
|
# is in the chat history when the user reconnects.
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
|
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
|
||||||
stream,
|
stream,
|
||||||
organization.organization_id,
|
organization.organization_id,
|
||||||
|
|
@ -1033,13 +1022,6 @@ async def workflow_copilot_chat_post(
|
||||||
debug_run_info_text,
|
debug_run_info_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if await stream.is_disconnected():
|
|
||||||
LOG.info(
|
|
||||||
"Workflow copilot chat request is disconnected after LLM call",
|
|
||||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if updated_workflow and chat.auto_accept is not True:
|
if updated_workflow and chat.auto_accept is not True:
|
||||||
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
||||||
organization_id=chat.organization_id,
|
organization_id=chat.organization_id,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from skyvern.forge.sdk.copilot.exceptions import CopilotClientDisconnectedError
|
|
||||||
from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse
|
from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,38 +84,50 @@ async def _stream_events_from(*events: Any) -> Any:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_to_sse_raises_when_send_reports_disconnect() -> None:
|
async def test_stream_to_sse_keeps_running_after_client_disconnect() -> None:
|
||||||
|
"""SKY-8986 regression: a dropped SSE client must NOT cancel the agent run.
|
||||||
|
|
||||||
|
The handler task outlives the SSE response so the agent's reply can be
|
||||||
|
persisted to the chat history. stream_to_sse keeps draining the SDK's
|
||||||
|
event stream; emissions turn into no-ops when is_disconnected() returns
|
||||||
|
True, but result.cancel() is never called and no exception escapes.
|
||||||
|
"""
|
||||||
from agents.items import RunItem
|
from agents.items import RunItem
|
||||||
from agents.stream_events import RunItemStreamEvent
|
from agents.stream_events import RunItemStreamEvent
|
||||||
|
|
||||||
raw = {"call_id": "c1", "name": "click", "arguments": "{}"}
|
raw_call = {"call_id": "c1", "name": "click", "arguments": "{}"}
|
||||||
item = MagicMock(spec=RunItem)
|
call_item = MagicMock(spec=RunItem)
|
||||||
item.raw_item = raw
|
call_item.raw_item = raw_call
|
||||||
event = RunItemStreamEvent(name="tool_called", item=item)
|
tool_call = RunItemStreamEvent(name="tool_called", item=call_item)
|
||||||
|
|
||||||
|
raw_output = {"call_id": "c1"}
|
||||||
|
output_item = MagicMock(spec=RunItem)
|
||||||
|
output_item.raw_item = raw_output
|
||||||
|
output_item.output = None
|
||||||
|
tool_output = RunItemStreamEvent(name="tool_output", item=output_item)
|
||||||
|
|
||||||
result = MagicMock()
|
result = MagicMock()
|
||||||
result.stream_events = lambda: _stream_events_from(event)
|
result.stream_events = lambda: _stream_events_from(tool_call, tool_output)
|
||||||
result.cancel = MagicMock()
|
result.cancel = MagicMock()
|
||||||
|
|
||||||
stream = MagicMock()
|
stream = MagicMock()
|
||||||
stream.is_disconnected = AsyncMock(return_value=False)
|
stream.is_disconnected = AsyncMock(return_value=True)
|
||||||
stream.send = AsyncMock(return_value=False)
|
stream.send = AsyncMock(return_value=True)
|
||||||
|
|
||||||
ctx = SimpleNamespace()
|
ctx = SimpleNamespace()
|
||||||
|
|
||||||
with pytest.raises(CopilotClientDisconnectedError):
|
await stream_to_sse(result, stream, ctx)
|
||||||
await stream_to_sse(result, stream, ctx)
|
|
||||||
|
|
||||||
result.cancel.assert_called_once()
|
result.cancel.assert_not_called()
|
||||||
|
stream.send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_to_sse_propagates_cancelled_error() -> None:
|
async def test_stream_to_sse_propagates_cancelled_error() -> None:
|
||||||
"""A generic asyncio.CancelledError must not be relabeled as a client
|
"""A generic asyncio.CancelledError must propagate up from stream_to_sse so
|
||||||
disconnect. Relabeling would silence cancellation (CopilotClientDisconnectedError
|
the event loop's cancellation machinery still works for task-group cancel,
|
||||||
is a plain Exception, CancelledError is BaseException) and break the event
|
upstream timeout, or parent abort. The adapter must not catch it and turn
|
||||||
loop's cancellation machinery for non-disconnect cancels such as task-group
|
it into a normal return.
|
||||||
cancel, upstream timeout, or parent abort.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _raises_cancelled() -> Any:
|
async def _raises_cancelled() -> Any:
|
||||||
|
|
|
||||||
151
tests/unit/test_event_source_stream_disconnect.py
Normal file
151
tests/unit/test_event_source_stream_disconnect.py
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
"""Regression tests for SKY-8986: SSE disconnect must not kill the handler.
|
||||||
|
|
||||||
|
The SSE stream is a view of work the backend is doing on the client's behalf
|
||||||
|
(e.g., the workflow copilot agent). Closing the browser tab or losing the TCP
|
||||||
|
connection mid-stream used to cancel the handler task, which in turn
|
||||||
|
cancelled the agent run and lost the unpersisted chat reply. The fix in
|
||||||
|
SKY-8986 decouples the handler from the SSE response lifecycle: the handler
|
||||||
|
runs to completion even after the client goes away, and subsequent send()
|
||||||
|
calls drop their payload silently instead of backing up the in-memory queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.routes.event_source_stream import (
|
||||||
|
_BACKGROUND_HANDLER_TASKS,
|
||||||
|
FastAPIEventSourceStream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(is_disconnected_values: list[bool]) -> MagicMock:
|
||||||
|
"""Build a fake Starlette Request whose is_disconnected() replays a script."""
|
||||||
|
request = MagicMock()
|
||||||
|
request.is_disconnected = AsyncMock(side_effect=is_disconnected_values)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_drops_events_silently_after_disconnect() -> None:
|
||||||
|
"""Once the client is gone, send() returns True but does not queue events.
|
||||||
|
|
||||||
|
If send() queued instead, a long-running agent would grow the queue
|
||||||
|
unbounded (no one is reading it) and leak memory until the process
|
||||||
|
was restarted.
|
||||||
|
"""
|
||||||
|
request = _make_request([False, True, True, True])
|
||||||
|
stream = FastAPIEventSourceStream(request)
|
||||||
|
|
||||||
|
first = await stream.send({"n": 1})
|
||||||
|
second = await stream.send({"n": 2})
|
||||||
|
third = await stream.send({"n": 3})
|
||||||
|
|
||||||
|
assert first is True
|
||||||
|
assert second is True
|
||||||
|
assert third is True
|
||||||
|
# Only the first send (when connected) queued anything.
|
||||||
|
assert stream._queue.qsize() == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_disconnected_latches_after_first_positive() -> None:
|
||||||
|
"""Avoid hammering the ASGI receive channel after the response is torn down.
|
||||||
|
|
||||||
|
The underlying Request.is_disconnected reads from the receive channel,
|
||||||
|
which may not be live after the ASGI task group has exited. Once we
|
||||||
|
observe disconnect once, cache it so later calls don't hit the channel.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
# If we didn't cache, the second call would raise.
|
||||||
|
request.is_disconnected = AsyncMock(side_effect=[True, RuntimeError("channel closed")])
|
||||||
|
stream = FastAPIEventSourceStream(request)
|
||||||
|
|
||||||
|
assert await stream.is_disconnected() is True
|
||||||
|
assert await stream.is_disconnected() is True
|
||||||
|
# Only one underlying call thanks to caching.
|
||||||
|
assert request.is_disconnected.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_disconnected_treats_exception_as_disconnect() -> None:
|
||||||
|
"""If checking the ASGI receive channel fails, assume the client is gone.
|
||||||
|
|
||||||
|
This protects handlers that keep running after the response has been
|
||||||
|
torn down: they still call is_disconnected periodically and must not
|
||||||
|
crash on a stale receive channel.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.is_disconnected = AsyncMock(side_effect=RuntimeError("closed"))
|
||||||
|
stream = FastAPIEventSourceStream(request)
|
||||||
|
|
||||||
|
assert await stream.is_disconnected() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_runs_to_completion_after_sse_generator_exits() -> None:
|
||||||
|
"""SKY-8986 regression: handler must NOT be cancelled on client disconnect.
|
||||||
|
|
||||||
|
Simulates a client that disconnects immediately (generator receives no
|
||||||
|
events before exiting). The handler should keep running in the
|
||||||
|
background and finish its work. This is the bug fix: the previous
|
||||||
|
implementation cancelled the handler task in the generator's finally
|
||||||
|
block, killing an in-flight copilot agent.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.is_disconnected = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
handler_finished = asyncio.Event()
|
||||||
|
|
||||||
|
async def handler(stream: Any) -> None:
|
||||||
|
# Simulate agent work that takes some time and runs past the
|
||||||
|
# moment the SSE generator decides the client is gone.
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await stream.send({"progress": "halfway"})
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await stream.send({"progress": "done"})
|
||||||
|
handler_finished.set()
|
||||||
|
|
||||||
|
response = FastAPIEventSourceStream.create(request, handler)
|
||||||
|
|
||||||
|
# The EventSourceResponse body iterator should close immediately since
|
||||||
|
# the client is already disconnected, but the handler keeps running in
|
||||||
|
# the background. We only care that the handler eventually finishes.
|
||||||
|
async for _ in response.body_iterator: # drain (should be empty)
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.wait_for(handler_finished.wait(), timeout=2.0)
|
||||||
|
# One task can remain in the registry briefly while its done callback
|
||||||
|
# fires; yield once so the callback runs.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# The registry is cleaned up once the task is fully done.
|
||||||
|
assert not any(not t.done() for t in _BACKGROUND_HANDLER_TASKS)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_exception_does_not_break_other_streams() -> None:
|
||||||
|
"""An error inside a handler after disconnect must not crash the process.
|
||||||
|
|
||||||
|
The handler runs as a background task after disconnect; without the
|
||||||
|
catch-and-log inside _run_handler an unhandled exception would surface
|
||||||
|
only as an asyncio warning at GC time.
|
||||||
|
"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.is_disconnected = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
async def handler(stream: Any) -> None:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
response = FastAPIEventSourceStream.create(request, handler)
|
||||||
|
async for _ in response.body_iterator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Allow the handler task to complete and its done-callback to fire.
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
# No exception propagated out of the ASGI response iteration.
|
||||||
|
# Background tasks are cleaned from the registry on completion.
|
||||||
|
assert not any(not t.done() for t in _BACKGROUND_HANDLER_TASKS)
|
||||||
|
|
@ -15,6 +15,7 @@ real database -- all DB / LLM / agent surfaces are patched.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
|
@ -144,7 +145,7 @@ def _setup_new_copilot_mocks(
|
||||||
get_workflow_copilot_chat_messages=AsyncMock(return_value=[]),
|
get_workflow_copilot_chat_messages=AsyncMock(return_value=[]),
|
||||||
update_workflow_copilot_chat=AsyncMock(),
|
update_workflow_copilot_chat=AsyncMock(),
|
||||||
create_workflow_copilot_chat_message=AsyncMock(
|
create_workflow_copilot_chat_message=AsyncMock(
|
||||||
return_value=SimpleNamespace(created_at=SimpleNamespace(isoformat=lambda: "2026-04-14T00:00:00Z"))
|
return_value=SimpleNamespace(created_at=datetime(2026, 4, 14, tzinfo=timezone.utc))
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
app.DATABASE.workflows = SimpleNamespace(
|
app.DATABASE.workflows = SimpleNamespace(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue