mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
325 lines
14 KiB
Python
325 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest # type: ignore[import-not-found]
|
|
|
|
from skyvern.forge.sdk.api.llm import api_handler_factory
|
|
from skyvern.forge.sdk.api.llm.api_handler_factory import (
|
|
EXTRACT_ACTION_PROMPT_NAME,
|
|
LLMAPIHandlerFactory,
|
|
)
|
|
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
|
from skyvern.forge.sdk.models import Step, StepStatus
|
|
from tests.unit.helpers import FakeLLMResponse
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cached_content_not_added_for_non_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that cached_content is NOT added to non-Gemini models like GPT-4."""
|
|
# Setup context with caching enabled
|
|
context = MagicMock()
|
|
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
|
|
context.use_prompt_caching = True
|
|
context.cached_static_prompt = "some static prompt"
|
|
context.hashed_href_map = {}
|
|
|
|
# Setup non-Gemini config
|
|
llm_config = LLMConfig(
|
|
model_name="gpt-4",
|
|
required_env_vars=[],
|
|
supports_vision=True,
|
|
add_assistant_prefix=False,
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
|
)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
|
)
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
|
monkeypatch.setattr(
|
|
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
|
)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
|
|
|
# Mock litellm.acompletion to capture the parameters
|
|
completion_params = {}
|
|
|
|
async def mock_acompletion(*args, **kwargs):
|
|
completion_params.update(kwargs)
|
|
return FakeLLMResponse("gpt-4")
|
|
|
|
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
|
|
|
# Get handler and call it
|
|
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
|
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
|
|
|
# Verify cached_content was NOT passed
|
|
assert "cached_content" not in completion_params
|
|
assert completion_params["model"] == "gpt-4"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cached_content_added_for_gemini(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that cached_content IS added for Gemini models."""
|
|
# Setup context with caching enabled
|
|
context = MagicMock()
|
|
context.vertex_cache_name = "projects/123/locations/us-central1/cachedContents/456"
|
|
context.use_prompt_caching = True
|
|
context.cached_static_prompt = "some static prompt"
|
|
context.hashed_href_map = {}
|
|
|
|
# Setup Gemini config
|
|
llm_config = LLMConfig(
|
|
model_name="gemini-1.5-pro",
|
|
required_env_vars=[],
|
|
supports_vision=True,
|
|
add_assistant_prefix=False,
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
|
)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
|
)
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
|
monkeypatch.setattr(
|
|
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
|
)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
|
|
|
# Mock litellm.acompletion to capture the parameters
|
|
completion_params = {}
|
|
|
|
async def mock_acompletion(*args, **kwargs):
|
|
completion_params.update(kwargs)
|
|
return FakeLLMResponse("gemini-1.5-pro")
|
|
|
|
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
|
|
|
# Get handler and call it
|
|
handler = LLMAPIHandlerFactory.get_llm_api_handler("gemini-1.5-pro")
|
|
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
|
|
|
# Verify cached_content WAS passed
|
|
assert "cached_content" in completion_params
|
|
assert completion_params["cached_content"] == "projects/123/locations/us-central1/cachedContents/456"
|
|
assert completion_params["model"] == "gemini-1.5-pro"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_caching_not_injected_for_check_user_goal(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that OpenAI context caching system message is NOT injected for check-user-goal prompts.
|
|
|
|
This is a regression test for a bug where the extract-action-static.j2 prompt was being
|
|
injected as a system message for ALL prompts on OpenAI models, causing the LLM to return
|
|
CLICK actions when running check-user-goal (which should only return COMPLETE/TERMINATE).
|
|
"""
|
|
# Setup context with caching enabled (simulating state after extract-action ran)
|
|
context = MagicMock()
|
|
context.vertex_cache_name = None
|
|
context.use_prompt_caching = True
|
|
context.cached_static_prompt = "This is the extract-action-static prompt content"
|
|
context.hashed_href_map = {}
|
|
|
|
# Setup OpenAI config (GPT-4)
|
|
llm_config = LLMConfig(
|
|
model_name="gpt-4",
|
|
required_env_vars=[],
|
|
supports_vision=True,
|
|
add_assistant_prefix=False,
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
|
)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
|
)
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
|
|
|
# Capture messages passed to LLM
|
|
captured_messages: list = []
|
|
|
|
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
|
|
return [{"role": "user", "content": prompt}]
|
|
|
|
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
|
|
|
async def mock_acompletion(*args, **kwargs):
|
|
captured_messages.extend(kwargs.get("messages", []))
|
|
return FakeLLMResponse("gpt-4")
|
|
|
|
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
|
|
|
# Get handler and call it with check-user-goal prompt (NOT extract-actions)
|
|
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
|
await handler(prompt="check-user-goal prompt content", prompt_name="check-user-goal")
|
|
|
|
# Verify the cached_static_prompt was NOT injected as a system message
|
|
# There should only be the user message, no system message with the cached content
|
|
system_messages = [m for m in captured_messages if m.get("role") == "system"]
|
|
assert len(system_messages) == 0, (
|
|
f"Expected no system messages with cached content for check-user-goal, but found: {system_messages}"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_caching_injected_for_extract_actions(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that OpenAI context caching system message IS injected for extract-actions prompts."""
|
|
# Setup context with caching enabled
|
|
context = MagicMock()
|
|
context.vertex_cache_name = None
|
|
context.use_prompt_caching = True
|
|
context.cached_static_prompt = "This is the extract-action-static prompt content"
|
|
context.hashed_href_map = {}
|
|
|
|
# Setup OpenAI config (GPT-4)
|
|
llm_config = LLMConfig(
|
|
model_name="gpt-4",
|
|
required_env_vars=[],
|
|
supports_vision=True,
|
|
add_assistant_prefix=False,
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
|
)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
|
)
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
|
|
|
# Capture messages passed to LLM
|
|
captured_messages: list = []
|
|
|
|
async def mock_llm_messages_builder(prompt, screenshots, add_assistant_prefix):
|
|
return [{"role": "user", "content": prompt}]
|
|
|
|
monkeypatch.setattr(api_handler_factory, "llm_messages_builder", mock_llm_messages_builder)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.0)
|
|
|
|
async def mock_acompletion(*args, **kwargs):
|
|
captured_messages.extend(kwargs.get("messages", []))
|
|
return FakeLLMResponse("gpt-4")
|
|
|
|
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
|
|
|
# Get handler and call it with extract-actions prompt
|
|
handler = LLMAPIHandlerFactory.get_llm_api_handler("gpt-4")
|
|
await handler(prompt="extract-actions prompt content", prompt_name=EXTRACT_ACTION_PROMPT_NAME)
|
|
|
|
# Verify the cached_static_prompt WAS injected as a system message
|
|
system_messages = [m for m in captured_messages if m.get("role") == "system"]
|
|
assert len(system_messages) == 1, (
|
|
f"Expected 1 system message with cached content for extract-actions, "
|
|
f"but found {len(system_messages)}: {system_messages}"
|
|
)
|
|
# Check the system message contains the cached content
|
|
system_content = system_messages[0].get("content", [])
|
|
assert any(part.get("text") == "This is the extract-action-static prompt content" for part in system_content), (
|
|
f"System message should contain cached_static_prompt, got: {system_content}"
|
|
)
|
|
|
|
|
|
def test_normalize_llm_model_strips_provider_prefix() -> None:
|
|
"""LiteLLM returns model names with provider prefixes; dbt expects the bare name."""
|
|
assert api_handler_factory._normalize_llm_model("vertex_ai/gemini-2.5-flash") == "gemini-2.5-flash"
|
|
assert api_handler_factory._normalize_llm_model("openai/gpt-4.1-mini") == "gpt-4.1-mini"
|
|
assert api_handler_factory._normalize_llm_model("gpt-4") == "gpt-4"
|
|
assert api_handler_factory._normalize_llm_model(None) is None
|
|
|
|
|
|
def test_assert_step_thought_block_exclusive_rejects_both_set() -> None:
|
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
|
api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), MagicMock(), None)
|
|
|
|
|
|
def test_assert_step_thought_block_exclusive_rejects_step_and_block() -> None:
|
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
|
api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), None, "wfb_123")
|
|
|
|
|
|
def test_assert_step_thought_block_exclusive_allows_single_or_neither() -> None:
|
|
api_handler_factory._assert_step_thought_block_exclusive(None, None, None)
|
|
api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), None, None)
|
|
api_handler_factory._assert_step_thought_block_exclusive(None, MagicMock(), None)
|
|
api_handler_factory._assert_step_thought_block_exclusive(None, None, "wfb_123")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handler_persists_response_model_not_router_group(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""The handler must persist response.model (normalized), not the config key used to resolve the handler."""
|
|
context = MagicMock()
|
|
context.vertex_cache_name = None
|
|
context.use_prompt_caching = False
|
|
context.cached_static_prompt = None
|
|
context.hashed_href_map = {}
|
|
context.use_artifact_bundling = False
|
|
context.workflow_run_id = None
|
|
context.task_id = None
|
|
|
|
llm_config = LLMConfig(
|
|
model_name="GEMINI_2_5_FLASH_WITH_FALLBACK", # router group name, not what response.model returns
|
|
required_env_vars=[],
|
|
supports_vision=True,
|
|
add_assistant_prefix=False,
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
|
|
)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
|
|
)
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
|
|
monkeypatch.setattr(
|
|
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
|
)
|
|
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.01)
|
|
|
|
# LiteLLM returns the actual backing model with its provider prefix
|
|
async def mock_acompletion(*args, **kwargs):
|
|
return FakeLLMResponse("vertex_ai/gemini-2.5-flash")
|
|
|
|
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
|
|
|
|
# Capture update_step kwargs to assert on the llm_model value
|
|
captured_kwargs: dict = {}
|
|
|
|
async def mock_update_step(**kwargs):
|
|
captured_kwargs.update(kwargs)
|
|
return MagicMock()
|
|
|
|
artifact_manager = MagicMock()
|
|
artifact_manager.prepare_llm_artifact = AsyncMock(return_value=None)
|
|
artifact_manager.bulk_create_artifacts = AsyncMock()
|
|
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.app.ARTIFACT_MANAGER", artifact_manager)
|
|
monkeypatch.setattr(
|
|
"skyvern.forge.sdk.api.llm.api_handler_factory.app.DATABASE.tasks.update_step", mock_update_step
|
|
)
|
|
|
|
now = datetime.now()
|
|
step = Step(
|
|
created_at=now,
|
|
modified_at=now,
|
|
task_id="tsk_test",
|
|
step_id="stp_test",
|
|
status=StepStatus.running,
|
|
order=0,
|
|
is_last=False,
|
|
retry_index=0,
|
|
organization_id="org_test",
|
|
)
|
|
|
|
handler = LLMAPIHandlerFactory.get_llm_api_handler("GEMINI_2_5_FLASH_WITH_FALLBACK")
|
|
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME, step=step)
|
|
|
|
# The persisted model should be the bare response.model, not the router group key
|
|
assert captured_kwargs.get("last_llm_model") == "gemini-2.5-flash"
|