feat(SKY-8879) copilot-stack/06: MCP tools surface + orphan-task cancellation (#5517)

This commit is contained in:
Andrew Neilson 2026-04-15 21:20:16 -07:00 committed by GitHub
parent d58ea46163
commit faa2b233cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2588 additions and 23 deletions

View file

@ -111,6 +111,8 @@ export const helpTooltips = {
...baseHelpTooltipContent,
fileUrl:
"Since we're in beta this section isn't fully customizable yet, contact us if you'd like to integrate it into your workflow.",
fileType:
"The format of the file to parse. Auto-detected from the URL extension when possible.",
},
wait: {
...baseHelpTooltipContent,

View file

@ -1,8 +1,15 @@
import { HelpTooltip } from "@/components/HelpTooltip";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Handle, NodeProps, Position } from "@xyflow/react";
import { helpTooltips } from "../../helpContent";
import { type FileParserNode } from "./types";
import { type FileParserNode, type FileParserFileType } from "./types";
import { WorkflowBlockInput } from "@/components/WorkflowBlockInput";
import { useIsFirstBlockInWorkflow } from "../../hooks/useIsFirstNodeInWorkflow";
import { cn } from "@/util/utils";
@ -16,6 +23,45 @@ import { useUpdate } from "@/routes/workflows/editor/useUpdate";
import { ModelSelector } from "@/components/ModelSelector";
import { useRecordingStore } from "@/store/useRecordingStore";
const FILE_TYPE_OPTIONS: Array<{ value: FileParserFileType; label: string }> = [
{ value: "csv", label: "CSV" },
{ value: "excel", label: "Excel" },
{ value: "pdf", label: "PDF" },
{ value: "image", label: "Image" },
{ value: "docx", label: "DOCX" },
];
const FILE_EXTENSION_TO_TYPE: Record<string, FileParserFileType> = {
csv: "csv",
xlsx: "excel",
xls: "excel",
pdf: "pdf",
png: "image",
jpg: "image",
jpeg: "image",
gif: "image",
webp: "image",
docx: "docx",
};
function detectFileTypeFromUrl(url: string): FileParserFileType | null {
try {
const urlObj = new URL(url);
const pathname = urlObj.pathname;
const ext = pathname.split(".").pop()?.toLowerCase();
if (ext && ext in FILE_EXTENSION_TO_TYPE) {
return FILE_EXTENSION_TO_TYPE[ext] ?? null;
}
} catch {
// Not a valid URL; try plain extension match
const ext = url.split(".").pop()?.toLowerCase().split("?")[0];
if (ext && ext in FILE_EXTENSION_TO_TYPE) {
return FILE_EXTENSION_TO_TYPE[ext] ?? null;
}
}
return null;
}
function FileParserNode({ id, data }: NodeProps<FileParserNode>) {
const { editable, label } = data;
const { blockLabel: urlBlockLabel } = useParams();
@ -30,6 +76,15 @@ function FileParserNode({ id, data }: NodeProps<FileParserNode>) {
const update = useUpdate<FileParserNode["data"]>({ id, editable });
const recordingStore = useRecordingStore();
function handleFileUrlChange(value: string) {
const detected = detectFileTypeFromUrl(value);
if (detected) {
update({ fileUrl: value, fileType: detected });
} else {
update({ fileUrl: value });
}
}
return (
<div
className={cn({
@ -83,12 +138,38 @@ function FileParserNode({ id, data }: NodeProps<FileParserNode>) {
<WorkflowBlockInput
nodeId={id}
value={data.fileUrl}
onChange={(value) => {
update({ fileUrl: value });
}}
onChange={handleFileUrlChange}
className="nopan text-xs"
/>
</div>
<div className="space-y-2">
<div className="flex gap-2">
<Label className="text-xs text-slate-300">File Type</Label>
<HelpTooltip content={helpTooltips["fileParser"]["fileType"]} />
</div>
<Select
value={data.fileType}
onValueChange={(value) => {
update({ fileType: value as FileParserFileType });
}}
disabled={!editable}
>
<SelectTrigger className="nopan w-36 text-xs">
<SelectValue placeholder="Select type" />
</SelectTrigger>
<SelectContent>
{FILE_TYPE_OPTIONS.map((option) => (
<SelectItem
key={option.value}
value={option.value}
className="text-xs"
>
{option.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<WorkflowDataSchemaInputGroup
exampleValue={dataSchemaExampleForFileExtraction}
value={data.jsonSchema}

View file

@ -6,8 +6,11 @@ import {
WorkflowModel,
} from "@/routes/workflows/types/workflowTypes";
export type FileParserFileType = "csv" | "excel" | "pdf" | "image" | "docx";
export type FileParserNodeData = NodeBaseData & {
fileUrl: string;
fileType: FileParserFileType;
jsonSchema: string;
model: WorkflowModel | null;
};
@ -19,6 +22,7 @@ export const fileParserNodeDefaultData: FileParserNodeData = {
editable: true,
label: "",
fileUrl: "",
fileType: "csv",
continueOnFailure: false,
jsonSchema: "null",
model: null,

View file

@ -864,6 +864,7 @@ function convertToNode(
data: {
...commonData,
fileUrl: block.file_url,
fileType: block.file_type ?? "csv",
jsonSchema: JSON.stringify(block.json_schema, null, 2),
model: block.model,
},
@ -2511,7 +2512,7 @@ function getWorkflowBlock(
...base,
block_type: "file_url_parser",
file_url: node.data.fileUrl,
file_type: "csv", // Backend will auto-detect based on file extension
file_type: node.data.fileType,
json_schema: JSONParseSafe(node.data.jsonSchema),
};
}

View file

@ -410,7 +410,7 @@ export type SendEmailBlock = WorkflowBlockBase & {
export type FileURLParserBlock = WorkflowBlockBase & {
block_type: "file_url_parser";
file_url: string;
file_type: "csv" | "excel" | "pdf" | "image";
file_type: "csv" | "excel" | "pdf" | "image" | "docx";
json_schema: Record<string, unknown> | null;
};

View file

@ -352,7 +352,7 @@ export type SendEmailBlockYAML = BlockYAMLBase & {
export type FileUrlParserBlockYAML = BlockYAMLBase & {
block_type: "file_url_parser";
file_url: string;
file_type: "csv" | "excel" | "pdf" | "image";
file_type: "csv" | "excel" | "pdf" | "image" | "docx";
json_schema?: Record<string, unknown> | null;
};

View file

@ -154,7 +154,14 @@ async def skyvern_click(
"Include visual cues, position, or surrounding text when the page has similar elements."
),
] = None,
selector: Annotated[str | None, Field(description="CSS selector or XPath for the element to click")] = None,
selector: Annotated[
str | None,
Field(
description="Standard CSS selector or XPath for the element to click. "
"jQuery pseudo-selectors like :contains(), :eq(), :first are NOT valid. "
"Use standard CSS: 'button.class', 'a[href*=\"pdf\"]', '#id', ':nth-of-type()'."
),
] = None,
timeout: Annotated[
int,
Field(

View file

@ -82,15 +82,45 @@ def compute_failure_signature(
def _canonical_block_config(block: Any) -> dict[str, Any]:
"""Stable dict view of a block's material config, with fields that don't
affect downstream behavior (``output_parameter``) dropped.
"""
dump = getattr(block, "model_dump", None)
if callable(dump):
try:
return dump(mode="json", exclude_none=True)
cfg = dump(mode="json", exclude_none=True)
except TypeError:
return dump()
if isinstance(block, dict):
return block
return {"repr": repr(block)}
cfg = dump()
elif isinstance(block, dict):
cfg = dict(block)
else:
return {"repr": repr(block)}
cfg.pop("output_parameter", None)
return cfg
def compute_action_sequence_fingerprint(results: list[dict[str, Any]]) -> str | None:
"""Hash the ordered ``(action_type, element_id)`` pairs across every
block's ``action_trace`` in ``results``. Returns ``None`` when the trace is
empty (e.g. fully-successful run where ``_attach_action_traces`` did not
attach anything). Stable across runs: a form-fillclickre-fill loop that
retargets the same elements will produce the same fingerprint.
"""
pairs: list[str] = []
for entry in results:
trace = entry.get("action_trace")
if not isinstance(trace, list):
continue
for action in trace:
if not isinstance(action, dict):
continue
action_type = action.get("action") or ""
element = action.get("element") or ""
pairs.append(f"{action_type}\x1f{element}")
if not pairs:
return None
payload = "\x1e".join(pairs).encode("utf-8")
return hashlib.sha256(payload).hexdigest()
def compute_frontier_fingerprint(
@ -118,9 +148,7 @@ def compute_frontier_fingerprint(
if block is None:
payload.append({"label": label, "missing": True})
continue
config = _canonical_block_config(block)
config.pop("output_parameter", None)
payload.append({"label": label, "config": config})
payload.append({"label": label, "config": _canonical_block_config(block)})
try:
serialized = json.dumps(payload, sort_keys=True, default=str, separators=(",", ":"))
except (TypeError, ValueError):

View file

@ -0,0 +1,275 @@
"""SDK-native MCP server with schema overlays for the Skyvern copilot."""
from __future__ import annotations
import json
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
import structlog
from agents.agent import AgentBase
from agents.mcp.server import MCPServer
from agents.run_context import RunContextWrapper
from fastmcp import Client
from mcp import Tool as MCPTool
from mcp.types import (
CallToolResult,
GetPromptResult,
ListPromptsResult,
TextContent,
)
from skyvern.forge.sdk.copilot.loop_detection import detect_tool_loop
from skyvern.forge.sdk.copilot.output_utils import sanitize_tool_result_for_llm
from skyvern.forge.sdk.copilot.runtime import (
AgentContext,
ensure_browser_session,
mcp_browser_context,
mcp_to_copilot,
)
from skyvern.forge.sdk.copilot.screenshot_utils import enqueue_screenshot_from_result
PreHook = Callable[[dict[str, Any], AgentContext], Awaitable[dict[str, Any] | None]]
PostHook = Callable[[dict[str, Any], dict[str, Any], AgentContext], Awaitable[dict[str, Any]]]
@dataclass
class SchemaOverlay:
"""Schema overlay for MCP tools — hides params, renames args, injects forced values."""
description: str | None = None
hide_params: frozenset[str] = frozenset()
required_overrides: list[str] | None = None
arg_transforms: dict[str, str] = field(default_factory=dict)
forced_args: dict[str, Any] = field(default_factory=dict)
requires_browser: bool = False
timeout: int | None = None
pre_hook: PreHook | None = None
post_hook: PostHook | None = None
LOG = structlog.get_logger()
def _apply_schema_overlay(
input_schema: dict[str, Any],
overlay: SchemaOverlay,
) -> dict[str, Any]:
props = dict(input_schema.get("properties", {}))
required = list(input_schema.get("required", []))
for p in overlay.hide_params | frozenset(overlay.forced_args):
props.pop(p, None)
if p in required:
required.remove(p)
for copilot_param, mcp_param in overlay.arg_transforms.items():
if mcp_param in props:
props[copilot_param] = props.pop(mcp_param)
if mcp_param in required:
required.remove(mcp_param)
required.append(copilot_param)
if overlay.required_overrides is not None:
required = overlay.required_overrides
return {
"type": input_schema.get("type", "object"),
"properties": props,
"required": required,
}
def _transform_args(
arguments: dict[str, Any],
overlay: SchemaOverlay,
) -> dict[str, Any]:
mcp_args = {k: v for k, v in arguments.items() if k not in overlay.hide_params}
for copilot_param, mcp_param in overlay.arg_transforms.items():
if copilot_param in mcp_args:
mcp_args[mcp_param] = mcp_args.pop(copilot_param)
mcp_args.update(overlay.forced_args)
return mcp_args
def _copilot_to_call_tool_result(
copilot_result: dict[str, Any],
) -> CallToolResult:
sanitized = sanitize_tool_result_for_llm("", copilot_result)
content: list[TextContent] = [TextContent(type="text", text=json.dumps(sanitized))]
is_error = not copilot_result.get("ok", True)
return CallToolResult(content=content, isError=is_error)
class SkyvernOverlayMCPServer(MCPServer):
"""MCP server that wraps a FastMCP transport with schema overlays and
copilot-specific dispatch logic (loop detection, browser injection, hooks).
"""
def __init__(
self,
transport: Any,
overlays: dict[str, SchemaOverlay],
alias_map: dict[str, str],
allowlist: frozenset[str],
context_provider: Callable[[], Any],
) -> None:
super().__init__(use_structured_content=False)
self._transport = transport
self._overlays = overlays
self._alias_map = alias_map # copilot_name -> mcp_name
self._reverse_alias: dict[str, str] = {v: k for k, v in alias_map.items()}
self._allowlist = allowlist
self._context_provider = context_provider
self._client: Client | None = None
self._exit_stack: AsyncExitStack | None = None
self._cached_tools: list[MCPTool] | None = None
@property
def name(self) -> str:
return "skyvern"
async def connect(self) -> None:
stack = AsyncExitStack()
await stack.__aenter__()
client = Client(self._transport)
await stack.enter_async_context(client)
self._client = client
self._exit_stack = stack
async def cleanup(self) -> None:
if self._exit_stack:
await self._exit_stack.__aexit__(None, None, None)
self._client = None
self._exit_stack = None
self._cached_tools = None
async def list_tools(
self,
run_context: RunContextWrapper[Any] | None = None,
agent: AgentBase | None = None,
) -> list[MCPTool]:
if self._cached_tools is not None:
return self._cached_tools
if not self._client:
raise RuntimeError("Not connected — call connect() first")
raw_tools = await self._client.list_tools()
result: list[MCPTool] = []
for tool in raw_tools:
if tool.name not in self._allowlist:
continue
copilot_name = self._reverse_alias.get(tool.name, tool.name)
overlay = self._overlays.get(copilot_name, SchemaOverlay())
schema = _apply_schema_overlay(tool.inputSchema, overlay)
description = overlay.description or tool.description or ""
result.append(
MCPTool(
name=copilot_name,
description=description,
inputSchema=schema,
)
)
self._cached_tools = result
return result
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any] | None,
meta: dict[str, Any] | None = None,
) -> CallToolResult:
if not self._client:
raise RuntimeError("Not connected — call connect() first")
arguments = arguments or {}
copilot_ctx = self._context_provider()
overlay = self._overlays.get(tool_name, SchemaOverlay())
tracker = getattr(copilot_ctx, "consecutive_tool_tracker", None)
loop_error = detect_tool_loop(tracker, tool_name) if isinstance(tracker, list) else None
if loop_error:
LOG.warning(
"Tool loop detected, skipping execution",
tool_name=tool_name,
)
return CallToolResult(
content=[
TextContent(
type="text",
text=json.dumps(
{
"ok": False,
"error": loop_error,
}
),
)
],
isError=True,
)
if overlay.pre_hook:
hook_result = await overlay.pre_hook(arguments, copilot_ctx)
if hook_result is not None:
return _copilot_to_call_tool_result(hook_result)
mcp_name = self._alias_map.get(tool_name, tool_name)
mcp_args = _transform_args(arguments, overlay)
if overlay.requires_browser:
err = await ensure_browser_session(copilot_ctx)
if err:
return _copilot_to_call_tool_result(err)
mcp_args["session_id"] = copilot_ctx.browser_session_id
try:
call = self._client.call_tool(mcp_name, mcp_args, raise_on_error=False)
if overlay.requires_browser:
async with mcp_browser_context(copilot_ctx):
raw_result = await call
else:
raw_result = await call
except Exception as e:
LOG.warning(
"MCP tool call failed",
tool=tool_name,
error=str(e),
exc_info=True,
)
return _copilot_to_call_tool_result({"ok": False, "error": f"{tool_name} failed: {e}"})
# Copy fastmcp's structured_content so mutations below stay local to
# this call — the client may reuse or cache the response object.
raw_mcp = dict(raw_result.structured_content or {})
if raw_result.is_error:
raw_mcp["ok"] = False
if not raw_result.structured_content and raw_result.content:
text_parts = [c.text for c in raw_result.content if hasattr(c, "text")]
raw_mcp["error"] = " ".join(text_parts) if text_parts else "Unknown MCP error"
else:
raw_mcp["error"] = raw_mcp.get("error") or "Unknown MCP error"
copilot_result = mcp_to_copilot(raw_mcp)
if overlay.post_hook:
copilot_result = await overlay.post_hook(copilot_result, raw_mcp, copilot_ctx)
enqueue_screenshot_from_result(copilot_ctx, copilot_result)
return _copilot_to_call_tool_result(copilot_result)
async def list_prompts(self) -> ListPromptsResult:
return ListPromptsResult(prompts=[])
async def get_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
) -> GetPromptResult:
raise ValueError(f"Prompts not supported: {name}")

View file

@ -50,6 +50,18 @@ class AgentContext:
pending_screenshots: list[ScreenshotEntry] = field(default_factory=list)
tool_activity: list[dict[str, Any]] = field(default_factory=list)
# Cross-turn agent state accumulated by tools.py as the agent runs.
# Read back by failure_tracking / loop_detection to detect stuck loops,
# preserve verified prefixes across partial runs, etc. All optional —
# downstream accessors use ``getattr(ctx, name, default)`` where
# tolerant-to-unset is the right default.
last_requested_block_labels: list[str] = field(default_factory=list)
last_executed_block_labels: list[str] = field(default_factory=list)
last_frontier_start_label: str | None = None
pending_action_sequence_fingerprint: str | None = None
verified_block_outputs: dict[str, Any] = field(default_factory=dict)
verified_prefix_labels: list[str] = field(default_factory=list)
def mcp_to_copilot(mcp_result: dict[str, Any]) -> dict[str, Any]:
"""Convert an MCP result dict to the copilot {ok, data, error} format."""

File diff suppressed because it is too large Load diff

View file

@ -1077,6 +1077,13 @@ class WorkflowService:
name=f"browser_session_renewal_{workflow_run_id}",
)
# Captured inside the try and consumed in the outer finally so status
# finalization runs even when the body is cancelled or raises. Stays
# None if we were cancelled before the block-execution step completed;
# in that case there's no terminal-state intent to restore.
pre_finally_status: WorkflowRunStatus | None = None
pre_finally_failure_reason: str | None = None
try:
# Check if there's a related workflow script that should be used instead
workflow_script, _, script_is_pinned = await workflow_script_service.get_workflow_script(
@ -1205,14 +1212,37 @@ class WorkflowService:
organization=organization,
browser_session_id=browser_session_id,
)
workflow_run = await self._finalize_workflow_run_status(
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
pre_finally_status=pre_finally_status,
pre_finally_failure_reason=pre_finally_failure_reason,
)
finally:
# Shielded finalize runs even when the try body was cancelled
# mid-flight (e.g. the copilot tool's orphan-task cancel path, or
# any outer caller that cancels execute_workflow). Without this,
# cancellation between the temporary ``running`` write above and
# the original finalize call leaked ``running``/``canceled`` rows
# in place of the real terminal reason. When pre_finally_status is
# still ``None`` (cancellation landed before block execution
# completed), there's no captured intent to restore and we skip.
if pre_finally_status is not None:
try:
workflow_run = await asyncio.shield(
self._finalize_workflow_run_status(
workflow_run_id=workflow_run_id,
workflow_run=workflow_run,
pre_finally_status=pre_finally_status,
pre_finally_failure_reason=pre_finally_failure_reason,
)
)
except BaseException:
# Catch BaseException (not Exception) so a second
# ``CancelledError`` arriving during the shielded await —
# plausible when the copilot's detached cancellation
# fallback re-cancels ``run_task`` — does not escape this
# block and skip ``clean_up_workflow`` below.
LOG.warning(
"Finalize failed during execute_workflow cleanup",
workflow_run_id=workflow_run_id,
exc_info=True,
)
if renewal_task is not None and not renewal_task.done():
renewal_task.cancel()
try:

View file

@ -0,0 +1,96 @@
"""Tests for the action-sequence fingerprint compute and the hard-abort
short-circuit in ``_tool_loop_error``.
The streak counter that drives the abort is owned by
``failure_tracking.update_repeated_failure_state`` (stack 03). These tests
cover only the tools.py-side behavior:
- ``compute_action_sequence_fingerprint`` is stable across runs that fire
the same action shape (independent of reasoning text / status).
- ``compute_action_sequence_fingerprint`` distinguishes different sequences.
- ``_tool_loop_error`` returns a hard-abort message when the streak crosses
``REPEATED_ACTION_STREAK_ABORT_AT`` for a block-running tool.
- The hard abort does NOT fire for non-block-running tools, regardless of
streak height.
"""
from __future__ import annotations
from types import SimpleNamespace
from skyvern.forge.sdk.copilot.failure_tracking import compute_action_sequence_fingerprint
from skyvern.forge.sdk.copilot.tools import (
REPEATED_ACTION_STREAK_ABORT_AT,
_tool_loop_error,
)
def _ctx_with_streak(streak: int) -> SimpleNamespace:
return SimpleNamespace(
consecutive_tool_tracker=[],
repeated_action_fingerprint_streak_count=streak,
)
def test_compute_action_sequence_fingerprint_stable_for_same_sequence() -> None:
trace_a = [
{"action": "input_text", "element": "elem-name", "reasoning": "r1", "status": "failed"},
{"action": "input_text", "element": "elem-email", "reasoning": "r2", "status": "failed"},
{"action": "click", "element": "elem-submit", "reasoning": "r3", "status": "failed"},
]
trace_b = [
{"action": "input_text", "element": "elem-name", "reasoning": "different_text", "status": "failed"},
{"action": "input_text", "element": "elem-email", "reasoning": "other", "status": "failed"},
{"action": "click", "element": "elem-submit", "reasoning": "third", "status": "failed"},
]
fp_a = compute_action_sequence_fingerprint([{"action_trace": trace_a}])
fp_b = compute_action_sequence_fingerprint([{"action_trace": trace_b}])
assert fp_a is not None
# Reasoning / status are excluded from the fingerprint on purpose — only
# the (action, element) shape matters for detecting a retry loop.
assert fp_a == fp_b
def test_compute_action_sequence_fingerprint_none_when_trace_missing() -> None:
assert compute_action_sequence_fingerprint([]) is None
assert compute_action_sequence_fingerprint([{"status": "completed"}]) is None
assert compute_action_sequence_fingerprint([{"action_trace": []}]) is None
def test_compute_action_sequence_fingerprint_distinguishes_different_sequences() -> None:
trace_a = [{"action": "click", "element": "btn-a"}]
trace_b = [{"action": "click", "element": "btn-b"}]
trace_c = [{"action": "input_text", "element": "btn-a"}]
fp_a = compute_action_sequence_fingerprint([{"action_trace": trace_a}])
fp_b = compute_action_sequence_fingerprint([{"action_trace": trace_b}])
fp_c = compute_action_sequence_fingerprint([{"action_trace": trace_c}])
assert fp_a != fp_b
assert fp_a != fp_c
assert fp_b != fp_c
def test_tool_loop_error_fires_hard_abort_on_block_running_tools_when_streak_high() -> None:
ctx = _ctx_with_streak(REPEATED_ACTION_STREAK_ABORT_AT)
error = _tool_loop_error(ctx, "run_blocks_and_collect_debug")
assert error is not None
assert "Repeated-action abort" in error
error_update = _tool_loop_error(ctx, "update_and_run_blocks")
assert error_update is not None
assert "Repeated-action abort" in error_update
def test_tool_loop_error_does_not_fire_for_non_block_running_tools() -> None:
ctx = _ctx_with_streak(REPEATED_ACTION_STREAK_ABORT_AT + 5)
# Planning/metadata tools keep their existing loop-detection behavior
# and are not affected by the block-run streak.
assert _tool_loop_error(ctx, "update_workflow") is None
assert _tool_loop_error(ctx, "list_credentials") is None
assert _tool_loop_error(ctx, "get_run_results") is None
def test_tool_loop_error_does_not_fire_below_threshold() -> None:
ctx = _ctx_with_streak(REPEATED_ACTION_STREAK_ABORT_AT - 1)
assert _tool_loop_error(ctx, "run_blocks_and_collect_debug") is None

View file

@ -0,0 +1,148 @@
"""Tests for the copilot orphan-workflow cancellation helpers.
Covers:
- ``_cancel_run_task_if_not_final`` cancels ``run_task`` and writes the
conditional cancel exactly once.
- A SUCCESS path (run_task completes on its own) still calls the conditional
cancel, but because the row is terminal the real helper would be a no-op.
- An SDK-realistic ``asyncio.wait_for`` timeout around the tool coroutine does
not leave ``run_task`` running in the background.
"""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from skyvern.forge.sdk.copilot.tools import _cancel_run_task_if_not_final
class _FakeService:
def __init__(self) -> None:
self.mark_calls: list[str] = []
self.raise_on_mark: Exception | None = None
async def mark_workflow_run_as_canceled_if_not_final(
self,
workflow_run_id: str,
) -> Any:
self.mark_calls.append(workflow_run_id)
if self.raise_on_mark is not None:
raise self.raise_on_mark
return None
@pytest.mark.asyncio
async def test_cancel_helper_cancels_task_and_writes_conditional_cancel(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from skyvern.forge import app as forge_app
service = _FakeService()
monkeypatch.setattr(forge_app, "WORKFLOW_SERVICE", service)
async def long_running() -> None:
await asyncio.sleep(60)
run_task = asyncio.create_task(long_running())
await _cancel_run_task_if_not_final(run_task, workflow_run_id="wr_1")
assert run_task.cancelled() or run_task.done()
assert service.mark_calls == ["wr_1"]
@pytest.mark.asyncio
async def test_cancel_helper_does_not_raise_on_mark_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Secondary errors during cleanup are logged, not propagated — otherwise
they would replace the original timeout/cancellation surface."""
from skyvern.forge import app as forge_app
service = _FakeService()
service.raise_on_mark = RuntimeError("DB is down")
monkeypatch.setattr(forge_app, "WORKFLOW_SERVICE", service)
async def long_running() -> None:
await asyncio.sleep(60)
run_task = asyncio.create_task(long_running())
# Must not raise despite the mark raising.
await _cancel_run_task_if_not_final(run_task, workflow_run_id="wr_2")
@pytest.mark.asyncio
async def test_cancel_helper_handles_already_completed_run_task(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When run_task has already finished (natural completion), the helper
should still issue the conditional cancel it is a no-op at the DB layer
if the row is already terminal, so the result is harmless."""
from skyvern.forge import app as forge_app
service = _FakeService()
monkeypatch.setattr(forge_app, "WORKFLOW_SERVICE", service)
async def quick() -> None:
return
run_task = asyncio.create_task(quick())
await run_task
await _cancel_run_task_if_not_final(run_task, workflow_run_id="wr_3")
assert service.mark_calls == ["wr_3"]
@pytest.mark.asyncio
async def test_sdk_style_wait_for_timeout_does_not_leak_background_work(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Exercises the production failure mode: the OpenAI Agents SDK wraps the
tool coroutine in ``asyncio.wait_for(..., timeout=N)`` and cancels it on
timeout. Our CancelledError branch must cancel ``run_task`` through the
helper so no orphan work is left behind."""
from skyvern.forge import app as forge_app
service = _FakeService()
monkeypatch.setattr(forge_app, "WORKFLOW_SERVICE", service)
run_task_ref: dict[str, asyncio.Task] = {}
workflow_work_completed = asyncio.Event()
async def tool_body() -> None:
async def workflow_body() -> None:
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
workflow_work_completed.set()
raise
run_task = asyncio.create_task(workflow_body())
run_task_ref["run_task"] = run_task
try:
# Simulate the inner poll loop.
while True:
await asyncio.sleep(0.05)
except asyncio.CancelledError:
try:
await asyncio.shield(_cancel_run_task_if_not_final(run_task, workflow_run_id="wr_sdk"))
except asyncio.CancelledError:
# Detached fallback mirror of the production path.
fallback = asyncio.ensure_future(_cancel_run_task_if_not_final(run_task, workflow_run_id="wr_sdk"))
await asyncio.wait_for(asyncio.shield(fallback), timeout=5.0)
raise
tool_task = asyncio.ensure_future(tool_body())
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(tool_task, timeout=0.2)
# Workflow's CancelledError handler should have fired via our helper.
await asyncio.wait_for(workflow_work_completed.wait(), timeout=1.0)
assert "run_task" in run_task_ref
assert run_task_ref["run_task"].cancelled() or run_task_ref["run_task"].done()
assert service.mark_calls == ["wr_sdk"]

View file

@ -6,10 +6,14 @@
- ``execute_workflow_webhook`` returns cleanly when the workflow row has been
soft-deleted mid-run it must not raise ``WorkflowNotFound`` from the
cleanup path.
- The cancellation-safe finalize pattern used in ``execute_workflow``'s outer
``finally`` runs ``_finalize_workflow_run_status`` via ``asyncio.shield``
so an outer cancel mid-body still restores the real terminal status.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -113,3 +117,89 @@ async def test_build_status_response_uses_filter_deleted_false_when_allowed(
by_wpid.assert_awaited_once()
assert by_wpid.call_args.kwargs["filter_deleted"] is False
@pytest.mark.asyncio
async def test_shielded_finalize_runs_when_outer_cancelled_mid_body() -> None:
"""Contract test for the ``execute_workflow`` cancellation-safe pattern:
when the try body is cancelled after ``pre_finally_status`` is captured,
the outer ``finally`` must still run ``_finalize_workflow_run_status``
via ``asyncio.shield`` so the row ends up terminal rather than stuck as
transient ``running``. Mirrors the structure of
``WorkflowService.execute_workflow``; if anyone removes the ``shield`` or
moves finalize back into the try, this test breaks.
"""
finalize_calls: list[WorkflowRunStatus] = []
clean_up_called = False
body_entered = asyncio.Event()
async def finalize(status: WorkflowRunStatus) -> None:
# Simulate a non-trivial DB write so shield cancellation-protection
# matters rather than being invisible.
await asyncio.sleep(0.05)
finalize_calls.append(status)
async def clean_up() -> None:
nonlocal clean_up_called
clean_up_called = True
async def simulated_execute_workflow() -> None:
pre_finally_status: WorkflowRunStatus | None = None
try:
pre_finally_status = WorkflowRunStatus.failed
body_entered.set()
# Simulate the finally-block execution phase that our copilot
# cancel lands inside of.
await asyncio.sleep(10)
finally:
if pre_finally_status is not None:
try:
await asyncio.shield(finalize(pre_finally_status))
except Exception:
pass
await clean_up()
task = asyncio.create_task(simulated_execute_workflow())
await body_entered.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert finalize_calls == [WorkflowRunStatus.failed], (
"shielded finalize must run with the captured pre_finally_status"
)
assert clean_up_called, "clean_up_workflow must still run in the outer finally"
@pytest.mark.asyncio
async def test_shielded_finalize_skipped_when_pre_finally_status_unset() -> None:
"""If cancellation lands before block execution captures
``pre_finally_status``, there's no intended terminal state to restore —
the outer ``finally`` must skip finalize, not call it with ``None``.
"""
finalize_called = False
async def finalize(status: WorkflowRunStatus) -> None:
nonlocal finalize_called
finalize_called = True
async def simulated_execute_workflow() -> None:
pre_finally_status: WorkflowRunStatus | None = None
try:
await asyncio.sleep(10)
pre_finally_status = WorkflowRunStatus.failed # pragma: no cover
finally:
if pre_finally_status is not None:
await asyncio.shield(finalize(pre_finally_status))
task = asyncio.create_task(simulated_execute_workflow())
await asyncio.sleep(0) # let the task enter its body
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert not finalize_called, "finalize must not run when pre_finally_status is unset"

View file

@ -16,6 +16,7 @@ import json
import yaml
from skyvern.forge.sdk.routes.workflow_copilot import _process_workflow_yaml
from skyvern.utils.yaml_loader import safe_load_no_dates
ISO_BLOB = """
@ -75,3 +76,40 @@ def test_safe_load_no_dates_preserves_other_implicit_types() -> None:
assert parsed["a_bool"] is True
assert parsed["a_null"] is None
assert parsed["a_list"] == [1, 2, 3]
def test_process_workflow_yaml_keeps_json_parameter_iso_strings() -> None:
workflow = _process_workflow_yaml(
workflow_id="wf-123",
workflow_permanent_id="wfp-123",
organization_id="org-123",
workflow_yaml="""
title: Test
workflow_definition:
parameters:
- parameter_type: workflow
key: payload
workflow_parameter_type: json
default_value:
id: "12345"
metadata:
created_at: 2023-10-27T10:00:00Z
updated_at: 2023-10-28T14:30:00Z
blocks:
- block_type: navigation
label: step1
url: https://example.com
title: Step 1
navigation_goal: Open the page
""",
)
parameter = workflow.get_parameter("payload")
assert parameter is not None
assert parameter.default_value is not None
metadata = parameter.default_value["metadata"]
assert metadata["created_at"] == "2023-10-27T10:00:00Z"
assert metadata["updated_at"] == "2023-10-28T14:30:00Z"
assert isinstance(metadata["created_at"], str)
assert isinstance(metadata["updated_at"], str)