diff --git a/skyvern/forge/sdk/db/mixins/credentials.py b/skyvern/forge/sdk/db/mixins/credentials.py index 32cf95cd9..d658899fc 100644 --- a/skyvern/forge/sdk/db/mixins/credentials.py +++ b/skyvern/forge/sdk/db/mixins/credentials.py @@ -14,8 +14,6 @@ from skyvern.forge.sdk.schemas.organization_bitwarden_collections import Organiz if TYPE_CHECKING: from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory -from skyvern.forge.sdk.db._sentinels import _UNSET - class CredentialsMixin: """Database operations for credential and Bitwarden collection management.""" @@ -100,8 +98,10 @@ class CredentialsMixin: credential_id: str, organization_id: str, name: str | None = None, - browser_profile_id: str | None | object = _UNSET, - tested_url: str | None | object = _UNSET, + browser_profile_id: str | None = None, + tested_url: str | None = None, + user_context: str | None = None, + save_browser_session_intent: bool | None = None, ) -> Credential: async with self.Session() as session: credential = ( @@ -116,10 +116,14 @@ class CredentialsMixin: raise NotFoundError(f"Credential {credential_id} not found") if name is not None: credential.name = name - if browser_profile_id is not _UNSET: + if browser_profile_id is not None: credential.browser_profile_id = browser_profile_id - if tested_url is not _UNSET: + if tested_url is not None: credential.tested_url = tested_url + if user_context is not None: + credential.user_context = user_context + if save_browser_session_intent is not None: + credential.save_browser_session_intent = save_browser_session_intent await session.commit() await session.refresh(credential) return Credential.model_validate(credential) diff --git a/skyvern/forge/sdk/db/repositories/browser_sessions.py b/skyvern/forge/sdk/db/repositories/browser_sessions.py index 7af27782c..55cc25ad8 100644 --- a/skyvern/forge/sdk/db/repositories/browser_sessions.py +++ b/skyvern/forge/sdk/db/repositories/browser_sessions.py @@ -241,6 +241,7 @@ class BrowserSessionsRepository(BaseRepository): timeout_minutes: int | None = None, organization_id: str | None = None, completed_at: datetime | None = None, + started_at: datetime | None = None, ) -> PersistentBrowserSession: async with self.Session() as session: persistent_browser_session = ( @@ -260,6 +261,8 @@ class BrowserSessionsRepository(BaseRepository): persistent_browser_session.timeout_minutes = timeout_minutes if completed_at: persistent_browser_session.completed_at = completed_at + if started_at: + persistent_browser_session.started_at = started_at await session.commit() await session.refresh(persistent_browser_session) diff --git a/skyvern/forge/sdk/db/repositories/credentials.py b/skyvern/forge/sdk/db/repositories/credentials.py index 72df1d3e2..8bf474854 100644 --- a/skyvern/forge/sdk/db/repositories/credentials.py +++ b/skyvern/forge/sdk/db/repositories/credentials.py @@ -5,7 +5,6 @@ from datetime import datetime, timezone from sqlalchemy import select from skyvern.forge.sdk.db._error_handling import db_operation -from skyvern.forge.sdk.db._sentinels import _UNSET from skyvern.forge.sdk.db.base_repository import BaseRepository from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.db.models import CredentialModel, OrganizationBitwardenCollectionModel @@ -94,8 +93,10 @@ class CredentialRepository(BaseRepository): credential_id: str, organization_id: str, name: str | None = None, - browser_profile_id: str | None | object = _UNSET, - tested_url: str | None | object = _UNSET, + browser_profile_id: str | None = None, + tested_url: str | None = None, + user_context: str | None = None, + save_browser_session_intent: bool | None = None, ) -> Credential: async with self.Session() as session: credential = ( @@ -110,10 +111,14 @@ class CredentialRepository(BaseRepository): raise NotFoundError(f"Credential {credential_id} not found") if name is not None: credential.name = name - if browser_profile_id is not _UNSET: + if browser_profile_id is not None: credential.browser_profile_id = browser_profile_id - if tested_url is not _UNSET: + if tested_url is not None: credential.tested_url = tested_url + if user_context is not None: + credential.user_context = user_context + if save_browser_session_intent is not None: + credential.save_browser_session_intent = save_browser_session_intent await session.commit() await session.refresh(credential) return Credential.model_validate(credential) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a0ca9f347..2fc98eed2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,6 @@ # -- begin speed up unit tests +from unittest.mock import AsyncMock, MagicMock + import pytest from tests.unit.force_stub_app import start_forge_stub_app @@ -28,3 +30,32 @@ from tests.unit.force_stub_app import start_forge_stub_app def setup_forge_stub_app(): start_forge_stub_app() yield + + +# -- shared helpers for repository unit tests -- + + +class MockAsyncSessionCtx: + """Async context manager wrapping a mock SQLAlchemy session for repository tests.""" + + def __init__(self, session: AsyncMock): + self._session = session + + async def __aenter__(self): + return self._session + + async def __aexit__(self, *args): + pass + + +def make_mock_session(mock_model: MagicMock) -> AsyncMock: + """Create a mock SQLAlchemy session that returns mock_model from scalars().first().""" + scalars_result = MagicMock() + scalars_result.first.return_value = mock_model + + mock_session = AsyncMock() + mock_session.scalars.return_value = scalars_result + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + return mock_session diff --git a/tests/unit/test_browser_session_update_started_at.py b/tests/unit/test_browser_session_update_started_at.py new file mode 100644 index 000000000..f1e67e49a --- /dev/null +++ b/tests/unit/test_browser_session_update_started_at.py @@ -0,0 +1,58 @@ +"""Test that BrowserSessionsRepository.update_persistent_browser_session() accepts started_at.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository +from tests.unit.conftest import MockAsyncSessionCtx, make_mock_session + + +def _make_browser_repo(mock_pbs: MagicMock) -> BrowserSessionsRepository: + mock_session = make_mock_session(mock_pbs) + return BrowserSessionsRepository(session_factory=lambda: MockAsyncSessionCtx(mock_session)) + + +@pytest.mark.asyncio +async def test_update_persistent_browser_session_accepts_started_at() -> None: + """update_persistent_browser_session() must accept started_at without raising TypeError.""" + mock_pbs = MagicMock() + mock_pbs.status = "running" + mock_pbs.started_at = None + repo = _make_browser_repo(mock_pbs) + + now = datetime.now(timezone.utc) + with patch( + "skyvern.forge.sdk.schemas.persistent_browser_sessions.PersistentBrowserSession.model_validate", + return_value=MagicMock(), + ): + await repo.update_persistent_browser_session( + "pbs_123", + organization_id="org_123", + started_at=now, + ) + + assert mock_pbs.started_at == now + + +@pytest.mark.asyncio +async def test_update_persistent_browser_session_without_started_at() -> None: + """When started_at is not passed, the field should not be touched.""" + mock_pbs = MagicMock() + mock_pbs.status = "created" + original_started_at = datetime(2026, 1, 1, tzinfo=timezone.utc) + mock_pbs.started_at = original_started_at + repo = _make_browser_repo(mock_pbs) + + with patch( + "skyvern.forge.sdk.schemas.persistent_browser_sessions.PersistentBrowserSession.model_validate", + return_value=MagicMock(), + ): + await repo.update_persistent_browser_session( + "pbs_123", + organization_id="org_123", + status="running", + ) + + assert mock_pbs.started_at == original_started_at diff --git a/tests/unit/test_credential_update_params.py b/tests/unit/test_credential_update_params.py new file mode 100644 index 000000000..2c0270220 --- /dev/null +++ b/tests/unit/test_credential_update_params.py @@ -0,0 +1,114 @@ +"""Tests that update_credential() accepts user_context and save_browser_session_intent +on both CredentialRepository and CredentialsMixin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin +from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository +from tests.unit.conftest import MockAsyncSessionCtx, make_mock_session + + +def _make_credential_repo(mock_credential: MagicMock) -> CredentialRepository: + mock_session = make_mock_session(mock_credential) + return CredentialRepository(session_factory=lambda: MockAsyncSessionCtx(mock_session)) + + +def _make_credential_mixin(mock_credential: MagicMock) -> CredentialsMixin: + mock_session = make_mock_session(mock_credential) + mixin = CredentialsMixin.__new__(CredentialsMixin) + mixin.Session = lambda: MockAsyncSessionCtx(mock_session) # type: ignore[assignment] + return mixin + + +# --- CredentialRepository tests --- + + +@pytest.mark.asyncio +async def test_repo_update_credential_accepts_user_context() -> None: + mock_credential = MagicMock() + mock_credential.name = "test" + mock_credential.user_context = None + repo = _make_credential_repo(mock_credential) + + with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()): + await repo.update_credential( + credential_id="cred_123", + organization_id="org_123", + user_context="Click SSO button first", + ) + + assert mock_credential.user_context == "Click SSO button first" + + +@pytest.mark.asyncio +async def test_repo_update_credential_accepts_save_browser_session_intent() -> None: + mock_credential = MagicMock() + mock_credential.name = "test" + mock_credential.save_browser_session_intent = False + repo = _make_credential_repo(mock_credential) + + with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()): + await repo.update_credential( + credential_id="cred_123", + organization_id="org_123", + save_browser_session_intent=True, + ) + + assert mock_credential.save_browser_session_intent is True + + +@pytest.mark.asyncio +async def test_repo_update_credential_unset_params_not_applied() -> None: + mock_credential = MagicMock() + mock_credential.name = "test" + mock_credential.user_context = "existing" + mock_credential.save_browser_session_intent = True + repo = _make_credential_repo(mock_credential) + + with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()): + await repo.update_credential( + credential_id="cred_123", + organization_id="org_123", + ) + + assert mock_credential.user_context == "existing" + assert mock_credential.save_browser_session_intent is True + + +# --- CredentialsMixin tests --- + + +@pytest.mark.asyncio +async def test_mixin_update_credential_accepts_user_context() -> None: + mock_credential = MagicMock() + mock_credential.name = "test" + mock_credential.user_context = None + mixin = _make_credential_mixin(mock_credential) + + with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()): + await mixin.update_credential( + credential_id="cred_123", + organization_id="org_123", + user_context="Click SSO button first", + ) + + assert mock_credential.user_context == "Click SSO button first" + + +@pytest.mark.asyncio +async def test_mixin_update_credential_accepts_save_browser_session_intent() -> None: + mock_credential = MagicMock() + mock_credential.name = "test" + mock_credential.save_browser_session_intent = False + mixin = _make_credential_mixin(mock_credential) + + with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()): + await mixin.update_credential( + credential_id="cred_123", + organization_id="org_123", + save_browser_session_intent=True, + ) + + assert mock_credential.save_browser_session_intent is True