Skyvern/tests/unit/test_copilot_streaming_adapter.py

316 lines
10 KiB
Python

from __future__ import annotations
import asyncio
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.copilot.streaming_adapter import _sanitize_input, stream_to_sse
def test_strips_workflow_yaml() -> None:
result = _sanitize_input({"workflow_yaml": "title: x", "block_labels": ["a"]})
assert "workflow_yaml" not in result
assert result["block_labels"] == ["a"]
def test_redacts_password_in_parameters() -> None:
result = _sanitize_input(
{
"workflow_yaml": "...",
"parameters": {"username": "u", "password": "p"},
}
)
params = result["parameters"]
assert params["username"] == "u"
assert params["password"] == "****"
def test_redacts_totp_and_api_key() -> None:
result = _sanitize_input(
{
"parameters": {
"totp": "123456",
"api_key": "sk-abc",
"mfa_code": "999",
}
}
)
params = result["parameters"]
assert params["totp"] == "****"
assert params["api_key"] == "****"
assert params["mfa_code"] == "****"
def test_does_not_redact_benign_identifiers() -> None:
result = _sanitize_input(
{
"parameters": {
"credential_id": "cred_abc",
"page_token": "pt_xyz",
"username": "user1",
"search_term": "apple",
}
}
)
params = result["parameters"]
assert params["credential_id"] == "cred_abc"
assert params["page_token"] == "pt_xyz"
assert params["username"] == "user1"
assert params["search_term"] == "apple"
def test_redacts_nested_dict() -> None:
result = _sanitize_input(
{
"parameters": {
"outer": {"password": "p", "label": "ok"},
}
}
)
assert result["parameters"]["outer"]["password"] == "****"
assert result["parameters"]["outer"]["label"] == "ok"
def test_empty_input() -> None:
assert _sanitize_input({}) == {}
async def _stream_events_from(*events: Any) -> Any:
for event in events:
yield event
@pytest.mark.asyncio
async def test_stream_to_sse_keeps_running_after_client_disconnect() -> None:
"""SKY-8986 regression: a dropped SSE client must NOT cancel the agent run.
The handler task outlives the SSE response so the agent's reply can be
persisted to the chat history. stream_to_sse keeps draining the SDK's
event stream; emissions turn into no-ops when is_disconnected() returns
True, but result.cancel() is never called and no exception escapes.
"""
from agents.items import RunItem
from agents.stream_events import RunItemStreamEvent
raw_call = {"call_id": "c1", "name": "click", "arguments": "{}"}
call_item = MagicMock(spec=RunItem)
call_item.raw_item = raw_call
tool_call = RunItemStreamEvent(name="tool_called", item=call_item)
raw_output = {"call_id": "c1"}
output_item = MagicMock(spec=RunItem)
output_item.raw_item = raw_output
output_item.output = None
tool_output = RunItemStreamEvent(name="tool_output", item=output_item)
result = MagicMock()
result.stream_events = lambda: _stream_events_from(tool_call, tool_output)
result.cancel = MagicMock()
stream = MagicMock()
stream.is_disconnected = AsyncMock(return_value=True)
stream.send = AsyncMock(return_value=True)
ctx = SimpleNamespace()
await stream_to_sse(result, stream, ctx)
result.cancel.assert_not_called()
stream.send.assert_not_called()
@pytest.mark.asyncio
async def test_stream_to_sse_propagates_cancelled_error() -> None:
"""A generic asyncio.CancelledError must propagate up from stream_to_sse so
the event loop's cancellation machinery still works for task-group cancel,
upstream timeout, or parent abort. The adapter must not catch it and turn
it into a normal return.
"""
async def _raises_cancelled() -> Any:
raise asyncio.CancelledError()
yield # make it an async generator
result = MagicMock()
result.stream_events = _raises_cancelled
result.cancel = MagicMock()
stream = MagicMock()
stream.is_disconnected = AsyncMock(return_value=False)
stream.send = AsyncMock(return_value=True)
ctx = SimpleNamespace()
with pytest.raises(asyncio.CancelledError):
await stream_to_sse(result, stream, ctx)
result.cancel.assert_called_once()
class TestParseToolOutput:
@staticmethod
def _parse(output: Any) -> dict[str, Any]:
from skyvern.forge.sdk.copilot.streaming_adapter import parse_tool_output
return parse_tool_output(output)
def test_parse_none(self) -> None:
assert self._parse(None) == {"ok": True}
def test_parse_plain_json_string(self) -> None:
assert self._parse('{"ok": true}') == {"ok": True}
def test_parse_json_string_with_data(self) -> None:
result = self._parse('{"ok": true, "data": {"count": 5}}')
assert result["ok"] is True
assert result["data"]["count"] == 5
def test_parse_error_json_string(self) -> None:
result = self._parse('{"ok": false, "error": "something broke"}')
assert result["ok"] is False
assert result["error"] == "something broke"
def test_parse_non_json_string(self) -> None:
result = self._parse("just plain text")
assert result["ok"] is True
assert result["data"] == "just plain text"
def test_parse_list_with_text_dict(self) -> None:
output = [{"type": "text", "text": '{"ok": false, "error": "fail"}'}]
result = self._parse(output)
assert result == {"ok": False, "error": "fail"}
def test_parse_list_with_text_object(self) -> None:
"""SDK may return ToolOutputText objects, not dicts."""
class FakeTextOutput:
type = "text"
text = '{"ok": true, "data": "hello"}'
result = self._parse([FakeTextOutput()])
assert result == {"ok": True, "data": "hello"}
def test_parse_list_skips_image_items(self) -> None:
output = [
{"type": "text", "text": '{"ok": true}'},
{"type": "image", "image_url": "data:image/png;base64,abc"},
]
result = self._parse(output)
assert result == {"ok": True}
def test_parse_wrapped_text_dict(self) -> None:
output = {"type": "text", "text": '{"ok": true}'}
result = self._parse(output)
assert result == {"ok": True}
def test_parse_direct_copilot_dict(self) -> None:
output = {"ok": True, "data": {"workflow_id": "wf_1"}}
result = self._parse(output)
assert result == output
def test_parse_dict_without_ok_or_type(self) -> None:
output = {"some_key": "some_value"}
result = self._parse(output)
assert result["ok"] is True
assert result["data"] == output
def test_parse_object_with_text_attr(self) -> None:
class FakeOutput:
type = "text"
text = '{"ok": true, "data": 42}'
result = self._parse(FakeOutput())
assert result == {"ok": True, "data": 42}
def test_parse_empty_list(self) -> None:
result = self._parse([])
assert result["ok"] is True
def test_run_blocks_summary_handles_non_dict_data(self) -> None:
from skyvern.forge.sdk.copilot.output_utils import summarize_tool_result
summary = summarize_tool_result(
"run_blocks_and_collect_debug",
{"ok": True, "data": [{"type": "text", "text": '{"ok": true}'}]},
)
assert summary == "Run debug completed"
class TestEnforcementStateUpdates:
def _make_ctx(self) -> Any:
ctx = MagicMock()
ctx.update_workflow_called = False
ctx.test_after_update_done = False
ctx.post_update_nudge_count = 0
ctx.navigate_called = False
ctx.observation_after_navigate = False
return ctx
def test_update_workflow_sets_flags(self) -> None:
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
_update_enforcement_from_tool(
ctx,
"update_workflow",
{
"ok": True,
"data": {"block_count": 2},
},
)
assert ctx.update_workflow_called is True
assert ctx.test_after_update_done is False
assert ctx.post_update_nudge_count == 0
def test_run_blocks_sets_test_done(self) -> None:
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
_update_enforcement_from_tool(ctx, "run_blocks_and_collect_debug", {"ok": True})
assert ctx.test_after_update_done is True
def test_update_and_run_blocks_sets_both_flags(self) -> None:
"""update_and_run_blocks is a composite tool — it must set update AND test flags."""
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
_update_enforcement_from_tool(
ctx,
"update_and_run_blocks",
{"ok": True, "data": {"block_count": 2}},
)
assert ctx.update_workflow_called is True
assert ctx.test_after_update_done is True
def test_navigate_sets_flags(self) -> None:
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
_update_enforcement_from_tool(ctx, "navigate_browser", {"ok": True})
assert ctx.navigate_called is True
assert ctx.observation_after_navigate is False
def test_observation_tool_sets_flag(self) -> None:
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
ctx.observation_after_navigate = False
_update_enforcement_from_tool(ctx, "get_browser_screenshot", {"ok": True})
assert ctx.observation_after_navigate is True
def test_update_without_blocks_does_not_set_flag(self) -> None:
from skyvern.forge.sdk.copilot.streaming_adapter import _update_enforcement_from_tool
ctx = self._make_ctx()
_update_enforcement_from_tool(
ctx,
"update_workflow",
{
"ok": True,
"data": {"block_count": 0},
},
)
assert ctx.update_workflow_called is False