feat: add COGS visibility dbt models and LLM model tracking (#5500)

This commit is contained in:
Aaron Perez 2026-04-14 15:49:12 -05:00 committed by GitHub
parent 68bc01051c
commit 2f0ab4e453
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 222 additions and 5 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest # type: ignore[import-not-found]
@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import (
LLMAPIHandlerFactory,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.models import Step, StepStatus
from tests.unit.helpers import FakeLLMResponse
@ -224,3 +226,94 @@ async def test_openai_caching_injected_for_extract_actions(monkeypatch: pytest.M
assert any(part.get("text") == "This is the extract-action-static prompt content" for part in system_content), (
f"System message should contain cached_static_prompt, got: {system_content}"
)
def test_normalize_llm_model_strips_provider_prefix() -> None:
"""LiteLLM returns model names with provider prefixes; dbt expects the bare name."""
assert api_handler_factory._normalize_llm_model("vertex_ai/gemini-2.5-flash") == "gemini-2.5-flash"
assert api_handler_factory._normalize_llm_model("openai/gpt-4.1-mini") == "gpt-4.1-mini"
assert api_handler_factory._normalize_llm_model("gpt-4") == "gpt-4"
assert api_handler_factory._normalize_llm_model(None) is None
def test_assert_step_thought_exclusive_rejects_both_set() -> None:
with pytest.raises(ValueError, match="mutually exclusive"):
api_handler_factory._assert_step_thought_exclusive(MagicMock(), MagicMock())
def test_assert_step_thought_exclusive_allows_single_or_neither() -> None:
api_handler_factory._assert_step_thought_exclusive(None, None)
api_handler_factory._assert_step_thought_exclusive(MagicMock(), None)
api_handler_factory._assert_step_thought_exclusive(None, MagicMock())
@pytest.mark.asyncio
async def test_handler_persists_response_model_not_router_group(monkeypatch: pytest.MonkeyPatch) -> None:
"""The handler must persist response.model (normalized), not the config key used to resolve the handler."""
context = MagicMock()
context.vertex_cache_name = None
context.use_prompt_caching = False
context.cached_static_prompt = None
context.hashed_href_map = {}
context.use_artifact_bundling = False
context.workflow_run_id = None
context.task_id = None
llm_config = LLMConfig(
model_name="GEMINI_2_5_FLASH_WITH_FALLBACK", # router group name, not what response.model returns
required_env_vars=[],
supports_vision=True,
add_assistant_prefix=False,
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", lambda _: llm_config
)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.is_router_config", lambda _: False
)
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.skyvern_context.current", lambda: context)
monkeypatch.setattr(
api_handler_factory, "llm_messages_builder", AsyncMock(return_value=[{"role": "user", "content": "test"}])
)
monkeypatch.setattr(api_handler_factory.litellm, "completion_cost", lambda _: 0.01)
# LiteLLM returns the actual backing model with its provider prefix
async def mock_acompletion(*args, **kwargs):
return FakeLLMResponse("vertex_ai/gemini-2.5-flash")
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", AsyncMock(side_effect=mock_acompletion))
# Capture update_step kwargs to assert on the llm_model value
captured_kwargs: dict = {}
async def mock_update_step(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
artifact_manager = MagicMock()
artifact_manager.prepare_llm_artifact = AsyncMock(return_value=None)
artifact_manager.bulk_create_artifacts = AsyncMock()
monkeypatch.setattr("skyvern.forge.sdk.api.llm.api_handler_factory.app.ARTIFACT_MANAGER", artifact_manager)
monkeypatch.setattr(
"skyvern.forge.sdk.api.llm.api_handler_factory.app.DATABASE.tasks.update_step", mock_update_step
)
now = datetime.now()
step = Step(
created_at=now,
modified_at=now,
task_id="tsk_test",
step_id="stp_test",
status=StepStatus.running,
order=0,
is_last=False,
retry_index=0,
organization_id="org_test",
)
handler = LLMAPIHandlerFactory.get_llm_api_handler("GEMINI_2_5_FLASH_WITH_FALLBACK")
await handler(prompt="test prompt", prompt_name=EXTRACT_ACTION_PROMPT_NAME, step=step)
# The persisted model should be the bare response.model, not the router group key
assert captured_kwargs.get("last_llm_model") == "gemini-2.5-flash"