diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index cdb1f6ff3..6c10d7876 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -87,6 +87,7 @@ from skyvern.forge.sdk.schemas.tasks import ( ) from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline from skyvern.forge.sdk.services import org_auth_service +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.exceptions import ( FailedToCreateWorkflow, FailedToUpdateWorkflow, @@ -1699,7 +1700,7 @@ async def models() -> ModelsResponse: """ Get a list of available models. """ - mapping = settings.get_model_name_to_llm_key() + mapping = SettingsManager.get_settings().get_model_name_to_llm_key() just_labels = {k: v["label"] for k, v in mapping.items() if "anthropic" not in k.lower()} return ModelsResponse(models=just_labels) diff --git a/skyvern/forge/sdk/schemas/task_v2.py b/skyvern/forge/sdk/schemas/task_v2.py index 64f642f9b..4c35166fe 100644 --- a/skyvern/forge/sdk/schemas/task_v2.py +++ b/skyvern/forge/sdk/schemas/task_v2.py @@ -5,7 +5,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator -from skyvern.config import settings +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.schemas.runs import GeoTarget, ProxyLocation, ProxyLocationInput from skyvern.utils.url_validators import validate_url @@ -95,7 +95,7 @@ class TaskV2(BaseModel): if self.model: model_name = self.model.get("model_name") if model_name: - mapping = settings.get_model_name_to_llm_key() + mapping = SettingsManager.get_settings().get_model_name_to_llm_key() llm_key = mapping.get(model_name, {}).get("llm_key") if llm_key: return llm_key diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 42ec74e12..139891a49 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -8,7 +8,6 @@ from fastapi import status from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from skyvern.config import settings from skyvern.exceptions import ( InvalidTaskStatusTransition, SkyvernHTTPException, @@ -17,6 +16,7 @@ from skyvern.exceptions import ( ) from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.schemas.files import FileInfo +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.schemas.docs.doc_strings import PROXY_LOCATION_DOC_STRING from skyvern.schemas.runs import ProxyLocationInput from skyvern.utils.url_validators import validate_url @@ -303,7 +303,7 @@ class Task(TaskBase): if self.model: model_name = self.model.get("model_name") if model_name: - mapping = settings.get_model_name_to_llm_key() + mapping = SettingsManager.get_settings().get_model_name_to_llm_key() return mapping.get(model_name, {}).get("llm_key") return None diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 881b4fd18..87f5df707 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -76,6 +76,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus from skyvern.forge.sdk.services.bitwarden import BitwardenConstants from skyvern.forge.sdk.services.credentials import AzureVaultConstants, OnePasswordConstants +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.trace import traced from skyvern.forge.sdk.utils.pdf_parser import extract_pdf_file, validate_pdf_file from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text @@ -182,7 +183,7 @@ class Block(BaseModel, abc.ABC): if self.model: model_name = self.model.get("model_name") if model_name: - mapping = settings.get_model_name_to_llm_key() + mapping = SettingsManager.get_settings().get_model_name_to_llm_key() return mapping.get(model_name, {}).get("llm_key") return None diff --git a/tests/unit/test_text_prompt_block.py b/tests/unit/test_text_prompt_block.py index 6e83afa78..2f8026951 100644 --- a/tests/unit/test_text_prompt_block.py +++ b/tests/unit/test_text_prompt_block.py @@ -4,7 +4,9 @@ from unittest.mock import AsyncMock import pytest +from skyvern.config import settings as base_settings from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.models.block import TextPromptBlock from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType @@ -20,6 +22,8 @@ block_module = sys.modules["skyvern.forge.sdk.workflow.models.block"] ], ) async def test_text_prompt_block_uses_selected_model(monkeypatch, model_name, expected_llm_key): + # Reset SettingsManager to base settings so cloud overrides from earlier tests don't leak + monkeypatch.setattr(SettingsManager, "_SettingsManager__instance", base_settings) now = datetime.now(timezone.utc) output_parameter = OutputParameter( parameter_type=ParameterType.OUTPUT,