mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 11:40:32 +00:00
160 lines
6.3 KiB
Python
160 lines
6.3 KiB
Python
"""Tests for the summarize-output endpoint and helpers."""
|
|
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from typing import Iterator
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from pydantic import ValidationError
|
|
|
|
from skyvern.forge import app
|
|
from skyvern.forge.sdk.api.llm.exceptions import (
|
|
EmptyLLMResponseError,
|
|
InvalidLLMResponseFormat,
|
|
InvalidLLMResponseType,
|
|
LLMProviderError,
|
|
)
|
|
from skyvern.forge.sdk.routes.prompts import summarize_output
|
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
|
from skyvern.forge.sdk.schemas.prompts import SummarizeOutputRequest, SummarizeOutputResponse
|
|
from skyvern.utils.strings import escape_code_fences
|
|
from tests.unit.helpers import make_organization
|
|
|
|
|
|
class TestEscapeCodeFences:
|
|
def test_none_returns_empty_string(self) -> None:
|
|
assert escape_code_fences(None) == ""
|
|
|
|
def test_triple_backticks_are_neutralized(self) -> None:
|
|
assert escape_code_fences("a ```evil``` b") == "a ` ` `evil` ` ` b"
|
|
|
|
def test_triple_tildes_are_neutralized(self) -> None:
|
|
assert escape_code_fences("a ~~~evil~~~ b") == "a ~ ~ ~evil~ ~ ~ b"
|
|
|
|
def test_fullwidth_backticks_normalized_then_escaped(self) -> None:
|
|
# U+FF40 is fullwidth grave accent; NFKC normalizes it to `
|
|
assert "```" not in escape_code_fences("\uff40\uff40\uff40")
|
|
|
|
|
|
class TestSummarizeOutputRequest:
|
|
def test_valid_json_accepted(self) -> None:
|
|
req = SummarizeOutputRequest(output_json='{"a": 1}')
|
|
assert req.output_json == '{"a": 1}'
|
|
|
|
def test_invalid_json_rejected(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SummarizeOutputRequest(output_json="not json")
|
|
|
|
def test_empty_output_rejected(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SummarizeOutputRequest(output_json="")
|
|
|
|
def test_oversized_output_rejected(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SummarizeOutputRequest(output_json='"' + "x" * 100_001 + '"')
|
|
|
|
def test_oversized_title_rejected(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
SummarizeOutputRequest(output_json="{}", workflow_title="x" * 501)
|
|
|
|
def test_deeply_nested_json_rejected(self) -> None:
|
|
# 10k levels of nesting exceeds Python's recursion limit inside json.loads.
|
|
deep = "[" * 10_000 + "]" * 10_000
|
|
with pytest.raises(ValidationError):
|
|
SummarizeOutputRequest(output_json=deep)
|
|
|
|
|
|
def _fake_org() -> Organization:
|
|
return make_organization(datetime.now(timezone.utc))
|
|
|
|
|
|
@contextmanager
|
|
def _patch_llm_handler(handler: AsyncMock) -> Iterator[None]:
|
|
"""Temporarily install an LLM handler on the app for a single test."""
|
|
sentinel = object()
|
|
original = getattr(app, "LLM_API_HANDLER", sentinel)
|
|
app.LLM_API_HANDLER = handler # type: ignore[attr-defined]
|
|
try:
|
|
yield
|
|
finally:
|
|
if original is sentinel:
|
|
delattr(app, "LLM_API_HANDLER")
|
|
else:
|
|
app.LLM_API_HANDLER = original # type: ignore[attr-defined]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSummarizeOutputRoute:
|
|
async def test_success_returns_stripped_summary(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(return_value={"summary": " a summary "})):
|
|
result = await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert isinstance(result, SummarizeOutputResponse)
|
|
assert result.summary == "a summary"
|
|
assert result.error is None
|
|
|
|
async def test_non_dict_response_returns_structured_error(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(return_value="just a string")):
|
|
result = await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert result.summary == ""
|
|
assert result.error == "LLM response is not valid JSON."
|
|
|
|
async def test_missing_summary_key_returns_structured_error(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(return_value={"other_field": "x"})):
|
|
result = await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert result.summary == ""
|
|
assert result.error == "LLM response missing 'summary' field."
|
|
|
|
async def test_non_string_summary_returns_structured_error(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(return_value={"summary": 42})):
|
|
result = await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert result.summary == ""
|
|
assert result.error == "LLM 'summary' field is not a string."
|
|
|
|
@pytest.mark.parametrize(
|
|
"exc",
|
|
[
|
|
InvalidLLMResponseFormat("bad"),
|
|
InvalidLLMResponseType("list"),
|
|
EmptyLLMResponseError("empty"),
|
|
],
|
|
)
|
|
async def test_malformed_llm_output_returns_structured_error(self, exc: Exception) -> None:
|
|
with _patch_llm_handler(AsyncMock(side_effect=exc)):
|
|
result = await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert result.summary == ""
|
|
assert result.error == "LLM response is not valid JSON."
|
|
|
|
async def test_llm_provider_error_raises_503(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(side_effect=LLMProviderError("down"))):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert exc_info.value.status_code == 503
|
|
|
|
async def test_unexpected_exception_raises_500(self) -> None:
|
|
with _patch_llm_handler(AsyncMock(side_effect=RuntimeError("boom"))):
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await summarize_output(
|
|
request=SummarizeOutputRequest(output_json='{"a": 1}'),
|
|
current_org=_fake_org(),
|
|
)
|
|
assert exc_info.value.status_code == 500
|