diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 6729577ed..c8f921583 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -13,20 +13,21 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from skyvern.config import settings from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB -from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin -from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin -from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin -from skyvern.forge.sdk.db.mixins.debug import DebugMixin -from skyvern.forge.sdk.db.mixins.folders import FoldersMixin -from skyvern.forge.sdk.db.mixins.observer import ObserverMixin -from skyvern.forge.sdk.db.mixins.organizations import OrganizationsMixin -from skyvern.forge.sdk.db.mixins.otp import OTPMixin -from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError, SchedulesMixin -from skyvern.forge.sdk.db.mixins.scripts import ScriptsMixin -from skyvern.forge.sdk.db.mixins.tasks import TasksMixin -from skyvern.forge.sdk.db.mixins.workflow_parameters import WorkflowParametersMixin -from skyvern.forge.sdk.db.mixins.workflow_runs import WorkflowRunsMixin -from skyvern.forge.sdk.db.mixins.workflows import WorkflowsMixin +from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError # noqa: F401 +from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository +from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository +from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository +from skyvern.forge.sdk.db.repositories.debug import DebugRepository +from skyvern.forge.sdk.db.repositories.folders import FoldersRepository +from skyvern.forge.sdk.db.repositories.observer import ObserverRepository +from skyvern.forge.sdk.db.repositories.organizations import OrganizationsRepository +from skyvern.forge.sdk.db.repositories.otp import OTPRepository +from skyvern.forge.sdk.db.repositories.schedules import SchedulesRepository +from skyvern.forge.sdk.db.repositories.scripts import ScriptsRepository +from skyvern.forge.sdk.db.repositories.tasks import TasksRepository +from skyvern.forge.sdk.db.repositories.workflow_parameters import WorkflowParametersRepository +from skyvern.forge.sdk.db.repositories.workflow_runs import WorkflowRunsRepository +from skyvern.forge.sdk.db.repositories.workflows import WorkflowsRepository from skyvern.forge.sdk.db.utils import ( _custom_json_serializer, ) @@ -104,23 +105,7 @@ def _build_engine(database_string: str) -> AsyncEngine: __all__ = ["AgentDB", "ScheduleLimitExceededError"] -class AgentDB( - TasksMixin, - WorkflowsMixin, - WorkflowRunsMixin, - WorkflowParametersMixin, - SchedulesMixin, - ArtifactsMixin, - BrowserSessionsMixin, - ScriptsMixin, - OTPMixin, - CredentialsMixin, - FoldersMixin, - OrganizationsMixin, - ObserverMixin, - DebugMixin, - BaseAlchemyDB, -): +class AgentDB(BaseAlchemyDB): def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None: super().__init__(db_engine or _build_engine(database_string)) self.debug_enabled = debug_enabled @@ -131,6 +116,819 @@ class AgentDB( asyncio.Lock() if self.engine.dialect.name == "sqlite" else None ) + # -- Zero-dependency repositories -- + self.tasks = TasksRepository(self.Session, debug_enabled, self.is_retryable_error) + self.workflows = WorkflowsRepository(self.Session, debug_enabled, self.is_retryable_error) + self.workflow_params = WorkflowParametersRepository(self.Session, debug_enabled, self.is_retryable_error) + self.credentials = CredentialRepository(self.Session, debug_enabled, self.is_retryable_error) + self.otp = OTPRepository(self.Session, debug_enabled, self.is_retryable_error) + self.debug = DebugRepository(self.Session, debug_enabled, self.is_retryable_error) + self.organizations = OrganizationsRepository(self.Session, debug_enabled, self.is_retryable_error) + self.scripts = ScriptsRepository(self.Session, debug_enabled, self.is_retryable_error) + self.browser_sessions = BrowserSessionsRepository(self.Session, debug_enabled, self.is_retryable_error) + self.schedules = SchedulesRepository( + self.Session, + debug_enabled, + self.is_retryable_error, + sqlite_schedule_lock=self._sqlite_schedule_lock, + ) + + # -- Cross-dependency repositories -- + self.workflow_runs = WorkflowRunsRepository( + self.Session, + debug_enabled, + self.is_retryable_error, + workflow_parameter_reader=self.workflow_params, + dialect_name=self.engine.dialect.name, + ) + self.artifacts = ArtifactsRepository( + self.Session, + debug_enabled, + self.is_retryable_error, + run_reader=self.workflow_runs, + ) + self.folders = FoldersRepository( + self.Session, + debug_enabled, + self.is_retryable_error, + workflow_reader=self.workflows, + ) + self.observer = ObserverRepository( + self.Session, + debug_enabled, + self.is_retryable_error, + task_reader=self.tasks, + ) + def is_retryable_error(self, error: SQLAlchemyError) -> bool: error_msg = str(error).lower() return "server closed the connection" in error_msg + + # ====================================================================== + # Backward-compatible delegate methods + # TODO(SKY-62): These delegates erase type information (*args: Any -> Any). + # Migrate callers to use typed repository attributes directly + # (e.g., db.tasks.get_task(...) instead of db.get_task(...)), then remove. + # ====================================================================== + + # -- Task delegates -- + + async def create_task(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.create_task(*args, **kwargs) + + async def create_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.create_step(*args, **kwargs) + + async def get_task(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task(*args, **kwargs) + + async def get_tasks_by_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_tasks_by_ids(*args, **kwargs) + + async def get_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_step(*args, **kwargs) + + async def get_task_steps(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task_steps(*args, **kwargs) + + async def get_steps_by_task_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_steps_by_task_ids(*args, **kwargs) + + async def get_total_unique_step_order_count_by_task_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_total_unique_step_order_count_by_task_ids(*args, **kwargs) + + async def get_task_step_models(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task_step_models(*args, **kwargs) + + async def get_task_step_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task_step_count(*args, **kwargs) + + async def get_task_actions(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task_actions(*args, **kwargs) + + async def get_task_actions_hydrated(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_task_actions_hydrated(*args, **kwargs) + + async def get_tasks_actions(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_tasks_actions(*args, **kwargs) + + async def get_action_count_for_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_action_count_for_step(*args, **kwargs) + + async def get_first_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_first_step(*args, **kwargs) + + async def get_latest_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_latest_step(*args, **kwargs) + + async def update_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.update_step(*args, **kwargs) + + async def clear_task_failure_reason(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.clear_task_failure_reason(*args, **kwargs) + + async def update_task(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.update_task(*args, **kwargs) + + async def update_task_2fa_state(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.update_task_2fa_state(*args, **kwargs) + + async def bulk_update_tasks(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.bulk_update_tasks(*args, **kwargs) + + async def get_tasks(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_tasks(*args, **kwargs) + + async def get_tasks_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_tasks_count(*args, **kwargs) + + async def get_running_tasks_info_globally(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_running_tasks_info_globally(*args, **kwargs) + + async def get_latest_task_by_workflow_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_latest_task_by_workflow_id(*args, **kwargs) + + async def get_last_task_for_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_last_task_for_workflow_run(*args, **kwargs) + + async def get_tasks_by_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_tasks_by_workflow_run_id(*args, **kwargs) + + async def delete_task_steps(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.delete_task_steps(*args, **kwargs) + + async def get_previous_actions_for_task(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.get_previous_actions_for_task(*args, **kwargs) + + async def delete_task_actions(self, *args: Any, **kwargs: Any) -> Any: + return await self.tasks.delete_task_actions(*args, **kwargs) + + # -- Workflow delegates -- + + async def create_workflow(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.create_workflow(*args, **kwargs) + + async def soft_delete_workflow_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.soft_delete_workflow_by_id(*args, **kwargs) + + async def get_workflow(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflow(*args, **kwargs) + + async def get_workflow_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflow_by_permanent_id(*args, **kwargs) + + async def get_workflow_for_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflow_for_workflow_run(*args, **kwargs) + + async def get_workflow_versions_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflow_versions_by_permanent_id(*args, **kwargs) + + async def get_workflows_by_permanent_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflows_by_permanent_ids(*args, **kwargs) + + async def get_workflows_by_organization_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_workflows_by_organization_id(*args, **kwargs) + + async def update_workflow(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.update_workflow(*args, **kwargs) + + async def soft_delete_workflow_and_schedules_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.soft_delete_workflow_and_schedules_by_permanent_id(*args, **kwargs) + + async def add_workflow_template(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.add_workflow_template(*args, **kwargs) + + async def remove_workflow_template(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.remove_workflow_template(*args, **kwargs) + + async def get_org_template_permanent_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.get_org_template_permanent_ids(*args, **kwargs) + + async def is_workflow_template(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflows.is_workflow_template(*args, **kwargs) + + # -- Workflow run delegates -- + + async def get_running_workflow_runs_info_globally(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_running_workflow_runs_info_globally(*args, **kwargs) + + async def create_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.create_workflow_run(*args, **kwargs) + + async def update_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.update_workflow_run(*args, **kwargs) + + async def bulk_update_workflow_runs(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.bulk_update_workflow_runs(*args, **kwargs) + + async def clear_workflow_run_failure_reason(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.clear_workflow_run_failure_reason(*args, **kwargs) + + async def get_all_runs(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_all_runs(*args, **kwargs) + + async def get_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_run(*args, **kwargs) + + async def get_last_queued_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_last_queued_workflow_run(*args, **kwargs) + + async def get_workflow_runs_by_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_runs_by_ids(*args, **kwargs) + + async def get_last_running_workflow_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_last_running_workflow_run(*args, **kwargs) + + async def get_last_workflow_run_for_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_last_workflow_run_for_browser_session(*args, **kwargs) + + async def get_last_workflow_run_for_browser_address(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_last_workflow_run_for_browser_address(*args, **kwargs) + + async def get_workflows_depending_on(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflows_depending_on(*args, **kwargs) + + async def get_workflow_runs(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_runs(*args, **kwargs) + + async def get_workflow_runs_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_runs_count(*args, **kwargs) + + async def get_workflow_runs_for_workflow_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_runs_for_workflow_permanent_id(*args, **kwargs) + + async def get_workflow_runs_by_parent_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(*args, **kwargs) + + async def get_workflow_run_output_parameters(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_run_output_parameters(*args, **kwargs) + + async def get_workflow_run_output_parameter_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_run_output_parameter_by_id(*args, **kwargs) + + async def create_or_update_workflow_run_output_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.create_or_update_workflow_run_output_parameter(*args, **kwargs) + + async def update_workflow_run_output_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.update_workflow_run_output_parameter(*args, **kwargs) + + async def create_workflow_run_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.create_workflow_run_parameter(*args, **kwargs) + + async def get_workflow_run_parameters(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_run_parameters(*args, **kwargs) + + async def get_workflow_run_block_errors(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_runs.get_workflow_run_block_errors(*args, **kwargs) + + # -- Workflow parameter delegates -- + + async def create_workflow_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_workflow_parameter(*args, **kwargs) + + async def create_aws_secret_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_aws_secret_parameter(*args, **kwargs) + + async def create_output_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_output_parameter(*args, **kwargs) + + async def save_workflow_definition_parameters(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.save_workflow_definition_parameters(*args, **kwargs) + + async def get_workflow_output_parameters(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_output_parameters(*args, **kwargs) + + async def get_workflow_output_parameters_by_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_output_parameters_by_ids(*args, **kwargs) + + async def get_workflow_parameters(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_parameters(*args, **kwargs) + + async def get_workflow_parameter(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_parameter(*args, **kwargs) + + async def create_task_generation(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_task_generation(*args, **kwargs) + + async def create_ai_suggestion(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_ai_suggestion(*args, **kwargs) + + async def create_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_workflow_copilot_chat(*args, **kwargs) + + async def update_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.update_workflow_copilot_chat(*args, **kwargs) + + async def create_workflow_copilot_chat_message(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_workflow_copilot_chat_message(*args, **kwargs) + + async def get_workflow_copilot_chat_messages(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_copilot_chat_messages(*args, **kwargs) + + async def get_workflow_copilot_chat_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_workflow_copilot_chat_by_id(*args, **kwargs) + + async def get_latest_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_latest_workflow_copilot_chat(*args, **kwargs) + + async def get_task_generation_by_prompt_hash(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_task_generation_by_prompt_hash(*args, **kwargs) + + async def create_action(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_action(*args, **kwargs) + + async def update_action_reasoning(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.update_action_reasoning(*args, **kwargs) + + async def retrieve_action_plan(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.retrieve_action_plan(*args, **kwargs) + + async def create_task_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.create_task_run(*args, **kwargs) + + async def update_task_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.update_task_run(*args, **kwargs) + + async def update_job_run_compute_cost(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.update_job_run_compute_cost(*args, **kwargs) + + async def cache_task_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.cache_task_run(*args, **kwargs) + + async def get_cached_task_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_cached_task_run(*args, **kwargs) + + async def get_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.workflow_params.get_run(*args, **kwargs) + + # -- Artifact delegates -- + + async def create_artifact(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.create_artifact(*args, **kwargs) + + async def bulk_create_artifacts(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.bulk_create_artifacts(*args, **kwargs) + + async def get_artifacts_for_task_v2(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifacts_for_task_v2(*args, **kwargs) + + async def get_artifacts_for_task_step(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifacts_for_task_step(*args, **kwargs) + + async def get_artifacts_for_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifacts_for_run(*args, **kwargs) + + async def get_artifact_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifact_by_id(*args, **kwargs) + + async def get_artifacts_by_ids(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifacts_by_ids(*args, **kwargs) + + async def get_artifacts_by_entity_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifacts_by_entity_id(*args, **kwargs) + + async def get_artifact_by_entity_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifact_by_entity_id(*args, **kwargs) + + async def get_artifact(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifact(*args, **kwargs) + + async def get_artifact_for_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifact_for_run(*args, **kwargs) + + async def get_latest_artifact(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_latest_artifact(*args, **kwargs) + + async def get_latest_n_artifacts(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_latest_n_artifacts(*args, **kwargs) + + async def delete_task_artifacts(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.delete_task_artifacts(*args, **kwargs) + + async def delete_task_v2_artifacts(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.delete_task_v2_artifacts(*args, **kwargs) + + async def update_action_screenshot_artifact_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.update_action_screenshot_artifact_id(*args, **kwargs) + + # -- Browser session delegates -- + + async def create_browser_profile(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.create_browser_profile(*args, **kwargs) + + async def get_browser_profile(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_browser_profile(*args, **kwargs) + + async def list_browser_profiles(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.list_browser_profiles(*args, **kwargs) + + async def delete_browser_profile(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.delete_browser_profile(*args, **kwargs) + + async def get_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_active_persistent_browser_sessions(*args, **kwargs) + + async def get_persistent_browser_sessions_history(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_persistent_browser_sessions_history(*args, **kwargs) + + async def get_persistent_browser_session_by_runnable_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_persistent_browser_session_by_runnable_id(*args, **kwargs) + + async def get_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_persistent_browser_session(*args, **kwargs) + + async def create_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.create_persistent_browser_session(*args, **kwargs) + + async def update_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.update_persistent_browser_session(*args, **kwargs) + + async def set_persistent_browser_session_browser_address(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.set_persistent_browser_session_browser_address(*args, **kwargs) + + async def update_persistent_browser_session_compute_cost(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.update_persistent_browser_session_compute_cost(*args, **kwargs) + + async def mark_persistent_browser_session_deleted(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.mark_persistent_browser_session_deleted(*args, **kwargs) + + async def occupy_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.occupy_persistent_browser_session(*args, **kwargs) + + async def release_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.release_persistent_browser_session(*args, **kwargs) + + async def close_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.close_persistent_browser_session(*args, **kwargs) + + async def get_all_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_all_active_persistent_browser_sessions(*args, **kwargs) + + async def archive_browser_session_address(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.archive_browser_session_address(*args, **kwargs) + + async def get_uncompleted_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_uncompleted_persistent_browser_sessions(*args, **kwargs) + + async def get_debug_session_by_browser_session_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.browser_sessions.get_debug_session_by_browser_session_id(*args, **kwargs) + + # -- Schedule delegates -- + + async def create_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.create_workflow_schedule(*args, **kwargs) + + async def create_workflow_schedule_with_limit(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.create_workflow_schedule_with_limit(*args, **kwargs) + + async def set_temporal_schedule_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.set_temporal_schedule_id(*args, **kwargs) + + async def update_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.update_workflow_schedule(*args, **kwargs) + + async def get_workflow_schedule_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.get_workflow_schedule_by_id(*args, **kwargs) + + async def get_workflow_schedules(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.get_workflow_schedules(*args, **kwargs) + + async def get_all_enabled_schedules(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.get_all_enabled_schedules(*args, **kwargs) + + async def has_schedule_fired_since(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.has_schedule_fired_since(*args, **kwargs) + + async def update_workflow_schedule_enabled(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.update_workflow_schedule_enabled(*args, **kwargs) + + async def delete_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.delete_workflow_schedule(*args, **kwargs) + + async def restore_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.restore_workflow_schedule(*args, **kwargs) + + async def count_workflow_schedules(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.count_workflow_schedules(*args, **kwargs) + + async def list_organization_schedules(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.list_organization_schedules(*args, **kwargs) + + async def soft_delete_orphaned_schedules(self, *args: Any, **kwargs: Any) -> Any: + return await self.schedules.soft_delete_orphaned_schedules(*args, **kwargs) + + # -- Script delegates -- + + async def create_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.create_script(*args, **kwargs) + + async def get_scripts(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_scripts(*args, **kwargs) + + async def get_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script(*args, **kwargs) + + async def get_script_revision(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_revision(*args, **kwargs) + + async def get_latest_script_version(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_latest_script_version(*args, **kwargs) + + async def get_script_versions(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_versions(*args, **kwargs) + + async def get_script_version_stats(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_version_stats(*args, **kwargs) + + async def soft_delete_script_by_revision(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.soft_delete_script_by_revision(*args, **kwargs) + + async def create_script_file(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.create_script_file(*args, **kwargs) + + async def create_script_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.create_script_block(*args, **kwargs) + + async def update_script_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.update_script_block(*args, **kwargs) + + async def get_script_files(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_files(*args, **kwargs) + + async def get_script_file_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_file_by_id(*args, **kwargs) + + async def get_script_file_by_path(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_file_by_path(*args, **kwargs) + + async def update_script_file(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.update_script_file(*args, **kwargs) + + async def get_script_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_block(*args, **kwargs) + + async def get_script_block_by_label(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_block_by_label(*args, **kwargs) + + async def get_script_blocks_by_script_revision_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_blocks_by_script_revision_id(*args, **kwargs) + + async def create_workflow_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.create_workflow_script(*args, **kwargs) + + async def get_workflow_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_script(*args, **kwargs) + + async def get_workflow_script_by_cache_key_value(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_script_by_cache_key_value(*args, **kwargs) + + async def get_workflow_cache_key_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_cache_key_count(*args, **kwargs) + + async def get_workflow_cache_key_values(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_cache_key_values(*args, **kwargs) + + async def delete_workflow_cache_key_value(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.delete_workflow_cache_key_value(*args, **kwargs) + + async def delete_workflow_scripts_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.delete_workflow_scripts_by_permanent_id(*args, **kwargs) + + async def get_workflow_scripts_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_scripts_by_permanent_id(*args, **kwargs) + + async def get_workflow_runs_for_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_workflow_runs_for_script(*args, **kwargs) + + async def get_script_run_stats(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_script_run_stats(*args, **kwargs) + + async def is_script_pinned(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.is_script_pinned(*args, **kwargs) + + async def pin_workflow_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.pin_workflow_script(*args, **kwargs) + + async def unpin_workflow_script(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.unpin_workflow_script(*args, **kwargs) + + async def create_fallback_episode(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.create_fallback_episode(*args, **kwargs) + + async def get_unreviewed_episodes(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_unreviewed_episodes(*args, **kwargs) + + async def update_fallback_episode(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.update_fallback_episode(*args, **kwargs) + + async def delete_fallback_episode(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.delete_fallback_episode(*args, **kwargs) + + async def get_fallback_episodes(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_fallback_episodes(*args, **kwargs) + + async def get_fallback_episodes_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_fallback_episodes_count(*args, **kwargs) + + async def get_fallback_episode(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_fallback_episode(*args, **kwargs) + + async def mark_episode_reviewed(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.mark_episode_reviewed(*args, **kwargs) + + async def get_recent_reviewed_episodes(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_recent_reviewed_episodes(*args, **kwargs) + + async def record_branch_hit(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.record_branch_hit(*args, **kwargs) + + async def get_stale_branches(self, *args: Any, **kwargs: Any) -> Any: + return await self.scripts.get_stale_branches(*args, **kwargs) + + # -- OTP delegates -- + + async def get_otp_codes(self, *args: Any, **kwargs: Any) -> Any: + return await self.otp.get_otp_codes(*args, **kwargs) + + async def get_otp_codes_by_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.otp.get_otp_codes_by_run(*args, **kwargs) + + async def get_recent_otp_codes(self, *args: Any, **kwargs: Any) -> Any: + return await self.otp.get_recent_otp_codes(*args, **kwargs) + + async def create_otp_code(self, *args: Any, **kwargs: Any) -> Any: + return await self.otp.create_otp_code(*args, **kwargs) + + # -- Credential delegates -- + + async def create_credential(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.create_credential(*args, **kwargs) + + async def get_credential(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.get_credential(*args, **kwargs) + + async def get_credentials(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.get_credentials(*args, **kwargs) + + async def update_credential(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.update_credential(*args, **kwargs) + + async def update_credential_vault_data(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.update_credential_vault_data(*args, **kwargs) + + async def delete_credential(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.delete_credential(*args, **kwargs) + + async def create_organization_bitwarden_collection(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.create_organization_bitwarden_collection(*args, **kwargs) + + async def get_organization_bitwarden_collection(self, *args: Any, **kwargs: Any) -> Any: + return await self.credentials.get_organization_bitwarden_collection(*args, **kwargs) + + # -- Folder delegates -- + + async def create_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.create_folder(*args, **kwargs) + + async def get_folders(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.get_folders(*args, **kwargs) + + async def get_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.get_folder(*args, **kwargs) + + async def update_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.update_folder(*args, **kwargs) + + async def get_workflow_permanent_ids_in_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.get_workflow_permanent_ids_in_folder(*args, **kwargs) + + async def soft_delete_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.soft_delete_folder(*args, **kwargs) + + async def get_folder_workflow_count(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.get_folder_workflow_count(*args, **kwargs) + + async def get_folder_workflow_counts_batch(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.get_folder_workflow_counts_batch(*args, **kwargs) + + async def update_workflow_folder(self, *args: Any, **kwargs: Any) -> Any: + return await self.folders.update_workflow_folder(*args, **kwargs) + + # -- Organization delegates -- + + async def get_active_verification_requests(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_active_verification_requests(*args, **kwargs) + + async def get_all_organizations(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_all_organizations(*args, **kwargs) + + async def get_organization(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_organization(*args, **kwargs) + + async def get_organization_by_domain(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_organization_by_domain(*args, **kwargs) + + async def create_organization(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.create_organization(*args, **kwargs) + + async def update_organization(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.update_organization(*args, **kwargs) + + async def get_valid_org_auth_token(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_valid_org_auth_token(*args, **kwargs) + + async def get_valid_org_auth_tokens(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.get_valid_org_auth_tokens(*args, **kwargs) + + async def validate_org_auth_token(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.validate_org_auth_token(*args, **kwargs) + + async def create_org_auth_token(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.create_org_auth_token(*args, **kwargs) + + async def invalidate_org_auth_tokens(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.invalidate_org_auth_tokens(*args, **kwargs) + + # -- Observer delegates -- + + async def get_task_v2(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_task_v2(*args, **kwargs) + + async def delete_thoughts(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.delete_thoughts(*args, **kwargs) + + async def get_task_v2_by_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_task_v2_by_workflow_run_id(*args, **kwargs) + + async def get_thought(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_thought(*args, **kwargs) + + async def get_thoughts(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_thoughts(*args, **kwargs) + + async def create_task_v2(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.create_task_v2(*args, **kwargs) + + async def create_thought(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.create_thought(*args, **kwargs) + + async def update_thought(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.update_thought(*args, **kwargs) + + async def update_task_v2(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.update_task_v2(*args, **kwargs) + + async def create_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.create_workflow_run_block(*args, **kwargs) + + async def delete_workflow_run_blocks(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.delete_workflow_run_blocks(*args, **kwargs) + + async def update_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.update_workflow_run_block(*args, **kwargs) + + async def get_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_workflow_run_block(*args, **kwargs) + + async def get_workflow_run_block_by_task_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_workflow_run_block_by_task_id(*args, **kwargs) + + async def get_workflow_run_blocks(self, *args: Any, **kwargs: Any) -> Any: + return await self.observer.get_workflow_run_blocks(*args, **kwargs) + + # -- Debug delegates -- + + async def get_debug_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_debug_session(*args, **kwargs) + + async def get_latest_block_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_latest_block_run(*args, **kwargs) + + async def get_latest_completed_block_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_latest_completed_block_run(*args, **kwargs) + + async def create_block_run(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.create_block_run(*args, **kwargs) + + async def get_latest_debug_session_for_user(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_latest_debug_session_for_user(*args, **kwargs) + + async def get_debug_session_by_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_debug_session_by_id(*args, **kwargs) + + async def get_workflow_runs_by_debug_session_id(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.get_workflow_runs_by_debug_session_id(*args, **kwargs) + + async def complete_debug_sessions(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.complete_debug_sessions(*args, **kwargs) + + async def create_debug_session(self, *args: Any, **kwargs: Any) -> Any: + return await self.debug.create_debug_session(*args, **kwargs) + + # -- NEW delegate methods (missing from branch) -- + + async def get_artifact_by_id_no_org(self, *args: Any, **kwargs: Any) -> Any: + return await self.artifacts.get_artifact_by_id_no_org(*args, **kwargs) + + async def replace_org_auth_token(self, *args: Any, **kwargs: Any) -> Any: + return await self.organizations.replace_org_auth_token(*args, **kwargs) diff --git a/skyvern/forge/sdk/db/base_repository.py b/skyvern/forge/sdk/db/base_repository.py new file mode 100644 index 000000000..3f7a0a25b --- /dev/null +++ b/skyvern/forge/sdk/db/base_repository.py @@ -0,0 +1,33 @@ +"""Base class for all repository classes extracted from AgentDB mixins.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from sqlalchemy.exc import SQLAlchemyError + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + + +class BaseRepository: + """Base for domain-specific repositories. + + Provides the session factory, debug flag, and retryable-error check + that decorators like ``read_retry`` and ``db_operation`` rely on. + """ + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + ) -> None: + self.Session = session_factory + self.debug_enabled = debug_enabled + self._is_retryable_error_fn = is_retryable_error_fn + + def is_retryable_error(self, error: SQLAlchemyError) -> bool: + if self._is_retryable_error_fn: + return self._is_retryable_error_fn(error) + return False diff --git a/skyvern/forge/sdk/db/exceptions.py b/skyvern/forge/sdk/db/exceptions.py index c901a5c9f..e1d9bbbb9 100644 --- a/skyvern/forge/sdk/db/exceptions.py +++ b/skyvern/forge/sdk/db/exceptions.py @@ -1,2 +1,13 @@ class NotFoundError(Exception): pass + + +class ScheduleLimitExceededError(Exception): + """Raised when attempting to create a schedule that would exceed the per-workflow limit.""" + + def __init__(self, organization_id: str, workflow_permanent_id: str, current_count: int, max_allowed: int): + self.organization_id = organization_id + self.workflow_permanent_id = workflow_permanent_id + self.current_count = current_count + self.max_allowed = max_allowed + super().__init__(f"Schedule limit {max_allowed} reached (current: {current_count})") diff --git a/skyvern/forge/sdk/db/mixins/__init__.py b/skyvern/forge/sdk/db/mixins/__init__.py index 37b163b9f..cdff30a18 100644 --- a/skyvern/forge/sdk/db/mixins/__init__.py +++ b/skyvern/forge/sdk/db/mixins/__init__.py @@ -1,3 +1,8 @@ +# DEPRECATED: These mixin classes have been superseded by the repository pattern +# in skyvern/forge/sdk/db/repositories/. AgentDB no longer inherits from these +# mixins — it composes repository instances instead. These files are retained +# temporarily to avoid breaking any downstream code that may import from here. +# TODO(SKY-62): Remove after 2026-Q2 once all imports are verified migrated. from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin from skyvern.forge.sdk.db.mixins.base import BaseAlchemyDB, read_retry from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin diff --git a/skyvern/forge/sdk/db/mixins/schedules.py b/skyvern/forge/sdk/db/mixins/schedules.py index d7db4cb0e..fcc382edd 100644 --- a/skyvern/forge/sdk/db/mixins/schedules.py +++ b/skyvern/forge/sdk/db/mixins/schedules.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any import structlog from sqlalchemy import func, or_, select, text, update -from skyvern.forge.sdk.db._error_handling import db_operation, register_passthrough_exception +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError from skyvern.forge.sdk.db.models import ( WorkflowModel, WorkflowRunModel, @@ -26,20 +27,6 @@ LOG = structlog.get_logger() _UNSET = object() -class ScheduleLimitExceededError(Exception): - """Raised when attempting to create a schedule that would exceed the per-workflow limit.""" - - def __init__(self, organization_id: str, workflow_permanent_id: str, current_count: int, max_allowed: int): - self.organization_id = organization_id - self.workflow_permanent_id = workflow_permanent_id - self.current_count = current_count - self.max_allowed = max_allowed - super().__init__(f"Schedule limit {max_allowed} reached (current: {current_count})") - - -register_passthrough_exception(ScheduleLimitExceededError) - - class SchedulesMixin: """Database operations for workflow schedules.""" diff --git a/skyvern/forge/sdk/db/protocols.py b/skyvern/forge/sdk/db/protocols.py new file mode 100644 index 000000000..43fedb192 --- /dev/null +++ b/skyvern/forge/sdk/db/protocols.py @@ -0,0 +1,42 @@ +"""Typed Protocol contracts for cross-repository dependencies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from skyvern.forge.sdk.schemas.tasks import Task + from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter + from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun + + +@runtime_checkable +class TaskReader(Protocol): + async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None: ... + async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]: ... + + +@runtime_checkable +class WorkflowReader(Protocol): + async def get_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + version: int | None = None, + ignore_version: int | None = None, + filter_deleted: bool = True, + ) -> Workflow | None: ... + + +@runtime_checkable +class WorkflowParameterReader(Protocol): + async def get_workflow_parameter( + self, + workflow_parameter_id: str, + organization_id: str | None = None, + ) -> WorkflowParameter | None: ... + + +@runtime_checkable +class RunReader(Protocol): + async def get_run(self, run_id: str, organization_id: str | None = None) -> WorkflowRun | None: ... diff --git a/skyvern/forge/sdk/db/repositories/__init__.py b/skyvern/forge/sdk/db/repositories/__init__.py new file mode 100644 index 000000000..b2f630134 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/__init__.py @@ -0,0 +1,4 @@ +"""Domain-specific repository classes extracted from AgentDB mixins.""" + +# Sentinel for distinguishing "not passed" from "passed as None" in update methods. +_UNSET = object() diff --git a/skyvern/forge/sdk/db/repositories/artifacts.py b/skyvern/forge/sdk/db/repositories/artifacts.py new file mode 100644 index 000000000..488723b86 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/artifacts.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import structlog +from sqlalchemy import and_, delete, or_, select, update +from sqlalchemy.exc import SQLAlchemyError + +from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.models import ActionModel, ArtifactModel +from skyvern.forge.sdk.db.protocols import RunReader +from skyvern.forge.sdk.db.utils import convert_to_artifact + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + +LOG = structlog.get_logger() + + +class ArtifactsRepository(BaseRepository): + """Database operations for artifact management.""" + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + run_reader: RunReader | None = None, + ) -> None: + super().__init__(session_factory, debug_enabled, is_retryable_error_fn) + self._run_reader = run_reader + + @db_operation("create_artifact") + async def create_artifact( + self, + artifact_id: str, + artifact_type: str, + uri: str, + organization_id: str, + step_id: str | None = None, + task_id: str | None = None, + workflow_run_id: str | None = None, + workflow_run_block_id: str | None = None, + task_v2_id: str | None = None, + run_id: str | None = None, + thought_id: str | None = None, + ai_suggestion_id: str | None = None, + ) -> Artifact: + async with self.Session() as session: + new_artifact = ArtifactModel( + artifact_id=artifact_id, + artifact_type=artifact_type, + uri=uri, + task_id=task_id, + step_id=step_id, + workflow_run_id=workflow_run_id, + workflow_run_block_id=workflow_run_block_id, + observer_cruise_id=task_v2_id, + observer_thought_id=thought_id, + run_id=run_id, + ai_suggestion_id=ai_suggestion_id, + organization_id=organization_id, + ) + session.add(new_artifact) + await session.commit() + await session.refresh(new_artifact) + return convert_to_artifact(new_artifact, self.debug_enabled) + + @db_operation("bulk_create_artifacts") + async def bulk_create_artifacts( + self, + artifact_models: list[ArtifactModel], + ) -> list[Artifact]: + """ + Bulk create multiple artifacts in a single database transaction. + + Args: + artifact_models: List of ArtifactModel instances to insert + + Returns: + List of created Artifact objects + """ + if not artifact_models: + return [] + + async with self.Session() as session: + session.add_all(artifact_models) + await session.commit() + + # Refresh all artifacts to get their created_at and modified_at values + for artifact in artifact_models: + await session.refresh(artifact) + + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifact_models] + + @db_operation("get_artifacts_for_task_v2") + async def get_artifacts_for_task_v2( + self, + task_v2_id: str, + organization_id: str | None = None, + artifact_types: list[ArtifactType] | None = None, + ) -> list[Artifact]: + async with self.Session() as session: + query = ( + select(ArtifactModel) + .filter_by(observer_cruise_id=task_v2_id) + .filter_by(organization_id=organization_id) + ) + if artifact_types: + query = query.filter(ArtifactModel.artifact_type.in_(artifact_types)) + + query = query.order_by(ArtifactModel.created_at) + if artifacts := (await session.scalars(query)).all(): + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] + else: + return [] + + @db_operation("get_artifacts_for_task_step") + async def get_artifacts_for_task_step( + self, + task_id: str, + step_id: str, + organization_id: str | None = None, + ) -> list[Artifact]: + async with self.Session() as session: + if artifacts := ( + await session.scalars( + select(ArtifactModel) + .filter_by(task_id=task_id) + .filter_by(step_id=step_id) + .filter_by(organization_id=organization_id) + .order_by(ArtifactModel.created_at) + ) + ).all(): + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] + else: + return [] + + @db_operation("get_artifacts_for_run") + async def get_artifacts_for_run( + self, + run_id: str, + organization_id: str, + artifact_types: list[ArtifactType] | None = None, + group_by_type: bool = False, + sort_by: str = "created_at", + ) -> dict[ArtifactType, list[Artifact]] | list[Artifact]: + """Return artifacts associated with a run. + + Args: + run_id: The ID of the run to get artifacts for + organization_id: The ID of the organization that owns the run + artifact_types: Optional list of artifact types to filter by + group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts. + If False, returns a flat list of artifacts. Defaults to False. + sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'. + Defaults to 'created_at'. + + Returns: + If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts. + If group_by_type is False, returns a list of artifacts sorted by the specified field. + + Raises: + ValueError: If sort_by is not one of the allowed values + """ + allowed_sort_fields = {"created_at", "step_id", "task_id"} + if sort_by not in allowed_sort_fields: + raise ValueError(f"sort_by must be one of {allowed_sort_fields}") + if self._run_reader is None: + raise RuntimeError("run_reader dependency not set") + run = await self._run_reader.get_run(run_id, organization_id=organization_id) + if not run: + return [] + + async with self.Session() as session: + query = select(ArtifactModel).filter_by(organization_id=organization_id) + + query = query.filter_by(run_id=run.run_id) + + if artifact_types: + query = query.filter(ArtifactModel.artifact_type.in_(artifact_types)) + + # Apply sorting + if sort_by == "created_at": + query = query.order_by(ArtifactModel.created_at) + elif sort_by == "step_id": + query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at) + elif sort_by == "task_id": + query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at) + + # Execute query and convert to Artifact objects + artifacts = [ + convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all() + ] + + # Group artifacts by type if requested + if group_by_type: + result: dict[ArtifactType, list[Artifact]] = {} + for artifact in artifacts: + if artifact.artifact_type not in result: + result[artifact.artifact_type] = [] + result[artifact.artifact_type].append(artifact) + return result + + return artifacts + + @db_operation("get_artifact_by_id") + async def get_artifact_by_id( + self, + artifact_id: str, + organization_id: str, + ) -> Artifact | None: + async with self.Session() as session: + if artifact := ( + await session.scalars( + select(ArtifactModel).filter_by(artifact_id=artifact_id).filter_by(organization_id=organization_id) + ) + ).first(): + return convert_to_artifact(artifact, self.debug_enabled) + else: + return None + + async def get_artifact_by_id_no_org( + self, + artifact_id: str, + ) -> Artifact | None: + """Fetch an artifact by ID without an organization filter. + + Only use this when the caller has already verified authorization through + an out-of-band mechanism (e.g. a valid HMAC-signed URL). + """ + async with self.Session() as session: + if artifact := (await session.scalars(select(ArtifactModel).filter_by(artifact_id=artifact_id))).first(): + return convert_to_artifact(artifact, self.debug_enabled) + else: + return None + + @db_operation("get_artifacts_by_ids") + async def get_artifacts_by_ids( + self, + artifact_ids: list[str], + organization_id: str, + ) -> list[Artifact]: + if not artifact_ids: + return [] + async with self.Session() as session: + artifacts = ( + await session.scalars( + select(ArtifactModel) + .filter(ArtifactModel.artifact_id.in_(artifact_ids)) + .filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] + + @db_operation("get_artifacts_by_entity_id") + async def get_artifacts_by_entity_id( + self, + *, + organization_id: str | None, + artifact_type: ArtifactType | None = None, + task_id: str | None = None, + step_id: str | None = None, + workflow_run_id: str | None = None, + workflow_run_block_id: str | None = None, + thought_id: str | None = None, + task_v2_id: str | None = None, + limit: int | None = None, + ) -> list[Artifact]: + async with self.Session() as session: + # Build base query + query = select(ArtifactModel) + + if artifact_type is not None: + query = query.filter_by(artifact_type=artifact_type) + if task_id is not None: + query = query.filter_by(task_id=task_id) + if step_id is not None: + query = query.filter_by(step_id=step_id) + if workflow_run_id is not None: + query = query.filter_by(workflow_run_id=workflow_run_id) + if workflow_run_block_id is not None: + query = query.filter_by(workflow_run_block_id=workflow_run_block_id) + if thought_id is not None: + query = query.filter_by(observer_thought_id=thought_id) + if task_v2_id is not None: + query = query.filter_by(observer_cruise_id=task_v2_id) + # Handle backward compatibility where old artifact rows were stored with organization_id NULL + if organization_id is not None: + query = query.filter( + or_(ArtifactModel.organization_id == organization_id, ArtifactModel.organization_id.is_(None)) + ) + + query = query.order_by(ArtifactModel.created_at.desc()) + + if limit is not None: + query = query.limit(limit) + + artifacts = (await session.scalars(query)).all() + LOG.debug("Artifacts fetched", count=len(artifacts)) + return [convert_to_artifact(a, self.debug_enabled) for a in artifacts] + + @db_operation("get_artifact_by_entity_id") + async def get_artifact_by_entity_id( + self, + *, + artifact_type: ArtifactType, + organization_id: str, + task_id: str | None = None, + step_id: str | None = None, + workflow_run_id: str | None = None, + workflow_run_block_id: str | None = None, + thought_id: str | None = None, + task_v2_id: str | None = None, + ) -> Artifact | None: + artifacts = await self.get_artifacts_by_entity_id( + organization_id=organization_id, + artifact_type=artifact_type, + task_id=task_id, + step_id=step_id, + workflow_run_id=workflow_run_id, + workflow_run_block_id=workflow_run_block_id, + thought_id=thought_id, + task_v2_id=task_v2_id, + limit=1, + ) + return artifacts[0] if artifacts else None + + @db_operation("get_artifact") + async def get_artifact( + self, + task_id: str, + step_id: str, + artifact_type: ArtifactType, + organization_id: str | None = None, + ) -> Artifact | None: + async with self.Session() as session: + artifact = ( + await session.scalars( + select(ArtifactModel) + .filter_by(task_id=task_id) + .filter_by(step_id=step_id) + .filter_by(organization_id=organization_id) + .filter_by(artifact_type=artifact_type) + .order_by(ArtifactModel.created_at.desc()) + ) + ).first() + if artifact: + return convert_to_artifact(artifact, self.debug_enabled) + return None + + @db_operation("get_artifact_for_run") + async def get_artifact_for_run( + self, + run_id: str, + artifact_type: ArtifactType, + organization_id: str | None = None, + ) -> Artifact | None: + async with self.Session() as session: + artifact = ( + await session.scalars( + select(ArtifactModel) + .filter(ArtifactModel.run_id == run_id) + .filter(ArtifactModel.artifact_type == artifact_type) + .filter(ArtifactModel.organization_id == organization_id) + .order_by(ArtifactModel.created_at.desc()) + ) + ).first() + if artifact: + return convert_to_artifact(artifact, self.debug_enabled) + return None + + @db_operation("get_latest_artifact") + async def get_latest_artifact( + self, + task_id: str, + step_id: str | None = None, + artifact_types: list[ArtifactType] | None = None, + organization_id: str | None = None, + ) -> Artifact | None: + artifacts = await self.get_latest_n_artifacts( + task_id=task_id, + step_id=step_id, + artifact_types=artifact_types, + organization_id=organization_id, + n=1, + ) + if artifacts: + return artifacts[0] + return None + + @db_operation("get_latest_n_artifacts") + async def get_latest_n_artifacts( + self, + task_id: str, + step_id: str | None = None, + artifact_types: list[ArtifactType] | None = None, + organization_id: str | None = None, + n: int = 1, + ) -> list[Artifact] | None: + async with self.Session() as session: + artifact_query = select(ArtifactModel).filter_by(task_id=task_id) + if organization_id: + artifact_query = artifact_query.filter_by(organization_id=organization_id) + if step_id: + artifact_query = artifact_query.filter_by(step_id=step_id) + if artifact_types: + artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types)) + + artifacts = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).fetchmany(n) + if artifacts: + return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts] + return None + + @db_operation("delete_task_artifacts") + async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None: + async with self.Session() as session: + # delete artifacts by filtering organization_id and task_id + stmt = delete(ArtifactModel).where( + and_( + ArtifactModel.organization_id == organization_id, + ArtifactModel.task_id == task_id, + ) + ) + await session.execute(stmt) + await session.commit() + + @db_operation("delete_task_v2_artifacts") + async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None: + async with self.Session() as session: + stmt = delete(ArtifactModel).where( + and_( + ArtifactModel.observer_cruise_id == task_v2_id, + ArtifactModel.organization_id == organization_id, + ) + ) + await session.execute(stmt) + await session.commit() + + @db_operation("update_action_screenshot_artifact_id") + async def update_action_screenshot_artifact_id( + self, *, organization_id: str, action_id: str, screenshot_artifact_id: str + ) -> None: + async with self.Session() as session: + await session.execute( + update(ActionModel) + .where(ActionModel.action_id == action_id, ActionModel.organization_id == organization_id) + .values(screenshot_artifact_id=screenshot_artifact_id) + ) + await session.commit() diff --git a/skyvern/forge/sdk/db/repositories/browser_sessions.py b/skyvern/forge/sdk/db/repositories/browser_sessions.py new file mode 100644 index 000000000..7af27782c --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/browser_sessions.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone + +import structlog +from sqlalchemy import asc, case, select + +from skyvern.exceptions import BrowserProfileNotFound +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_alchemy_db import read_retry +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + BrowserProfileModel, + DebugSessionModel, + PersistentBrowserSessionModel, +) +from skyvern.forge.sdk.db.utils import serialize_proxy_location +from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile +from skyvern.forge.sdk.schemas.debug_sessions import DebugSession +from skyvern.forge.sdk.schemas.persistent_browser_sessions import ( + Extensions, + PersistentBrowserSession, + PersistentBrowserType, +) +from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput + +LOG = structlog.get_logger() + + +class BrowserSessionsRepository(BaseRepository): + """Database operations for browser profiles and persistent browser sessions.""" + + @db_operation("create_browser_profile") + async def create_browser_profile( + self, + organization_id: str, + name: str, + description: str | None = None, + ) -> BrowserProfile: + async with self.Session() as session: + browser_profile = BrowserProfileModel( + organization_id=organization_id, + name=name, + description=description, + ) + session.add(browser_profile) + await session.commit() + await session.refresh(browser_profile) + return BrowserProfile.model_validate(browser_profile) + + @db_operation("get_browser_profile") + async def get_browser_profile( + self, + profile_id: str, + organization_id: str, + include_deleted: bool = False, + ) -> BrowserProfile | None: + async with self.Session() as session: + query = ( + select(BrowserProfileModel) + .filter_by(browser_profile_id=profile_id) + .filter_by(organization_id=organization_id) + ) + if not include_deleted: + query = query.filter(BrowserProfileModel.deleted_at.is_(None)) + + browser_profile = (await session.scalars(query)).first() + if not browser_profile: + return None + return BrowserProfile.model_validate(browser_profile) + + @db_operation("list_browser_profiles") + async def list_browser_profiles( + self, + organization_id: str, + include_deleted: bool = False, + ) -> list[BrowserProfile]: + async with self.Session() as session: + query = select(BrowserProfileModel).filter_by(organization_id=organization_id) + if not include_deleted: + query = query.filter(BrowserProfileModel.deleted_at.is_(None)) + browser_profiles = await session.scalars(query.order_by(asc(BrowserProfileModel.created_at))) + return [BrowserProfile.model_validate(profile) for profile in browser_profiles.all()] + + @db_operation("delete_browser_profile") + async def delete_browser_profile( + self, + profile_id: str, + organization_id: str, + ) -> None: + async with self.Session() as session: + query = ( + select(BrowserProfileModel) + .filter_by(browser_profile_id=profile_id) + .filter_by(organization_id=organization_id) + .filter(BrowserProfileModel.deleted_at.is_(None)) + ) + browser_profile = (await session.scalars(query)).first() + if not browser_profile: + raise BrowserProfileNotFound(profile_id=profile_id, organization_id=organization_id) + browser_profile.deleted_at = datetime.now(timezone.utc) + await session.commit() + + @db_operation("get_active_persistent_browser_sessions") + async def get_active_persistent_browser_sessions( + self, + organization_id: str, + active_hours: int = 24, + ) -> list[PersistentBrowserSession]: + """Get all active persistent browser sessions for an organization.""" + async with self.Session() as session: + result = await session.execute( + select(PersistentBrowserSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + .filter_by(completed_at=None) + .filter( + PersistentBrowserSessionModel.created_at + > datetime.now(timezone.utc) - timedelta(hours=active_hours) + ) + ) + sessions = result.scalars().all() + return [PersistentBrowserSession.model_validate(session) for session in sessions] + + @db_operation("get_persistent_browser_sessions_history") + async def get_persistent_browser_sessions_history( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + lookback_hours: int = 24 * 7, + ) -> list[PersistentBrowserSession]: + """Get persistent browser sessions history for an organization.""" + async with self.Session() as session: + open_first = case( + ( + PersistentBrowserSessionModel.status == "running", + 0, # open + ), + else_=1, # not open + ) + + result = await session.execute( + select(PersistentBrowserSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + .filter( + PersistentBrowserSessionModel.created_at + > (datetime.now(timezone.utc) - timedelta(hours=lookback_hours)) + ) + .order_by( + open_first.asc(), # open sessions first + PersistentBrowserSessionModel.created_at.desc(), # then newest within each group + ) + .offset((page - 1) * page_size) + .limit(page_size) + ) + sessions = result.scalars().all() + return [PersistentBrowserSession.model_validate(session) for session in sessions] + + @read_retry() + @db_operation("get_persistent_browser_session_by_runnable_id", log_errors=False) + async def get_persistent_browser_session_by_runnable_id( + self, runnable_id: str, organization_id: str | None = None + ) -> PersistentBrowserSession | None: + """Get a specific persistent browser session.""" + async with self.Session() as session: + query = ( + select(PersistentBrowserSessionModel) + .filter_by(runnable_id=runnable_id) + .filter_by(deleted_at=None) + .filter_by(completed_at=None) + ) + if organization_id: + query = query.filter_by(organization_id=organization_id) + persistent_browser_session = (await session.scalars(query)).first() + if persistent_browser_session: + return PersistentBrowserSession.model_validate(persistent_browser_session) + return None + + @db_operation("get_persistent_browser_session") + async def get_persistent_browser_session( + self, + session_id: str, + organization_id: str | None = None, + ) -> PersistentBrowserSession | None: + """Get a specific persistent browser session.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + return PersistentBrowserSession.model_validate(persistent_browser_session) + return None + + @db_operation("create_persistent_browser_session") + async def create_persistent_browser_session( + self, + organization_id: str, + runnable_type: str | None = None, + runnable_id: str | None = None, + timeout_minutes: int | None = None, + proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL, + extensions: list[Extensions] | None = None, + browser_type: PersistentBrowserType | None = None, + browser_profile_id: str | None = None, + ) -> PersistentBrowserSession: + """Create a new persistent browser session.""" + extensions_str: list[str] | None = ( + [extension.value for extension in extensions] if extensions is not None else None + ) + async with self.Session() as session: + browser_session = PersistentBrowserSessionModel( + organization_id=organization_id, + runnable_type=runnable_type, + runnable_id=runnable_id, + timeout_minutes=timeout_minutes, + proxy_location=serialize_proxy_location(proxy_location), + extensions=extensions_str, + browser_type=browser_type.value if browser_type else None, + browser_profile_id=browser_profile_id, + ) + session.add(browser_session) + await session.commit() + await session.refresh(browser_session) + return PersistentBrowserSession.model_validate(browser_session) + + @db_operation("update_persistent_browser_session") + async def update_persistent_browser_session( + self, + browser_session_id: str, + *, + status: str | None = None, + timeout_minutes: int | None = None, + organization_id: str | None = None, + completed_at: datetime | None = None, + ) -> PersistentBrowserSession: + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=browser_session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if not persistent_browser_session: + raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found") + + if status: + persistent_browser_session.status = status + if timeout_minutes: + persistent_browser_session.timeout_minutes = timeout_minutes + if completed_at: + persistent_browser_session.completed_at = completed_at + + await session.commit() + await session.refresh(persistent_browser_session) + return PersistentBrowserSession.model_validate(persistent_browser_session) + + @db_operation("set_persistent_browser_session_browser_address") + async def set_persistent_browser_session_browser_address( + self, + browser_session_id: str, + browser_address: str | None, + ip_address: str | None, + ecs_task_arn: str | None, + organization_id: str | None = None, + ) -> None: + """Set the browser address for a persistent browser session.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=browser_session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + if browser_address: + persistent_browser_session.browser_address = browser_address + # once the address is set, the session is started + persistent_browser_session.started_at = datetime.now(timezone.utc) + if ip_address: + persistent_browser_session.ip_address = ip_address + if ecs_task_arn: + persistent_browser_session.ecs_task_arn = ecs_task_arn + await session.commit() + await session.refresh(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found") + + @db_operation("update_persistent_browser_session_compute_cost") + async def update_persistent_browser_session_compute_cost( + self, + session_id: str, + organization_id: str, + instance_type: str, + vcpu_millicores: int, + memory_mb: int, + duration_ms: int, + compute_cost: float, + ) -> None: + """Update the compute cost fields for a persistent browser session""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.instance_type = instance_type + persistent_browser_session.vcpu_millicores = vcpu_millicores + persistent_browser_session.memory_mb = memory_mb + persistent_browser_session.duration_ms = duration_ms + persistent_browser_session.compute_cost = compute_cost + await session.commit() + await session.refresh(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + + @db_operation("mark_persistent_browser_session_deleted") + async def mark_persistent_browser_session_deleted(self, session_id: str, organization_id: str) -> None: + """Mark a persistent browser session as deleted.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.deleted_at = datetime.now(timezone.utc) + await session.commit() + await session.refresh(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + + @db_operation("occupy_persistent_browser_session") + async def occupy_persistent_browser_session( + self, session_id: str, runnable_type: str, runnable_id: str, organization_id: str + ) -> None: + """Occupy a specific persistent browser session.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.runnable_type = runnable_type + persistent_browser_session.runnable_id = runnable_id + await session.commit() + await session.refresh(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + + @db_operation("release_persistent_browser_session") + async def release_persistent_browser_session( + self, + session_id: str, + organization_id: str, + ) -> PersistentBrowserSession: + """Release a specific persistent browser session.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.runnable_type = None + persistent_browser_session.runnable_id = None + await session.commit() + await session.refresh(persistent_browser_session) + return PersistentBrowserSession.model_validate(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + + @db_operation("close_persistent_browser_session") + async def close_persistent_browser_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession: + """Close a specific persistent browser session.""" + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + if persistent_browser_session.completed_at: + return PersistentBrowserSession.model_validate(persistent_browser_session) + persistent_browser_session.completed_at = datetime.now(timezone.utc) + persistent_browser_session.status = "completed" + await session.commit() + await session.refresh(persistent_browser_session) + return PersistentBrowserSession.model_validate(persistent_browser_session) + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + + @db_operation("get_all_active_persistent_browser_sessions") + async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]: + """Get all active persistent browser sessions across all organizations.""" + async with self.Session() as session: + result = await session.execute(select(PersistentBrowserSessionModel).filter_by(deleted_at=None)) + return result.scalars().all() + + @db_operation("archive_browser_session_address") + async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None: + """Suffix browser_address with a unique tag so the unique constraint + no longer blocks new sessions that reuse the same local address.""" + async with self.Session() as session: + row = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + + if not row or not row.browser_address: + return + if "::closed::" in row.browser_address: + return + + row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}" + await session.commit() + + @db_operation("get_uncompleted_persistent_browser_sessions") + async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]: + """Get all browser sessions that have not been completed or deleted.""" + async with self.Session() as session: + result = await session.execute( + select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None) + ) + return result.scalars().all() + + @db_operation("get_debug_session_by_browser_session_id") + async def get_debug_session_by_browser_session_id( + self, + browser_session_id: str, + organization_id: str, + ) -> DebugSession | None: + async with self.Session() as session: + query = ( + select(DebugSessionModel) + .filter_by(browser_session_id=browser_session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + model = (await session.scalars(query)).first() + return DebugSession.model_validate(model) if model else None diff --git a/skyvern/forge/sdk/db/repositories/credentials.py b/skyvern/forge/sdk/db/repositories/credentials.py new file mode 100644 index 000000000..9eb0ecf22 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/credentials.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +from sqlalchemy import select + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import CredentialModel, OrganizationBitwardenCollectionModel +from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType +from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection + +from . import _UNSET + + +class CredentialRepository(BaseRepository): + """Database operations for credential and Bitwarden collection management.""" + + @db_operation("create_credential") + async def create_credential( + self, + organization_id: str, + name: str, + vault_type: CredentialVaultType, + item_id: str, + credential_type: CredentialType, + username: str | None, + totp_type: str, + card_last4: str | None, + card_brand: str | None, + totp_identifier: str | None = None, + secret_label: str | None = None, + ) -> Credential: + async with self.Session() as session: + credential = CredentialModel( + organization_id=organization_id, + name=name, + vault_type=vault_type, + item_id=item_id, + credential_type=credential_type, + username=username, + totp_type=totp_type, + totp_identifier=totp_identifier, + card_last4=card_last4, + card_brand=card_brand, + secret_label=secret_label, + ) + session.add(credential) + await session.commit() + await session.refresh(credential) + return Credential.model_validate(credential) + + @db_operation("get_credential") + async def get_credential(self, credential_id: str, organization_id: str) -> Credential | None: + async with self.Session() as session: + credential = ( + await session.scalars( + select(CredentialModel) + .filter_by(credential_id=credential_id) + .filter_by(organization_id=organization_id) + .filter(CredentialModel.deleted_at.is_(None)) + ) + ).first() + if credential: + return Credential.model_validate(credential) + return None + + @db_operation("get_credentials") + async def get_credentials( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + vault_type: str | None = None, + ) -> list[Credential]: + async with self.Session() as session: + query = ( + select(CredentialModel) + .filter_by(organization_id=organization_id) + .filter(CredentialModel.deleted_at.is_(None)) + ) + if vault_type is not None: + query = query.filter(CredentialModel.vault_type == vault_type) + credentials = ( + await session.scalars( + query.order_by(CredentialModel.created_at.desc()).offset((page - 1) * page_size).limit(page_size) + ) + ).all() + return [Credential.model_validate(credential) for credential in credentials] + + @db_operation("update_credential") + async def update_credential( + self, + credential_id: str, + organization_id: str, + name: str | None = None, + browser_profile_id: str | None | object = _UNSET, + tested_url: str | None | object = _UNSET, + ) -> Credential: + async with self.Session() as session: + credential = ( + await session.scalars( + select(CredentialModel) + .filter_by(credential_id=credential_id) + .filter_by(organization_id=organization_id) + .filter(CredentialModel.deleted_at.is_(None)) + ) + ).first() + if not credential: + raise NotFoundError(f"Credential {credential_id} not found") + if name is not None: + credential.name = name + if browser_profile_id is not _UNSET: + credential.browser_profile_id = browser_profile_id + if tested_url is not _UNSET: + credential.tested_url = tested_url + await session.commit() + await session.refresh(credential) + return Credential.model_validate(credential) + + @db_operation("update_credential_vault_data") + async def update_credential_vault_data( + self, + credential_id: str, + organization_id: str, + item_id: str, + name: str, + credential_type: CredentialType, + username: str | None = None, + totp_type: str = "none", + totp_identifier: str | None = None, + card_last4: str | None = None, + card_brand: str | None = None, + secret_label: str | None = None, + ) -> Credential: + async with self.Session() as session: + credential = ( + await session.scalars( + select(CredentialModel) + .filter_by(credential_id=credential_id) + .filter_by(organization_id=organization_id) + .filter(CredentialModel.deleted_at.is_(None)) + .with_for_update() + ) + ).first() + if not credential: + raise NotFoundError(f"Credential {credential_id} not found") + credential.item_id = item_id + credential.name = name + credential.credential_type = credential_type + credential.username = username + credential.totp_type = totp_type + credential.totp_identifier = totp_identifier + credential.card_last4 = card_last4 + credential.card_brand = card_brand + credential.secret_label = secret_label + await session.commit() + await session.refresh(credential) + return Credential.model_validate(credential) + + @db_operation("delete_credential") + async def delete_credential(self, credential_id: str, organization_id: str) -> None: + async with self.Session() as session: + credential = ( + await session.scalars( + select(CredentialModel) + .filter_by(credential_id=credential_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if not credential: + raise NotFoundError(f"Credential {credential_id} not found") + credential.deleted_at = datetime.now(timezone.utc) + await session.commit() + await session.refresh(credential) + return None + + @db_operation("create_organization_bitwarden_collection") + async def create_organization_bitwarden_collection( + self, + organization_id: str, + collection_id: str, + ) -> OrganizationBitwardenCollection: + async with self.Session() as session: + organization_bitwarden_collection = OrganizationBitwardenCollectionModel( + organization_id=organization_id, collection_id=collection_id + ) + session.add(organization_bitwarden_collection) + await session.commit() + await session.refresh(organization_bitwarden_collection) + return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection) + + @db_operation("get_organization_bitwarden_collection") + async def get_organization_bitwarden_collection( + self, + organization_id: str, + ) -> OrganizationBitwardenCollection | None: + async with self.Session() as session: + organization_bitwarden_collection = ( + await session.scalars( + select(OrganizationBitwardenCollectionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if organization_bitwarden_collection: + return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection) + return None diff --git a/skyvern/forge/sdk/db/repositories/debug.py b/skyvern/forge/sdk/db/repositories/debug.py new file mode 100644 index 000000000..0627a4d86 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/debug.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from sqlalchemy import select + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.models import ( + BlockRunModel, + DebugSessionModel, + WorkflowRunModel, +) +from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession, DebugSessionRun +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus + + +class DebugRepository(BaseRepository): + """Database operations for debug sessions and block runs.""" + + @db_operation("get_debug_session") + async def get_debug_session( + self, + *, + organization_id: str, + user_id: str, + workflow_permanent_id: str, + ) -> DebugSession | None: + async with self.Session() as session: + debug_session = ( + await session.scalars( + select(DebugSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(user_id=user_id) + .filter_by(deleted_at=None) + .filter_by(status="created") + .order_by(DebugSessionModel.created_at.desc()) + ) + ).first() + + if not debug_session: + return None + + return DebugSession.model_validate(debug_session) + + @db_operation("get_latest_block_run") + async def get_latest_block_run( + self, + *, + organization_id: str, + user_id: str, + block_label: str, + ) -> BlockRun | None: + async with self.Session() as session: + query = ( + select(BlockRunModel) + .filter_by(organization_id=organization_id) + .filter_by(user_id=user_id) + .filter_by(block_label=block_label) + .order_by(BlockRunModel.created_at.desc()) + ) + + model = (await session.scalars(query)).first() + + return BlockRun.model_validate(model) if model else None + + @db_operation("get_latest_completed_block_run") + async def get_latest_completed_block_run( + self, + *, + organization_id: str, + user_id: str, + block_label: str, + workflow_permanent_id: str, + ) -> BlockRun | None: + async with self.Session() as session: + query = ( + select(BlockRunModel) + .join(WorkflowRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .filter(BlockRunModel.organization_id == organization_id) + .filter(BlockRunModel.user_id == user_id) + .filter(BlockRunModel.block_label == block_label) + .filter(WorkflowRunModel.status == WorkflowRunStatus.completed) + .filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id) + .order_by(BlockRunModel.created_at.desc()) + ) + + model = (await session.scalars(query)).first() + + return BlockRun.model_validate(model) if model else None + + @db_operation("create_block_run") + async def create_block_run( + self, + *, + organization_id: str, + user_id: str, + block_label: str, + output_parameter_id: str, + workflow_run_id: str, + ) -> None: + async with self.Session() as session: + block_run = BlockRunModel( + organization_id=organization_id, + user_id=user_id, + block_label=block_label, + output_parameter_id=output_parameter_id, + workflow_run_id=workflow_run_id, + ) + + session.add(block_run) + + await session.commit() + + @db_operation("get_latest_debug_session_for_user") + async def get_latest_debug_session_for_user( + self, + *, + organization_id: str, + user_id: str, + workflow_permanent_id: str, + ) -> DebugSession | None: + async with self.Session() as session: + query = ( + select(DebugSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + .filter_by(status="created") + .filter_by(user_id=user_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .order_by(DebugSessionModel.created_at.desc()) + ) + + model = (await session.scalars(query)).first() + + return DebugSession.model_validate(model) if model else None + + @db_operation("get_debug_session_by_id") + async def get_debug_session_by_id( + self, + debug_session_id: str, + organization_id: str, + ) -> DebugSession | None: + async with self.Session() as session: + query = ( + select(DebugSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + .filter_by(debug_session_id=debug_session_id) + ) + + model = (await session.scalars(query)).first() + + return DebugSession.model_validate(model) if model else None + + @db_operation("get_workflow_runs_by_debug_session_id") + async def get_workflow_runs_by_debug_session_id( + self, + debug_session_id: str, + organization_id: str, + ) -> list[DebugSessionRun]: + async with self.Session() as session: + query = ( + select(WorkflowRunModel, BlockRunModel) + .join(BlockRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .filter(WorkflowRunModel.organization_id == organization_id) + .filter(WorkflowRunModel.debug_session_id == debug_session_id) + .order_by(WorkflowRunModel.created_at.desc()) + ) + + results = (await session.execute(query)).all() + + debug_session_runs = [] + for workflow_run, block_run in results: + debug_session_runs.append( + DebugSessionRun( + ai_fallback=workflow_run.ai_fallback, + block_label=block_run.block_label, + browser_session_id=workflow_run.browser_session_id, + code_gen=workflow_run.code_gen, + debug_session_id=workflow_run.debug_session_id, + failure_reason=workflow_run.failure_reason, + output_parameter_id=block_run.output_parameter_id, + run_with=workflow_run.run_with, + script_run_id=workflow_run.script_run.get("script_run_id") if workflow_run.script_run else None, + status=workflow_run.status, + workflow_id=workflow_run.workflow_id, + workflow_permanent_id=workflow_run.workflow_permanent_id, + workflow_run_id=workflow_run.workflow_run_id, + created_at=workflow_run.created_at, + queued_at=workflow_run.queued_at, + started_at=workflow_run.started_at, + finished_at=workflow_run.finished_at, + ) + ) + + return debug_session_runs + + @db_operation("complete_debug_sessions") + async def complete_debug_sessions( + self, + *, + organization_id: str, + user_id: str | None = None, + workflow_permanent_id: str | None = None, + ) -> list[DebugSession]: + async with self.Session() as session: + query = ( + select(DebugSessionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + .filter_by(status="created") + ) + + if user_id: + query = query.filter_by(user_id=user_id) + if workflow_permanent_id: + query = query.filter_by(workflow_permanent_id=workflow_permanent_id) + + models = (await session.scalars(query)).all() + + for model in models: + model.status = "completed" + + debug_sessions = [DebugSession.model_validate(model) for model in models] + + await session.commit() + + return debug_sessions + + @db_operation("create_debug_session") + async def create_debug_session( + self, + *, + browser_session_id: str, + organization_id: str, + user_id: str, + workflow_permanent_id: str, + vnc_streaming_supported: bool, + ) -> DebugSession: + async with self.Session() as session: + debug_session = DebugSessionModel( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + user_id=user_id, + browser_session_id=browser_session_id, + vnc_streaming_supported=vnc_streaming_supported, + status="created", + ) + + session.add(debug_session) + await session.commit() + await session.refresh(debug_session) + + return DebugSession.model_validate(debug_session) diff --git a/skyvern/forge/sdk/db/repositories/folders.py b/skyvern/forge/sdk/db/repositories/folders.py new file mode 100644 index 000000000..f85ac5d63 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/folders.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Callable + +from sqlalchemy import func, or_, select, update +from sqlalchemy.exc import SQLAlchemyError + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.models import FolderModel, WorkflowModel +from skyvern.forge.sdk.db.protocols import WorkflowReader +from skyvern.forge.sdk.db.utils import convert_to_workflow +from skyvern.forge.sdk.workflow.models.workflow import Workflow + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + + +class FoldersRepository(BaseRepository): + """Database operations for folder management.""" + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + workflow_reader: WorkflowReader | None = None, + ) -> None: + super().__init__(session_factory, debug_enabled, is_retryable_error_fn) + self._workflow_reader = workflow_reader + + @db_operation("create_folder") + async def create_folder( + self, + organization_id: str, + title: str, + description: str | None = None, + ) -> FolderModel: + """Create a new folder.""" + async with self.Session() as session: + folder = FolderModel( + organization_id=organization_id, + title=title, + description=description, + ) + session.add(folder) + await session.commit() + await session.refresh(folder) + return folder + + @db_operation("get_folders") + async def get_folders( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + search_query: str | None = None, + ) -> list[FolderModel]: + """Get all folders for an organization with pagination and optional search.""" + async with self.Session() as session: + stmt = ( + select(FolderModel).filter_by(organization_id=organization_id).filter(FolderModel.deleted_at.is_(None)) + ) + + if search_query: + search_pattern = f"%{search_query}%" + stmt = stmt.filter( + or_( + FolderModel.title.ilike(search_pattern), + FolderModel.description.ilike(search_pattern), + ) + ) + + stmt = stmt.order_by(FolderModel.modified_at.desc()) + stmt = stmt.offset((page - 1) * page_size).limit(page_size) + + result = await session.execute(stmt) + return list(result.scalars().all()) + + @db_operation("get_folder") + async def get_folder( + self, + folder_id: str, + organization_id: str, + ) -> FolderModel | None: + """Get a folder by ID.""" + async with self.Session() as session: + stmt = ( + select(FolderModel) + .filter_by(folder_id=folder_id, organization_id=organization_id) + .filter(FolderModel.deleted_at.is_(None)) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + @db_operation("update_folder") + async def update_folder( + self, + folder_id: str, + organization_id: str, + title: str | None = None, + description: str | None = None, + ) -> FolderModel | None: + """Update a folder's title or description.""" + async with self.Session() as session: + stmt = ( + select(FolderModel) + .filter_by(folder_id=folder_id, organization_id=organization_id) + .filter(FolderModel.deleted_at.is_(None)) + ) + result = await session.execute(stmt) + folder = result.scalar_one_or_none() + if not folder: + return None + + if title is not None: + folder.title = title + if description is not None: + folder.description = description + + folder.modified_at = datetime.now(timezone.utc) + await session.commit() + await session.refresh(folder) + return folder + + @db_operation("get_workflow_permanent_ids_in_folder") + async def get_workflow_permanent_ids_in_folder( + self, + folder_id: str, + organization_id: str, + ) -> list[str]: + """Get workflow permanent IDs (latest versions only) in a folder.""" + async with self.Session() as session: + # Subquery to get the latest version for each workflow + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + + # Get workflow_permanent_ids where the latest version is in this folder + stmt = ( + select(WorkflowModel.workflow_permanent_id) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .where(WorkflowModel.folder_id == folder_id) + ) + result = await session.execute(stmt) + return list(result.scalars().all()) + + @db_operation("soft_delete_folder") + async def soft_delete_folder( + self, + folder_id: str, + organization_id: str, + delete_workflows: bool = False, + ) -> bool: + """Soft delete a folder. Optionally delete all workflows in the folder.""" + async with self.Session() as session: + # Check if folder exists + folder_stmt = ( + select(FolderModel) + .filter_by(folder_id=folder_id, organization_id=organization_id) + .filter(FolderModel.deleted_at.is_(None)) + ) + folder_result = await session.execute(folder_stmt) + folder = folder_result.scalar_one_or_none() + if not folder: + return False + + # If delete_workflows is True, delete all workflows in the folder + if delete_workflows: + # Get workflow permanent IDs in the folder (inline logic) + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + + workflow_permanent_ids_stmt = ( + select(WorkflowModel.workflow_permanent_id) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .where(WorkflowModel.folder_id == folder_id) + ) + result = await session.execute(workflow_permanent_ids_stmt) + workflow_permanent_ids = list(result.scalars().all()) + + # Soft delete all workflows with these permanent IDs in a single bulk update + if workflow_permanent_ids: + update_workflows_query = ( + update(WorkflowModel) + .where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids)) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .values(deleted_at=datetime.now(timezone.utc)) + ) + await session.execute(update_workflows_query) + else: + # Just remove folder_id from all workflows in this folder + update_workflows_query = ( + update(WorkflowModel) + .where(WorkflowModel.folder_id == folder_id) + .where(WorkflowModel.organization_id == organization_id) + .values(folder_id=None, modified_at=datetime.now(timezone.utc)) + ) + await session.execute(update_workflows_query) + + # Soft delete the folder + folder.deleted_at = datetime.now(timezone.utc) + await session.commit() + return True + + @db_operation("get_folder_workflow_count") + async def get_folder_workflow_count( + self, + folder_id: str, + organization_id: str, + ) -> int: + """Get the count of workflows (latest versions only) in a folder.""" + async with self.Session() as session: + # Subquery to get the latest version for each workflow (same pattern as get_workflows_by_organization_id) + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + + # Count workflows where the latest version is in this folder + stmt = ( + select(func.count(WorkflowModel.workflow_permanent_id)) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .where(WorkflowModel.folder_id == folder_id) + ) + result = await session.execute(stmt) + return result.scalar_one() + + @db_operation("get_folder_workflow_counts_batch") + async def get_folder_workflow_counts_batch( + self, + folder_ids: list[str], + organization_id: str, + ) -> dict[str, int]: + """Get workflow counts for multiple folders in a single query.""" + async with self.Session() as session: + # Subquery to get the latest version for each workflow + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + + # Count workflows grouped by folder_id + stmt = ( + select( + WorkflowModel.folder_id, + func.count(WorkflowModel.workflow_permanent_id).label("count"), + ) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .where(WorkflowModel.folder_id.in_(folder_ids)) + .group_by(WorkflowModel.folder_id) + ) + result = await session.execute(stmt) + rows = result.all() + + # Convert to dict; folders with no workflows will be absent from the result + return {row.folder_id: row.count for row in rows} + + @db_operation("update_workflow_folder") + async def update_workflow_folder( + self, + workflow_permanent_id: str, + organization_id: str, + folder_id: str | None, + ) -> Workflow | None: + """Update folder assignment for the latest version of a workflow.""" + # Get the latest version of the workflow + if self._workflow_reader is None: + raise RuntimeError("workflow_reader dependency not set") + latest_workflow = await self._workflow_reader.get_workflow_by_permanent_id( + workflow_permanent_id=workflow_permanent_id, + organization_id=organization_id, + ) + + if not latest_workflow: + return None + + async with self.Session() as session: + # Validate folder exists in-org if folder_id is provided + if folder_id: + stmt = ( + select(FolderModel.folder_id) + .where(FolderModel.folder_id == folder_id) + .where(FolderModel.organization_id == organization_id) + .where(FolderModel.deleted_at.is_(None)) + ) + if (await session.scalar(stmt)) is None: + raise ValueError(f"Folder {folder_id} not found") + + workflow_model = await session.get(WorkflowModel, latest_workflow.workflow_id) + if workflow_model: + workflow_model.folder_id = folder_id + workflow_model.modified_at = datetime.now(timezone.utc) + + # Update folder's modified_at in the same transaction + if folder_id: + folder_model = await session.get(FolderModel, folder_id) + if folder_model: + folder_model.modified_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(workflow_model) + + return convert_to_workflow(workflow_model, self.debug_enabled) + return None diff --git a/skyvern/forge/sdk/db/repositories/observer.py b/skyvern/forge/sdk/db/repositories/observer.py new file mode 100644 index 000000000..11373e9b9 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/observer.py @@ -0,0 +1,588 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Callable + +from sqlalchemy import and_, delete, select +from sqlalchemy.exc import SQLAlchemyError + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_alchemy_db import read_retry +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + +from skyvern.forge.sdk.db.models import ( + TaskV2Model, + ThoughtModel, + WorkflowRunBlockModel, +) +from skyvern.forge.sdk.db.protocols import TaskReader +from skyvern.forge.sdk.db.utils import ( + convert_to_task_v2, + convert_to_workflow_run_block, + serialize_proxy_location, +) +from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType +from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock +from skyvern.schemas.runs import ProxyLocationInput, RunEngine +from skyvern.schemas.workflows import BlockStatus, BlockType + + +class ObserverRepository(BaseRepository): + """Database operations for observer tasks (TaskV2), thoughts, and workflow run blocks.""" + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + task_reader: TaskReader | None = None, + ) -> None: + super().__init__(session_factory, debug_enabled, is_retryable_error_fn) + self._task_reader = task_reader + + @read_retry() + @db_operation("get_task_v2", log_errors=False) + async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None: + async with self.Session() as session: + if task_v2 := ( + await session.scalars( + select(TaskV2Model) + .filter_by(observer_cruise_id=task_v2_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) + return None + + @db_operation("delete_thoughts") + async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None: + async with self.Session() as session: + stmt = delete(ThoughtModel).where( + and_( + ThoughtModel.observer_cruise_id == task_v2_id, + ThoughtModel.organization_id == organization_id, + ) + ) + await session.execute(stmt) + await session.commit() + + @db_operation("get_task_v2_by_workflow_run_id") + async def get_task_v2_by_workflow_run_id( + self, + workflow_run_id: str, + organization_id: str | None = None, + ) -> TaskV2 | None: + async with self.Session() as session: + if task_v2 := ( + await session.scalars( + select(TaskV2Model) + .filter_by(organization_id=organization_id) + .filter_by(workflow_run_id=workflow_run_id) + ) + ).first(): + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) + return None + + @db_operation("get_thought") + async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None: + async with self.Session() as session: + if thought := ( + await session.scalars( + select(ThoughtModel) + .filter_by(observer_thought_id=thought_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + return Thought.model_validate(thought) + return None + + @db_operation("get_thoughts") + async def get_thoughts( + self, + *, + task_v2_id: str, + thought_types: list[ThoughtType], + organization_id: str, + ) -> list[Thought]: + async with self.Session() as session: + query = ( + select(ThoughtModel) + .filter_by(observer_cruise_id=task_v2_id) + .filter_by(organization_id=organization_id) + .order_by(ThoughtModel.created_at) + ) + if thought_types: + query = query.filter(ThoughtModel.observer_thought_type.in_(thought_types)) + thoughts = (await session.scalars(query)).all() + return [Thought.model_validate(thought) for thought in thoughts] + + @db_operation("create_task_v2") + async def create_task_v2( + self, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, + prompt: str | None = None, + url: str | None = None, + organization_id: str | None = None, + proxy_location: ProxyLocationInput = None, + totp_identifier: str | None = None, + totp_verification_url: str | None = None, + webhook_callback_url: str | None = None, + extracted_information_schema: dict | list | str | None = None, + error_code_mapping: dict | None = None, + model: dict[str, Any] | None = None, + max_screenshot_scrolling_times: int | None = None, + extra_http_headers: dict[str, str] | None = None, + browser_address: str | None = None, + run_with: str | None = None, + ) -> TaskV2: + async with self.Session() as session: + new_task_v2 = TaskV2Model( + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + prompt=prompt, + url=url, + proxy_location=serialize_proxy_location(proxy_location), + totp_identifier=totp_identifier, + totp_verification_url=totp_verification_url, + webhook_callback_url=webhook_callback_url, + extracted_information_schema=extracted_information_schema, + error_code_mapping=error_code_mapping, + organization_id=organization_id, + model=model, + max_screenshot_scrolling_times=max_screenshot_scrolling_times, + extra_http_headers=extra_http_headers, + browser_address=browser_address, + run_with=run_with, + ) + session.add(new_task_v2) + await session.commit() + await session.refresh(new_task_v2) + return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled) + + @db_operation("create_thought") + async def create_thought( + self, + task_v2_id: str, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, + workflow_run_block_id: str | None = None, + user_input: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + thought_scenario: str | None = None, + thought_type: str = ThoughtType.plan, + output: dict[str, Any] | None = None, + input_token_count: int | None = None, + output_token_count: int | None = None, + reasoning_token_count: int | None = None, + cached_token_count: int | None = None, + thought_cost: float | None = None, + organization_id: str | None = None, + ) -> Thought: + async with self.Session() as session: + new_thought = ThoughtModel( + observer_cruise_id=task_v2_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + workflow_run_block_id=workflow_run_block_id, + user_input=user_input, + observation=observation, + thought=thought, + answer=answer, + observer_thought_scenario=thought_scenario, + observer_thought_type=thought_type, + output=output, + input_token_count=input_token_count, + output_token_count=output_token_count, + reasoning_token_count=reasoning_token_count, + cached_token_count=cached_token_count, + thought_cost=thought_cost, + organization_id=organization_id, + ) + session.add(new_thought) + await session.commit() + await session.refresh(new_thought) + return Thought.model_validate(new_thought) + + @db_operation("update_thought") + async def update_thought( + self, + thought_id: str, + workflow_run_block_id: str | None = None, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, + observation: str | None = None, + thought: str | None = None, + answer: str | None = None, + output: dict[str, Any] | None = None, + input_token_count: int | None = None, + output_token_count: int | None = None, + reasoning_token_count: int | None = None, + cached_token_count: int | None = None, + thought_cost: float | None = None, + organization_id: str | None = None, + ) -> Thought: + async with self.Session() as session: + thought_obj = ( + await session.scalars( + select(ThoughtModel) + .filter_by(observer_thought_id=thought_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if thought_obj: + if workflow_run_block_id: + thought_obj.workflow_run_block_id = workflow_run_block_id + if workflow_run_id: + thought_obj.workflow_run_id = workflow_run_id + if workflow_id: + thought_obj.workflow_id = workflow_id + if workflow_permanent_id: + thought_obj.workflow_permanent_id = workflow_permanent_id + if observation: + thought_obj.observation = observation + if thought: + thought_obj.thought = thought + if answer: + thought_obj.answer = answer + if output: + thought_obj.output = output + if input_token_count: + thought_obj.input_token_count = input_token_count + if output_token_count: + thought_obj.output_token_count = output_token_count + if reasoning_token_count: + thought_obj.reasoning_token_count = reasoning_token_count + if cached_token_count: + thought_obj.cached_token_count = cached_token_count + if thought_cost: + thought_obj.thought_cost = thought_cost + await session.commit() + await session.refresh(thought_obj) + return Thought.model_validate(thought_obj) + raise NotFoundError(f"Thought {thought_id}") + + @db_operation("update_task_v2") + async def update_task_v2( + self, + task_v2_id: str, + status: TaskV2Status | None = None, + workflow_run_id: str | None = None, + workflow_id: str | None = None, + workflow_permanent_id: str | None = None, + url: str | None = None, + prompt: str | None = None, + summary: str | None = None, + output: dict[str, Any] | None = None, + organization_id: str | None = None, + webhook_failure_reason: str | None = None, + failure_category: list[dict[str, Any]] | None = None, + ) -> TaskV2: + async with self.Session() as session: + task_v2 = ( + await session.scalars( + select(TaskV2Model) + .filter_by(observer_cruise_id=task_v2_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if task_v2: + if status: + task_v2.status = status + if status == TaskV2Status.queued and task_v2.queued_at is None: + task_v2.queued_at = datetime.now(timezone.utc) + if status == TaskV2Status.running and task_v2.started_at is None: + task_v2.started_at = datetime.now(timezone.utc) + if status.is_final() and task_v2.finished_at is None: + task_v2.finished_at = datetime.now(timezone.utc) + if workflow_run_id: + task_v2.workflow_run_id = workflow_run_id + if workflow_id: + task_v2.workflow_id = workflow_id + if workflow_permanent_id: + task_v2.workflow_permanent_id = workflow_permanent_id + if url: + task_v2.url = url + if prompt: + task_v2.prompt = prompt + if summary: + task_v2.summary = summary + if output: + task_v2.output = output + if webhook_failure_reason is not None: + task_v2.webhook_failure_reason = webhook_failure_reason + if failure_category is not None: + task_v2.failure_category = failure_category + await session.commit() + await session.refresh(task_v2) + return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled) + raise NotFoundError(f"TaskV2 {task_v2_id} not found") + + @db_operation("create_workflow_run_block") + async def create_workflow_run_block( + self, + workflow_run_id: str, + parent_workflow_run_block_id: str | None = None, + organization_id: str | None = None, + task_id: str | None = None, + label: str | None = None, + block_type: BlockType | None = None, + status: BlockStatus = BlockStatus.running, + output: dict | list | str | None = None, + continue_on_failure: bool = False, + engine: RunEngine | None = None, + current_value: str | None = None, + current_index: int | None = None, + ) -> WorkflowRunBlock: + async with self.Session() as session: + new_workflow_run_block = WorkflowRunBlockModel( + workflow_run_id=workflow_run_id, + parent_workflow_run_block_id=parent_workflow_run_block_id, + organization_id=organization_id, + task_id=task_id, + label=label, + block_type=block_type, + status=status, + output=output, + continue_on_failure=continue_on_failure, + engine=engine, + current_value=current_value, + current_index=current_index, + ) + session.add(new_workflow_run_block) + await session.commit() + await session.refresh(new_workflow_run_block) + + task = None + if task_id: + if self._task_reader is None: + raise RuntimeError("task_reader dependency not set") + task = await self._task_reader.get_task(task_id, organization_id=organization_id) + return convert_to_workflow_run_block(new_workflow_run_block, task=task) + + @db_operation("delete_workflow_run_blocks") + async def delete_workflow_run_blocks(self, workflow_run_id: str, organization_id: str | None = None) -> None: + async with self.Session() as session: + stmt = delete(WorkflowRunBlockModel).where( + and_( + WorkflowRunBlockModel.workflow_run_id == workflow_run_id, + WorkflowRunBlockModel.organization_id == organization_id, + ) + ) + await session.execute(stmt) + await session.commit() + + @db_operation("update_workflow_run_block") + async def update_workflow_run_block( + self, + workflow_run_block_id: str, + organization_id: str | None = None, + status: BlockStatus | None = None, + output: dict | list | str | None = None, + failure_reason: str | None = None, + task_id: str | None = None, + loop_values: list | None = None, + current_value: str | None = None, + current_index: int | None = None, + recipients: list[str] | None = None, + attachments: list[str] | None = None, + subject: str | None = None, + body: str | None = None, + prompt: str | None = None, + wait_sec: int | None = None, + description: str | None = None, + block_workflow_run_id: str | None = None, + engine: str | None = None, + # HTTP request block parameters + http_request_method: str | None = None, + http_request_url: str | None = None, + http_request_headers: dict[str, str] | None = None, + http_request_body: dict[str, Any] | None = None, + http_request_parameters: dict[str, Any] | None = None, + http_request_timeout: int | None = None, + http_request_follow_redirects: bool | None = None, + ai_fallback_triggered: bool | None = None, + # block-level error codes (e.g. ["FILE_PARSER_ERROR"]) + error_codes: list[str] | None = None, + # human interaction block + instructions: str | None = None, + positive_descriptor: str | None = None, + negative_descriptor: str | None = None, + # conditional block + executed_branch_id: str | None = None, + executed_branch_expression: str | None = None, + executed_branch_result: bool | None = None, + executed_branch_next_block: str | None = None, + ) -> WorkflowRunBlock: + async with self.Session() as session: + workflow_run_block = ( + await session.scalars( + select(WorkflowRunBlockModel) + .filter_by(workflow_run_block_id=workflow_run_block_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if workflow_run_block: + if status: + workflow_run_block.status = status + if output: + workflow_run_block.output = output + if task_id: + workflow_run_block.task_id = task_id + if failure_reason: + workflow_run_block.failure_reason = failure_reason + # Use `is not None` instead of truthiness checks so that falsy + # values like current_index=0, empty loop_values=[], or + # current_value="" are correctly persisted. Without this, + # the first loop iteration (index 0) loses its metadata. + if loop_values is not None: + workflow_run_block.loop_values = loop_values + if current_value is not None: + workflow_run_block.current_value = current_value + if current_index is not None: + workflow_run_block.current_index = current_index + if recipients: + workflow_run_block.recipients = recipients + if attachments: + workflow_run_block.attachments = attachments + if subject: + workflow_run_block.subject = subject + if body: + workflow_run_block.body = body + if prompt: + workflow_run_block.prompt = prompt + if wait_sec: + workflow_run_block.wait_sec = wait_sec + if description: + workflow_run_block.description = description + if block_workflow_run_id: + workflow_run_block.block_workflow_run_id = block_workflow_run_id + if engine: + workflow_run_block.engine = engine + # HTTP request block fields + if http_request_method: + workflow_run_block.http_request_method = http_request_method + if http_request_url: + workflow_run_block.http_request_url = http_request_url + if http_request_headers: + workflow_run_block.http_request_headers = http_request_headers + if http_request_body: + workflow_run_block.http_request_body = http_request_body + if http_request_parameters: + workflow_run_block.http_request_parameters = http_request_parameters + if http_request_timeout: + workflow_run_block.http_request_timeout = http_request_timeout + if http_request_follow_redirects is not None: + workflow_run_block.http_request_follow_redirects = http_request_follow_redirects + if ai_fallback_triggered is not None: + workflow_run_block.script_run = {"ai_fallback_triggered": ai_fallback_triggered} + if error_codes is not None: + workflow_run_block.error_codes = error_codes + # human interaction block fields + if instructions: + workflow_run_block.instructions = instructions + if positive_descriptor: + workflow_run_block.positive_descriptor = positive_descriptor + if negative_descriptor: + workflow_run_block.negative_descriptor = negative_descriptor + # conditional block fields + if executed_branch_id: + workflow_run_block.executed_branch_id = executed_branch_id + if executed_branch_expression is not None: + workflow_run_block.executed_branch_expression = executed_branch_expression + if executed_branch_result is not None: + workflow_run_block.executed_branch_result = executed_branch_result + if executed_branch_next_block is not None: + workflow_run_block.executed_branch_next_block = executed_branch_next_block + await session.commit() + await session.refresh(workflow_run_block) + else: + raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") + task = None + task_id = workflow_run_block.task_id + if task_id: + if self._task_reader is None: + raise RuntimeError("task_reader dependency not set") + task = await self._task_reader.get_task(task_id, organization_id=workflow_run_block.organization_id) + return convert_to_workflow_run_block(workflow_run_block, task=task) + + @db_operation("get_workflow_run_block") + async def get_workflow_run_block( + self, + workflow_run_block_id: str, + organization_id: str | None = None, + ) -> WorkflowRunBlock: + async with self.Session() as session: + workflow_run_block = ( + await session.scalars( + select(WorkflowRunBlockModel) + .filter_by(workflow_run_block_id=workflow_run_block_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if workflow_run_block: + task = None + task_id = workflow_run_block.task_id + if task_id: + if self._task_reader is None: + raise RuntimeError("task_reader dependency not set") + task = await self._task_reader.get_task(task_id, organization_id=organization_id) + return convert_to_workflow_run_block(workflow_run_block, task=task) + raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") + + @db_operation("get_workflow_run_block_by_task_id") + async def get_workflow_run_block_by_task_id( + self, + task_id: str, + organization_id: str | None = None, + ) -> WorkflowRunBlock: + async with self.Session() as session: + workflow_run_block = ( + await session.scalars( + select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first() + if workflow_run_block: + task = None + task_id = workflow_run_block.task_id + if task_id: + if self._task_reader is None: + raise RuntimeError("task_reader dependency not set") + task = await self._task_reader.get_task(task_id, organization_id=organization_id) + return convert_to_workflow_run_block(workflow_run_block, task=task) + raise NotFoundError(f"WorkflowRunBlock not found by {task_id}") + + @db_operation("get_workflow_run_blocks") + async def get_workflow_run_blocks( + self, + workflow_run_id: str, + organization_id: str | None = None, + ) -> list[WorkflowRunBlock]: + async with self.Session() as session: + workflow_run_blocks = ( + await session.scalars( + select(WorkflowRunBlockModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(organization_id=organization_id) + .order_by(WorkflowRunBlockModel.created_at.desc()) + ) + ).all() + if self._task_reader is None: + raise RuntimeError("task_reader dependency not set") + tasks = await self._task_reader.get_tasks_by_workflow_run_id(workflow_run_id) + tasks_dict = {task.task_id: task for task in tasks} + return [ + convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id)) + for workflow_run_block in workflow_run_blocks + ] diff --git a/skyvern/forge/sdk/db/repositories/organizations.py b/skyvern/forge/sdk/db/repositories/organizations.py new file mode 100644 index 000000000..df9168599 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/organizations.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Literal, overload + +from sqlalchemy import select, update + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_alchemy_db import read_retry +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + OrganizationAuthTokenModel, + OrganizationModel, + TaskModel, + WorkflowRunModel, +) +from skyvern.forge.sdk.db.utils import ( + convert_to_organization, + convert_to_organization_auth_token, +) +from skyvern.forge.sdk.encrypt import encryptor +from skyvern.forge.sdk.encrypt.base import EncryptMethod +from skyvern.forge.sdk.schemas.organizations import ( + AzureClientSecretCredential, + AzureOrganizationAuthToken, + BitwardenCredential, + BitwardenOrganizationAuthToken, + Organization, + OrganizationAuthToken, +) +from skyvern.forge.sdk.schemas.tasks import TaskStatus +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus + + +class OrganizationsRepository(BaseRepository): + """Database operations for organization and auth-token management.""" + + @read_retry() + @db_operation("get_active_verification_requests", log_errors=False) + async def get_active_verification_requests(self, organization_id: str) -> list[dict]: + """Return active 2FA verification requests for an organization. + + Queries both tasks and workflow runs where waiting_for_verification_code=True. + Used to provide initial state when a WebSocket notification client connects. + """ + results: list[dict] = [] + async with self.Session() as session: + # Tasks waiting for verification (exclude finalized tasks) + finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()] + task_rows = ( + await session.scalars( + select(TaskModel) + .filter_by(organization_id=organization_id) + .filter_by(waiting_for_verification_code=True) + .filter_by(workflow_run_id=None) + .filter(TaskModel.status.not_in(finalized_task_statuses)) + .filter(TaskModel.created_at > datetime.now(timezone.utc) - timedelta(hours=1)) + ) + ).all() + for t in task_rows: + results.append( + { + "task_id": t.task_id, + "workflow_run_id": None, + "verification_code_identifier": t.verification_code_identifier, + "verification_code_polling_started_at": ( + t.verification_code_polling_started_at.isoformat() + if t.verification_code_polling_started_at + else None + ), + } + ) + # Workflow runs waiting for verification (exclude finalized runs) + finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()] + wr_rows = ( + await session.scalars( + select(WorkflowRunModel) + .filter_by(organization_id=organization_id) + .filter_by(waiting_for_verification_code=True) + .filter(WorkflowRunModel.status.not_in(finalized_wr_statuses)) + .filter(WorkflowRunModel.created_at > datetime.now(timezone.utc) - timedelta(hours=1)) + ) + ).all() + for wr in wr_rows: + results.append( + { + "task_id": None, + "workflow_run_id": wr.workflow_run_id, + "verification_code_identifier": wr.verification_code_identifier, + "verification_code_polling_started_at": ( + wr.verification_code_polling_started_at.isoformat() + if wr.verification_code_polling_started_at + else None + ), + } + ) + return results + + @db_operation("get_all_organizations") + async def get_all_organizations(self) -> list[Organization]: + async with self.Session() as session: + organizations = (await session.scalars(select(OrganizationModel))).all() + return [convert_to_organization(organization) for organization in organizations] + + @db_operation("get_organization") + async def get_organization(self, organization_id: str) -> Organization | None: + async with self.Session() as session: + if organization := ( + await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id)) + ).first(): + return convert_to_organization(organization) + else: + return None + + @db_operation("get_organization_by_domain") + async def get_organization_by_domain(self, domain: str) -> Organization | None: + async with self.Session() as session: + if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first(): + return convert_to_organization(organization) + return None + + @db_operation("create_organization") + async def create_organization( + self, + organization_name: str, + webhook_callback_url: str | None = None, + max_steps_per_run: int | None = None, + max_retries_per_step: int | None = None, + domain: str | None = None, + ) -> Organization: + async with self.Session() as session: + org = OrganizationModel( + organization_name=organization_name, + webhook_callback_url=webhook_callback_url, + max_steps_per_run=max_steps_per_run, + max_retries_per_step=max_retries_per_step, + domain=domain, + ) + session.add(org) + await session.commit() + await session.refresh(org) + + return convert_to_organization(org) + + @db_operation("update_organization") + async def update_organization( + self, + organization_id: str, + organization_name: str | None = None, + webhook_callback_url: str | None = None, + max_steps_per_run: int | None = None, + max_retries_per_step: int | None = None, + ) -> Organization: + async with self.Session() as session: + organization = ( + await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id)) + ).first() + if not organization: + raise NotFoundError + if organization_name: + organization.organization_name = organization_name + if webhook_callback_url: + organization.webhook_callback_url = webhook_callback_url + if max_steps_per_run: + organization.max_steps_per_run = max_steps_per_run + if max_retries_per_step: + organization.max_retries_per_step = max_retries_per_step + await session.commit() + await session.refresh(organization) + return Organization.model_validate(organization) + + @overload + async def get_valid_org_auth_token( + self, + organization_id: str, + token_type: Literal["api", "onepassword_service_account", "custom_credential_service"], + ) -> OrganizationAuthToken | None: ... + + @overload + async def get_valid_org_auth_token( # type: ignore + self, + organization_id: str, + token_type: Literal["azure_client_secret_credential"], + ) -> AzureOrganizationAuthToken | None: ... + + @overload + async def get_valid_org_auth_token( # type: ignore + self, + organization_id: str, + token_type: Literal["bitwarden_credential"], + ) -> BitwardenOrganizationAuthToken | None: ... + + @db_operation("get_valid_org_auth_token") + async def get_valid_org_auth_token( + self, + organization_id: str, + token_type: Literal[ + "api", + "onepassword_service_account", + "azure_client_secret_credential", + "bitwarden_credential", + "custom_credential_service", + ], + ) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken | None: + async with self.Session() as session: + if token := ( + await session.scalars( + select(OrganizationAuthTokenModel) + .filter_by(organization_id=organization_id) + .filter_by(token_type=token_type) + .filter_by(valid=True) + .order_by(OrganizationAuthTokenModel.created_at.desc()) + ) + ).first(): + return await convert_to_organization_auth_token(token, token_type) + else: + return None + + @db_operation("replace_org_auth_token") + async def replace_org_auth_token( + self, + organization_id: str, + token_type: OrganizationAuthTokenType, + token: str | AzureClientSecretCredential | BitwardenCredential, + encrypted_method: EncryptMethod | None = None, + ) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken: + """Atomically invalidate existing tokens and create a new one in a single transaction.""" + if token_type is OrganizationAuthTokenType.azure_client_secret_credential: + if not isinstance(token, AzureClientSecretCredential): + raise TypeError("Expected AzureClientSecretCredential for this token_type") + plaintext_token = token.model_dump_json() + elif token_type is OrganizationAuthTokenType.bitwarden_credential: + if not isinstance(token, BitwardenCredential): + raise TypeError("Expected BitwardenCredential for this token_type") + plaintext_token = token.model_dump_json() + else: + if not isinstance(token, str): + raise TypeError("Expected str token for this token_type") + plaintext_token = token + + encrypted_token = "" + if encrypted_method is not None: + encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method) + plaintext_token = "" + + async with self.Session() as session: + # Invalidate existing tokens + await session.execute( + update(OrganizationAuthTokenModel) + .filter_by(organization_id=organization_id) + .filter_by(token_type=token_type) + .filter_by(valid=True) + .values(valid=False) + ) + # Create new token + auth_token = OrganizationAuthTokenModel( + organization_id=organization_id, + token_type=token_type, + token=plaintext_token, + encrypted_token=encrypted_token, + encrypted_method=encrypted_method.value if encrypted_method is not None else "", + ) + session.add(auth_token) + await session.commit() + await session.refresh(auth_token) + + return await convert_to_organization_auth_token(auth_token, token_type) + + @db_operation("get_valid_org_auth_tokens") + async def get_valid_org_auth_tokens( + self, + organization_id: str, + token_type: OrganizationAuthTokenType, + ) -> list[OrganizationAuthToken]: + async with self.Session() as session: + tokens = ( + await session.scalars( + select(OrganizationAuthTokenModel) + .filter_by(organization_id=organization_id) + .filter_by(token_type=token_type) + .filter_by(valid=True) + .order_by(OrganizationAuthTokenModel.created_at.desc()) + ) + ).all() + return [await convert_to_organization_auth_token(token, token_type) for token in tokens] + + @db_operation("validate_org_auth_token") + async def validate_org_auth_token( + self, + organization_id: str, + token_type: OrganizationAuthTokenType, + token: str, + valid: bool | None = True, + encrypted_method: EncryptMethod | None = None, + ) -> OrganizationAuthToken | None: + encrypted_token = "" + if encrypted_method is not None: + encrypted_token = await encryptor.encrypt(token, encrypted_method) + + async with self.Session() as session: + query = ( + select(OrganizationAuthTokenModel) + .filter_by(organization_id=organization_id) + .filter_by(token_type=token_type) + ) + if encrypted_token: + query = query.filter_by(encrypted_token=encrypted_token) + else: + query = query.filter_by(token=token) + if valid is not None: + query = query.filter_by(valid=valid) + if token_obj := (await session.scalars(query)).first(): + return await convert_to_organization_auth_token(token_obj, token_type) + else: + return None + + @db_operation("create_org_auth_token") + async def create_org_auth_token( + self, + organization_id: str, + token_type: OrganizationAuthTokenType, + token: str | AzureClientSecretCredential | BitwardenCredential, + encrypted_method: EncryptMethod | None = None, + ) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken: + if token_type is OrganizationAuthTokenType.azure_client_secret_credential: + if not isinstance(token, AzureClientSecretCredential): + raise TypeError("Expected AzureClientSecretCredential for this token_type") + plaintext_token = token.model_dump_json() + elif token_type is OrganizationAuthTokenType.bitwarden_credential: + if not isinstance(token, BitwardenCredential): + raise TypeError("Expected BitwardenCredential for this token_type") + plaintext_token = token.model_dump_json() + else: + if not isinstance(token, str): + raise TypeError("Expected str token for this token_type") + plaintext_token = token + + encrypted_token = "" + + if encrypted_method is not None: + encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method) + plaintext_token = "" + + async with self.Session() as session: + auth_token = OrganizationAuthTokenModel( + organization_id=organization_id, + token_type=token_type, + token=plaintext_token, + encrypted_token=encrypted_token, + encrypted_method=encrypted_method.value if encrypted_method is not None else "", + ) + session.add(auth_token) + await session.commit() + await session.refresh(auth_token) + + return await convert_to_organization_auth_token(auth_token, token_type) + + @db_operation("invalidate_org_auth_tokens") + async def invalidate_org_auth_tokens( + self, + organization_id: str, + token_type: OrganizationAuthTokenType, + ) -> None: + """Invalidate all existing tokens of a specific type for an organization.""" + async with self.Session() as session: + await session.execute( + update(OrganizationAuthTokenModel) + .filter_by(organization_id=organization_id) + .filter_by(token_type=token_type) + .filter_by(valid=True) + .values(valid=False) + ) + await session.commit() diff --git a/skyvern/forge/sdk/db/repositories/otp.py b/skyvern/forge/sdk/db/repositories/otp.py new file mode 100644 index 000000000..1586babb6 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/otp.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from sqlalchemy import and_, asc, select + +from skyvern.config import settings +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.models import TOTPCodeModel +from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode + + +class OTPRepository(BaseRepository): + """Database operations for OTP/TOTP management.""" + + @db_operation("get_otp_codes") + async def get_otp_codes( + self, + organization_id: str, + totp_identifier: str, + valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES, + otp_type: OTPType | None = None, + workflow_run_id: str | None = None, + limit: int | None = None, + ) -> list[TOTPCode]: + """ + 1. filter by: + - organization_id + - totp_identifier + - workflow_run_id (optional) + 2. make sure created_at is within the valid lifespan + 3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc + 4. apply an optional limit at the DB layer + """ + all_null = and_( + TOTPCodeModel.task_id.is_(None), + TOTPCodeModel.workflow_id.is_(None), + TOTPCodeModel.workflow_run_id.is_(None), + ) + async with self.Session() as session: + query = ( + select(TOTPCodeModel) + .filter_by(organization_id=organization_id) + .filter_by(totp_identifier=totp_identifier) + .filter( + TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes) + ) + ) + if otp_type: + query = query.filter(TOTPCodeModel.otp_type == otp_type) + if workflow_run_id is not None: + query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id) + query = query.order_by(asc(all_null), TOTPCodeModel.created_at.desc()) + if limit is not None: + query = query.limit(limit) + totp_codes = (await session.scalars(query)).all() + return [TOTPCode.model_validate(code) for code in totp_codes] + + @db_operation("get_otp_codes_by_run") + async def get_otp_codes_by_run( + self, + organization_id: str, + task_id: str | None = None, + workflow_run_id: str | None = None, + valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES, + limit: int = 1, + ) -> list[TOTPCode]: + """Get OTP codes matching a specific task or workflow run (no totp_identifier required). + + Used when the agent detects a 2FA page but no TOTP credentials are pre-configured. + The user submits codes manually via the UI, and this method finds them by run context. + """ + if not workflow_run_id and not task_id: + return [] + async with self.Session() as session: + query = ( + select(TOTPCodeModel) + .filter_by(organization_id=organization_id) + .filter( + TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes) + ) + ) + if workflow_run_id: + query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id) + elif task_id: + query = query.filter(TOTPCodeModel.task_id == task_id) + query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit) + results = (await session.scalars(query)).all() + return [TOTPCode.model_validate(r) for r in results] + + @db_operation("get_recent_otp_codes") + async def get_recent_otp_codes( + self, + organization_id: str, + limit: int = 50, + valid_lifespan_minutes: int | None = None, + otp_type: OTPType | None = None, + workflow_run_id: str | None = None, + totp_identifier: str | None = None, + ) -> list[TOTPCode]: + """ + Return recent otp codes for an organization ordered by newest first with optional + workflow_run_id filtering. + """ + async with self.Session() as session: + query = select(TOTPCodeModel).filter_by(organization_id=organization_id) + + if valid_lifespan_minutes is not None: + query = query.filter( + TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes) + ) + + if otp_type: + query = query.filter(TOTPCodeModel.otp_type == otp_type) + if workflow_run_id is not None: + query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id) + if totp_identifier: + query = query.filter(TOTPCodeModel.totp_identifier == totp_identifier) + query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit) + totp_codes = (await session.scalars(query)).all() + return [TOTPCode.model_validate(totp_code) for totp_code in totp_codes] + + @db_operation("create_otp_code") + async def create_otp_code( + self, + organization_id: str, + totp_identifier: str, + content: str, + code: str, + otp_type: OTPType, + task_id: str | None = None, + workflow_id: str | None = None, + workflow_run_id: str | None = None, + source: str | None = None, + expired_at: datetime | None = None, + ) -> TOTPCode: + async with self.Session() as session: + new_totp_code = TOTPCodeModel( + organization_id=organization_id, + totp_identifier=totp_identifier, + content=content, + code=code, + task_id=task_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + source=source, + expired_at=expired_at, + otp_type=otp_type, + ) + session.add(new_totp_code) + await session.commit() + await session.refresh(new_totp_code) + return TOTPCode.model_validate(new_totp_code) diff --git a/skyvern/forge/sdk/db/repositories/schedules.py b/skyvern/forge/sdk/db/repositories/schedules.py new file mode 100644 index 000000000..f13d404cc --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/schedules.py @@ -0,0 +1,602 @@ +"""Database operations for workflow schedules.""" + +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, Callable + +import structlog +from sqlalchemy import exists, func, or_, select, text, update +from sqlalchemy.exc import SQLAlchemyError + +from skyvern.forge.sdk.db._error_handling import db_operation, register_passthrough_exception +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError +from skyvern.forge.sdk.db.models import WorkflowModel, WorkflowRunModel, WorkflowScheduleModel +from skyvern.forge.sdk.db.utils import convert_to_workflow_schedule +from skyvern.forge.sdk.schemas.workflow_schedules import OrganizationScheduleItem, WorkflowSchedule +from skyvern.forge.sdk.workflow.schedules import compute_next_run + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + +from . import _UNSET + +LOG = structlog.get_logger() + +register_passthrough_exception(ScheduleLimitExceededError) + + +class SchedulesRepository(BaseRepository): + """Database operations for workflow schedules.""" + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + sqlite_schedule_lock: asyncio.Lock | None = None, + ) -> None: + super().__init__(session_factory, debug_enabled, is_retryable_error_fn) + self._sqlite_schedule_lock = sqlite_schedule_lock + + @db_operation("create_workflow_schedule") + async def create_workflow_schedule( + self, + organization_id: str, + workflow_permanent_id: str, + cron_expression: str, + timezone: str, + enabled: bool, + parameters: dict[str, Any] | None = None, + temporal_schedule_id: str | None = None, + name: str | None = None, + description: str | None = None, + ) -> WorkflowSchedule: + async with self.Session() as session: + workflow_schedule = WorkflowScheduleModel( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + cron_expression=cron_expression, + timezone=timezone, + enabled=enabled, + parameters=parameters, + temporal_schedule_id=temporal_schedule_id, + name=name, + description=description, + ) + session.add(workflow_schedule) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("create_workflow_schedule_with_limit") + async def create_workflow_schedule_with_limit( + self, + organization_id: str, + workflow_permanent_id: str, + max_schedules: int | None, + cron_expression: str, + timezone: str, + enabled: bool, + parameters: dict[str, Any] | None = None, + name: str | None = None, + description: str | None = None, + ) -> tuple[WorkflowSchedule, int]: + """Create a schedule atomically with limit enforcement. + + On PostgreSQL, uses an advisory lock to serialize concurrent creates for + the same workflow, preventing TOCTOU races on the schedule count. + + On SQLite, uses an asyncio.Lock (set on AgentDB.__init__) since SQLite + is single-writer and has no advisory lock support. + + Returns (created_schedule, count_before_insert). + Raises ScheduleLimitExceededError if count >= max_schedules. + """ + # SQLite: serialize via Python lock (no advisory locks available). + # The lock is held across the count-check + insert to prevent TOCTOU. + if self._sqlite_schedule_lock is not None: + async with self._sqlite_schedule_lock: + return await self._create_schedule_with_limit_inner( + organization_id, + workflow_permanent_id, + max_schedules, + cron_expression, + timezone, + enabled, + parameters, + name, + description, + use_advisory_lock=False, + ) + return await self._create_schedule_with_limit_inner( + organization_id, + workflow_permanent_id, + max_schedules, + cron_expression, + timezone, + enabled, + parameters, + name, + description, + use_advisory_lock=True, + ) + + # Intentionally not decorated with @db_operation — errors are caught by the + # outer create_workflow_schedule_with_limit which owns the operation name. + async def _create_schedule_with_limit_inner( + self, + organization_id: str, + workflow_permanent_id: str, + max_schedules: int | None, + cron_expression: str, + timezone: str, + enabled: bool, + parameters: dict[str, Any] | None, + name: str | None, + description: str | None, + *, + use_advisory_lock: bool, + ) -> tuple[WorkflowSchedule, int]: + async with self.Session() as session: + if use_advisory_lock: + lock_key = f"schedule:{organization_id}:{workflow_permanent_id}" + await session.execute( + text("SELECT pg_advisory_xact_lock(hashtext(:key))"), + {"key": lock_key}, + ) + + count = ( + await session.execute( + select(func.count()).where( + WorkflowScheduleModel.organization_id == organization_id, + WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScheduleModel.deleted_at.is_(None), + ) + ) + ).scalar_one() + + if max_schedules is not None and count >= max_schedules: + raise ScheduleLimitExceededError( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + current_count=count, + max_allowed=max_schedules, + ) + + workflow_schedule = WorkflowScheduleModel( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + cron_expression=cron_expression, + timezone=timezone, + enabled=enabled, + parameters=parameters, + name=name, + description=description, + ) + session.add(workflow_schedule) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled), count + + @db_operation("set_temporal_schedule_id") + async def set_temporal_schedule_id( + self, + workflow_schedule_id: str, + organization_id: str, + temporal_schedule_id: str, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).first() + + if not workflow_schedule: + return None + + workflow_schedule.temporal_schedule_id = temporal_schedule_id + workflow_schedule.modified_at = datetime.now(UTC) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("update_workflow_schedule") + async def update_workflow_schedule( + self, + workflow_schedule_id: str, + organization_id: str, + cron_expression: str, + timezone: str, + enabled: bool, + parameters: dict[str, Any] | None = None, + temporal_schedule_id: str | None | object = _UNSET, + name: str | None | object = _UNSET, + description: str | None | object = _UNSET, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).first() + + if not workflow_schedule: + return None + + workflow_schedule.cron_expression = cron_expression + workflow_schedule.timezone = timezone + workflow_schedule.enabled = enabled + workflow_schedule.parameters = parameters + if temporal_schedule_id is not _UNSET: + workflow_schedule.temporal_schedule_id = temporal_schedule_id + if name is not _UNSET: + workflow_schedule.name = name + if description is not _UNSET: + workflow_schedule.description = description + workflow_schedule.modified_at = datetime.now(UTC) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("get_workflow_schedule_by_id") + async def get_workflow_schedule_by_id( + self, + workflow_schedule_id: str, + organization_id: str, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).first() + if not workflow_schedule: + return None + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("get_workflow_schedules") + async def get_workflow_schedules( + self, + workflow_permanent_id: str, + organization_id: str, + ) -> list[WorkflowSchedule]: + async with self.Session() as session: + rows = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_permanent_id=workflow_permanent_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).all() + return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows] + + @db_operation("get_all_enabled_schedules") + async def get_all_enabled_schedules( + self, + organization_id: str | None = None, + ) -> list[WorkflowSchedule]: + """Fetch all enabled, non-deleted schedules, optionally filtered by org.""" + async with self.Session() as session: + stmt = select(WorkflowScheduleModel).where( + WorkflowScheduleModel.enabled.is_(True), + WorkflowScheduleModel.deleted_at.is_(None), + ) + if organization_id: + stmt = stmt.where(WorkflowScheduleModel.organization_id == organization_id) + rows = (await session.scalars(stmt)).all() + return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows] + + @db_operation("has_schedule_fired_since") + async def has_schedule_fired_since( + self, + workflow_schedule_id: str, + since: datetime, + ) -> bool: + """Check if a workflow_run exists for the given schedule since a timestamp.""" + async with self.Session() as session: + row = ( + await session.execute( + select( + exists().where( + WorkflowRunModel.workflow_schedule_id == workflow_schedule_id, + WorkflowRunModel.created_at >= since, + ) + ) + ) + ).scalar() + return bool(row) + + @db_operation("update_workflow_schedule_enabled") + async def update_workflow_schedule_enabled( + self, + workflow_schedule_id: str, + organization_id: str, + enabled: bool, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).first() + if not workflow_schedule: + return None + workflow_schedule.enabled = enabled + workflow_schedule.modified_at = datetime.now(UTC) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("delete_workflow_schedule") + async def delete_workflow_schedule( + self, + workflow_schedule_id: str, + organization_id: str, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel).filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + deleted_at=None, + ) + ) + ).first() + if not workflow_schedule: + return None + + workflow_schedule.deleted_at = datetime.now(UTC) + workflow_schedule.modified_at = datetime.now(UTC) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("restore_workflow_schedule") + async def restore_workflow_schedule( + self, + workflow_schedule_id: str, + organization_id: str, + ) -> WorkflowSchedule | None: + async with self.Session() as session: + workflow_schedule = ( + await session.scalars( + select(WorkflowScheduleModel) + .filter_by( + workflow_schedule_id=workflow_schedule_id, + organization_id=organization_id, + ) + .filter(WorkflowScheduleModel.deleted_at.isnot(None)) + ) + ).first() + if not workflow_schedule: + return None + + workflow_schedule.deleted_at = None + workflow_schedule.modified_at = datetime.now(UTC) + await session.commit() + await session.refresh(workflow_schedule) + return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled) + + @db_operation("count_workflow_schedules") + async def count_workflow_schedules( + self, + organization_id: str, + workflow_permanent_id: str, + ) -> int: + async with self.Session() as session: + result = await session.execute( + select(func.count()).where( + WorkflowScheduleModel.organization_id == organization_id, + WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScheduleModel.deleted_at.is_(None), + ) + ) + return result.scalar_one() + + @db_operation("list_organization_schedules") + async def list_organization_schedules( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + enabled_filter: bool | None = None, + search: str | None = None, + ) -> tuple[list[OrganizationScheduleItem], int]: + """ + List all schedules for an organization, joined with workflow titles. + Returns (schedules, total_count). + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + db_page = page - 1 + async with self.Session() as session: + # Subquery to get the latest version title per workflow_permanent_id + latest_version_sq = ( + select( + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by(WorkflowModel.workflow_permanent_id) + .subquery() + ) + + workflow_title_sq = ( + select( + WorkflowModel.workflow_permanent_id, + WorkflowModel.title, + ) + .join( + latest_version_sq, + (WorkflowModel.workflow_permanent_id == latest_version_sq.c.workflow_permanent_id) + & (WorkflowModel.version == latest_version_sq.c.max_version), + ) + .subquery() + ) + + # Base query: schedules joined with workflow titles + base_filter = ( + select(WorkflowScheduleModel, workflow_title_sq.c.title.label("workflow_title")) + .outerjoin( + workflow_title_sq, + WorkflowScheduleModel.workflow_permanent_id == workflow_title_sq.c.workflow_permanent_id, + ) + .where(WorkflowScheduleModel.organization_id == organization_id) + .where(WorkflowScheduleModel.deleted_at.is_(None)) + ) + + if enabled_filter is not None: + base_filter = base_filter.where(WorkflowScheduleModel.enabled == enabled_filter) + + if search: + base_filter = base_filter.where( + or_( + workflow_title_sq.c.title.icontains(search, autoescape=True), + WorkflowScheduleModel.name.icontains(search, autoescape=True), + ) + ) + + # Count query + count_query = select(func.count()).select_from(base_filter.subquery()) + total_count = (await session.execute(count_query)).scalar_one() + + # Data query with pagination + data_query = ( + base_filter.order_by(WorkflowScheduleModel.created_at.desc()) + .limit(page_size) + .offset(db_page * page_size) + ) + rows = (await session.execute(data_query)).all() + + # Materialize row data while session is open + raw_schedules = [] + for row in rows: + schedule_model = row[0] + raw_schedules.append( + ( + schedule_model.workflow_schedule_id, + schedule_model.organization_id, + schedule_model.workflow_permanent_id, + row[1] or "Untitled Workflow", + schedule_model.cron_expression, + schedule_model.timezone, + schedule_model.enabled, + schedule_model.parameters, + schedule_model.name, + schedule_model.description, + schedule_model.created_at, + schedule_model.modified_at, + ) + ) + + # Compute next_run outside session scope (pure CPU, no DB needed) + schedules: list[OrganizationScheduleItem] = [] + for ( + ws_id, + org_id, + wpid, + title, + cron_expr, + tz, + enabled, + params, + name, + description, + created, + modified, + ) in raw_schedules: + next_run = None + if enabled: + try: + next_run = compute_next_run(cron_expr, tz) + except Exception: + LOG.warning( + "Failed to compute next_run for schedule", + workflow_schedule_id=ws_id, + exc_info=True, + ) + + schedules.append( + OrganizationScheduleItem( + workflow_schedule_id=ws_id, + organization_id=org_id, + workflow_permanent_id=wpid, + workflow_title=title, + cron_expression=cron_expr, + timezone=tz, + enabled=enabled, + parameters=params, + name=name, + description=description, + next_run=next_run, + created_at=created, + modified_at=modified, + ) + ) + + return schedules, total_count + + @db_operation("soft_delete_orphaned_schedules") + async def soft_delete_orphaned_schedules(self, limit: int = 500) -> list[tuple[str, str]]: + """Soft-delete orphaned schedules and return their identities. + + Uses a single UPDATE ... RETURNING statement so orphan detection and + soft-deletion happen atomically in one DB round-trip. + """ + async with self.Session() as session: + active_workflow_exists = ( + select(WorkflowModel.workflow_permanent_id) + .where(WorkflowModel.workflow_permanent_id == WorkflowScheduleModel.workflow_permanent_id) + .where(WorkflowModel.deleted_at.is_(None)) + .correlate(WorkflowScheduleModel) + .exists() + ) + orphaned_schedules = ( + select( + WorkflowScheduleModel.workflow_schedule_id.label("workflow_schedule_id"), + WorkflowScheduleModel.workflow_permanent_id.label("workflow_permanent_id"), + ) + .where(WorkflowScheduleModel.deleted_at.is_(None)) + .where(~active_workflow_exists) + .limit(limit) + .cte("orphaned_schedules") + ) + update_query = ( + update(WorkflowScheduleModel) + .where( + WorkflowScheduleModel.workflow_schedule_id.in_(select(orphaned_schedules.c.workflow_schedule_id)) + ) + .where(WorkflowScheduleModel.deleted_at.is_(None)) + .values(deleted_at=datetime.now(UTC)) + .returning( + WorkflowScheduleModel.workflow_schedule_id, + WorkflowScheduleModel.workflow_permanent_id, + ) + ) + result = await session.execute(update_query) + await session.commit() + return [(row[0], row[1]) for row in result.all()] diff --git a/skyvern/forge/sdk/db/repositories/scripts.py b/skyvern/forge/sdk/db/repositories/scripts.py new file mode 100644 index 000000000..e1673c580 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/scripts.py @@ -0,0 +1,1206 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import structlog +from sqlalchemy import and_, delete, distinct, func, select, update +from sqlalchemy.dialects.postgresql import insert + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + ScriptBlockModel, + ScriptBranchHitModel, + ScriptFallbackEpisodeModel, + ScriptFileModel, + ScriptModel, + WorkflowRunModel, + WorkflowScriptModel, +) +from skyvern.forge.sdk.db.utils import ( + convert_to_script, + convert_to_script_block, + convert_to_script_file, +) +from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text +from skyvern.schemas.scripts import ( + Script, + ScriptBlock, + ScriptBranchHit, + ScriptFallbackEpisode, + ScriptFile, + ScriptStatus, + WorkflowScript, +) + +LOG = structlog.get_logger() + + +class ScriptsRepository(BaseRepository): + """Database operations for scripts, script files, script blocks, workflow scripts, and fallback episodes.""" + + @db_operation("create_script") + async def create_script( + self, + organization_id: str, + run_id: str | None = None, + script_id: str | None = None, + version: int | None = None, + ) -> Script: + async with self.Session() as session: + script = ScriptModel( + organization_id=organization_id, + run_id=run_id, + ) + if script_id: + script.script_id = script_id + if version: + script.version = version + session.add(script) + await session.commit() + await session.refresh(script) + return convert_to_script(script) + + @db_operation("get_scripts") + async def get_scripts( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + ) -> list[Script]: + async with self.Session() as session: + # Calculate offset for pagination + offset = (page - 1) * page_size + + # Subquery to get the latest version of each script + latest_versions_subquery = ( + select(ScriptModel.script_id, func.max(ScriptModel.version).label("latest_version")) + .filter_by(organization_id=organization_id) + .filter(ScriptModel.deleted_at.is_(None)) + .group_by(ScriptModel.script_id) + .subquery() + ) + + # Main query to get scripts with their latest versions + get_scripts_query = ( + select(ScriptModel) + .join( + latest_versions_subquery, + and_( + ScriptModel.script_id == latest_versions_subquery.c.script_id, + ScriptModel.version == latest_versions_subquery.c.latest_version, + ), + ) + .filter_by(organization_id=organization_id) + .filter(ScriptModel.deleted_at.is_(None)) + .order_by(ScriptModel.created_at.desc()) + .limit(page_size) + .offset(offset) + ) + scripts = (await session.scalars(get_scripts_query)).all() + return [convert_to_script(script) for script in scripts] + + @db_operation("get_script") + async def get_script( + self, + script_id: str, + organization_id: str, + version: int | None = None, + ) -> Script | None: + """Get a specific script by ID and optionally by version.""" + async with self.Session() as session: + get_script_query = ( + select(ScriptModel) + .filter_by(script_id=script_id) + .filter_by(organization_id=organization_id) + .filter(ScriptModel.deleted_at.is_(None)) + ) + + if version is not None: + get_script_query = get_script_query.filter_by(version=version) + else: + # Get the latest version + get_script_query = get_script_query.order_by(ScriptModel.version.desc()).limit(1) + + if script := (await session.scalars(get_script_query)).first(): + return convert_to_script(script) + return None + + @db_operation("get_script_revision") + async def get_script_revision(self, script_revision_id: str, organization_id: str) -> Script | None: + async with self.Session() as session: + script = ( + await session.scalars( + select(ScriptModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(organization_id=organization_id) + ) + ).first() + return convert_to_script(script) if script else None + + @db_operation("get_latest_script_version") + async def get_latest_script_version(self, script_id: str, organization_id: str) -> Script | None: + """Get the latest version of a script by script_id.""" + async with self.Session() as session: + script = ( + await session.scalars( + select(ScriptModel) + .filter_by(script_id=script_id, organization_id=organization_id) + .filter(ScriptModel.deleted_at.is_(None)) + .order_by(ScriptModel.version.desc()) + .limit(1) + ) + ).first() + return convert_to_script(script) if script else None + + @db_operation("get_script_versions") + async def get_script_versions( + self, + script_id: str, + organization_id: str, + ) -> list[Script]: + """Get all versions of a script, ordered by version DESC.""" + async with self.Session() as session: + query = ( + select(ScriptModel) + .filter( + ScriptModel.script_id == script_id, + ScriptModel.organization_id == organization_id, + ScriptModel.deleted_at.is_(None), + ) + .order_by(ScriptModel.version.desc()) + ) + result = await session.scalars(query) + return [convert_to_script(row) for row in result.all()] + + @db_operation("get_script_version_stats") + async def get_script_version_stats( + self, + organization_id: str, + script_ids: list[str], + ) -> dict[str, tuple[int, int]]: + """Return {script_id: (latest_version, version_count)} for the given script IDs.""" + if not script_ids: + return {} + async with self.Session() as session: + query = ( + select( + ScriptModel.script_id, + # max(version) must include soft-deleted rows so next-version + # assignment doesn't collide with the unique constraint. + func.max(ScriptModel.version), + # version_count only counts live rows (for display). + func.count(ScriptModel.script_revision_id).filter( + ScriptModel.deleted_at.is_(None), + ), + ) + .filter( + ScriptModel.organization_id == organization_id, + ScriptModel.script_id.in_(script_ids), + ) + .group_by(ScriptModel.script_id) + ) + rows = (await session.execute(query)).all() + return {row[0]: (row[1], row[2]) for row in rows} + + @db_operation("soft_delete_script_by_revision") + async def soft_delete_script_by_revision(self, script_revision_id: str, organization_id: str) -> None: + async with self.Session() as session: + await session.execute( + update(ScriptModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(organization_id=organization_id) + .values(deleted_at=datetime.now(timezone.utc)) + ) + await session.commit() + + @db_operation("create_script_file") + async def create_script_file( + self, + script_revision_id: str, + script_id: str, + organization_id: str, + file_path: str, + file_name: str, + file_type: str, + content_hash: str | None = None, + file_size: int | None = None, + mime_type: str | None = None, + encoding: str = "utf-8", + artifact_id: str | None = None, + ) -> ScriptFile: + """Create a script file.""" + async with self.Session() as session: + script_file = ScriptFileModel( + script_revision_id=script_revision_id, + script_id=script_id, + organization_id=organization_id, + file_path=file_path, + file_name=file_name, + file_type=file_type, + content_hash=content_hash, + file_size=file_size, + mime_type=mime_type, + encoding=encoding, + artifact_id=artifact_id, + ) + session.add(script_file) + await session.commit() + await session.refresh(script_file) + return convert_to_script_file(script_file) + + @db_operation("create_script_block") + async def create_script_block( + self, + script_revision_id: str, + script_id: str, + organization_id: str, + script_block_label: str, + script_file_id: str | None = None, + run_signature: str | None = None, + workflow_run_id: str | None = None, + workflow_run_block_id: str | None = None, + input_fields: list[str] | None = None, + requires_agent: bool = False, + ) -> ScriptBlock: + """Create a script block.""" + async with self.Session() as session: + script_block = ScriptBlockModel( + script_revision_id=script_revision_id, + script_id=script_id, + organization_id=organization_id, + script_block_label=script_block_label, + script_file_id=script_file_id, + run_signature=run_signature, + workflow_run_id=workflow_run_id, + workflow_run_block_id=workflow_run_block_id, + input_fields=input_fields, + requires_agent=requires_agent, + ) + session.add(script_block) + await session.commit() + await session.refresh(script_block) + return convert_to_script_block(script_block) + + @db_operation("update_script_block") + async def update_script_block( + self, + script_block_id: str, + organization_id: str, + script_file_id: str | None = None, + run_signature: str | None = None, + workflow_run_id: str | None = None, + workflow_run_block_id: str | None = None, + clear_run_signature: bool = False, + input_fields: list[str] | None = None, + requires_agent: bool | None = None, + ) -> ScriptBlock: + async with self.Session() as session: + script_block = ( + await session.scalars( + select(ScriptBlockModel) + .filter_by(script_block_id=script_block_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if script_block: + if script_file_id is not None: + script_block.script_file_id = script_file_id + if clear_run_signature: + script_block.run_signature = None + elif run_signature is not None: + script_block.run_signature = run_signature + if workflow_run_id is not None: + script_block.workflow_run_id = workflow_run_id + if workflow_run_block_id is not None: + script_block.workflow_run_block_id = workflow_run_block_id + if input_fields is not None: + script_block.input_fields = input_fields + if requires_agent is not None: + script_block.requires_agent = requires_agent + await session.commit() + await session.refresh(script_block) + return convert_to_script_block(script_block) + else: + raise NotFoundError("Script block not found") + + @db_operation("get_script_files") + async def get_script_files(self, script_revision_id: str, organization_id: str) -> list[ScriptFile]: + async with self.Session() as session: + script_files = ( + await session.scalars( + select(ScriptFileModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_script_file(script_file) for script_file in script_files] + + @db_operation("get_script_file_by_id") + async def get_script_file_by_id( + self, + script_revision_id: str, + file_id: str, + organization_id: str, + ) -> ScriptFile | None: + async with self.Session() as session: + script_file = ( + await session.scalars( + select(ScriptFileModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(file_id=file_id) + .filter_by(organization_id=organization_id) + ) + ).first() + + return convert_to_script_file(script_file) if script_file else None + + @db_operation("get_script_file_by_path") + async def get_script_file_by_path( + self, + script_revision_id: str, + file_path: str, + organization_id: str, + ) -> ScriptFile | None: + async with self.Session() as session: + script_file = ( + await session.scalars( + select(ScriptFileModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(file_path=file_path) + .filter_by(organization_id=organization_id) + ) + ).first() + return convert_to_script_file(script_file) if script_file else None + + @db_operation("update_script_file") + async def update_script_file( + self, + script_file_id: str, + organization_id: str, + artifact_id: str | None = None, + ) -> ScriptFile: + async with self.Session() as session: + script_file = ( + await session.scalars( + select(ScriptFileModel).filter_by(file_id=script_file_id).filter_by(organization_id=organization_id) + ) + ).first() + if script_file: + if artifact_id: + script_file.artifact_id = artifact_id + await session.commit() + await session.refresh(script_file) + return convert_to_script_file(script_file) + else: + raise NotFoundError("Script file not found") + + @db_operation("get_script_block") + async def get_script_block( + self, + script_block_id: str, + organization_id: str, + ) -> ScriptBlock | None: + async with self.Session() as session: + record = ( + await session.scalars( + select(ScriptBlockModel) + .filter_by(script_block_id=script_block_id) + .filter_by(organization_id=organization_id) + ) + ).first() + return convert_to_script_block(record) if record else None + + @db_operation("get_script_block_by_label") + async def get_script_block_by_label( + self, + organization_id: str, + script_revision_id: str, + script_block_label: str, + ) -> ScriptBlock | None: + async with self.Session() as session: + record = ( + await session.scalars( + select(ScriptBlockModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(script_block_label=script_block_label) + .filter_by(organization_id=organization_id) + ) + ).first() + return convert_to_script_block(record) if record else None + + @db_operation("get_script_blocks_by_script_revision_id") + async def get_script_blocks_by_script_revision_id( + self, + script_revision_id: str, + organization_id: str, + ) -> list[ScriptBlock]: + async with self.Session() as session: + records = ( + await session.scalars( + select(ScriptBlockModel) + .filter_by(script_revision_id=script_revision_id) + .filter_by(organization_id=organization_id) + .order_by(ScriptBlockModel.created_at.asc()) + ) + ).all() + return [convert_to_script_block(record) for record in records] + + @db_operation("create_workflow_script") + async def create_workflow_script( + self, + *, + organization_id: str, + script_id: str, + workflow_permanent_id: str, + cache_key: str, + cache_key_value: str, + workflow_id: str | None = None, + workflow_run_id: str | None = None, + status: ScriptStatus = ScriptStatus.published, + ) -> None: + """Create a workflow->script cache mapping entry.""" + async with self.Session() as session: + record = WorkflowScriptModel( + organization_id=organization_id, + script_id=script_id, + workflow_permanent_id=workflow_permanent_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + cache_key=cache_key, + cache_key_value=cache_key_value, + status=status, + ) + session.add(record) + await session.commit() + + @db_operation("get_workflow_script") + async def get_workflow_script( + self, + organization_id: str, + workflow_permanent_id: str, + workflow_run_id: str, + statuses: list[ScriptStatus] | None = None, + ) -> WorkflowScript | None: + async with self.Session() as session: + query = ( + select(WorkflowScriptModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(workflow_run_id=workflow_run_id) + ) + if statuses: + query = query.filter(WorkflowScriptModel.status.in_(statuses)) + workflow_script_model = (await session.scalars(query)).first() + return WorkflowScript.model_validate(workflow_script_model) if workflow_script_model else None + + @db_operation("get_workflow_script_by_cache_key_value") + async def get_workflow_script_by_cache_key_value( + self, + *, + organization_id: str, + workflow_permanent_id: str, + cache_key_value: str, + workflow_run_id: str | None = None, + cache_key: str | None = None, + statuses: list[ScriptStatus] | None = None, + ) -> Script | None: + """Get latest script version linked to a workflow by a specific cache_key_value.""" + async with self.Session() as session: + # Build the query: join workflow_scripts with scripts + # Join on both script_id and organization_id to leverage uc_org_script_version index + query = ( + select(ScriptModel) + .join( + WorkflowScriptModel, + and_( + ScriptModel.organization_id == WorkflowScriptModel.organization_id, + ScriptModel.script_id == WorkflowScriptModel.script_id, + ), + ) + .where( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.cache_key_value == cache_key_value, + WorkflowScriptModel.deleted_at.is_(None), + ) + ) + + if workflow_run_id: + query = query.where(WorkflowScriptModel.workflow_run_id == workflow_run_id) + + if cache_key is not None: + query = query.where(WorkflowScriptModel.cache_key == cache_key) + + if statuses is not None and len(statuses) > 0: + query = query.where(WorkflowScriptModel.status.in_(statuses)) + + query = query.order_by(ScriptModel.created_at.desc(), ScriptModel.version.desc()).limit(1) + + script = (await session.scalars(query)).first() + return convert_to_script(script) if script else None + + @db_operation("get_workflow_cache_key_count") + async def get_workflow_cache_key_count( + self, + organization_id: str, + workflow_permanent_id: str, + cache_key: str, + filter: str | None = None, + ) -> int: + async with self.Session() as session: + query = ( + select(func.count()) + .select_from(WorkflowScriptModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(cache_key=cache_key) + .filter_by(deleted_at=None) + .filter_by(status="published") + ) + + if filter: + query = query.filter(WorkflowScriptModel.cache_key_value.contains(filter)) + + return (await session.execute(query)).scalar_one() + + @db_operation("get_workflow_cache_key_values") + async def get_workflow_cache_key_values( + self, + organization_id: str, + workflow_permanent_id: str, + cache_key: str, + page: int = 1, + page_size: int = 100, + filter: str | None = None, + ) -> list[str]: + async with self.Session() as session: + query = ( + select(WorkflowScriptModel.cache_key_value) + .order_by(WorkflowScriptModel.cache_key_value.asc()) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(cache_key=cache_key) + .filter_by(deleted_at=None) + .filter_by(status="published") + .offset((page - 1) * page_size) + .limit(page_size) + ) + + if filter: + query = query.filter(WorkflowScriptModel.cache_key_value.contains(filter)) + + return (await session.scalars(query)).all() + + @db_operation("delete_workflow_cache_key_value") + async def delete_workflow_cache_key_value( + self, + organization_id: str, + workflow_permanent_id: str, + cache_key_value: str, + ) -> bool: + """ + Soft delete workflow cache key values by setting deleted_at timestamp. + + Returns True if any records were deleted, False otherwise. + """ + async with self.Session() as session: + stmt = ( + update(WorkflowScriptModel) + .where( + and_( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.cache_key_value == cache_key_value, + WorkflowScriptModel.deleted_at.is_(None), + ) + ) + .values(deleted_at=datetime.now(timezone.utc)) + ) + + result = await session.execute(stmt) + await session.commit() + + return result.rowcount > 0 + + @db_operation("delete_workflow_scripts_by_permanent_id") + async def delete_workflow_scripts_by_permanent_id( + self, + organization_id: str, + workflow_permanent_id: str, + statuses: list[ScriptStatus] | None = None, + script_ids: list[str] | None = None, + ) -> int: + """ + Soft delete all published workflow scripts for a workflow permanent id by setting deleted_at timestamp. + + Returns True if any records were deleted, False otherwise. + """ + async with self.Session() as session: + stmt = ( + update(WorkflowScriptModel) + .where( + and_( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.deleted_at.is_(None), + ) + ) + .values(deleted_at=datetime.now(timezone.utc)) + ) + + if statuses: + stmt = stmt.where(WorkflowScriptModel.status.in_([s.value for s in statuses])) + + if script_ids: + stmt = stmt.where(WorkflowScriptModel.script_id.in_(script_ids)) + + result = await session.execute(stmt) + await session.commit() + + return result.rowcount + + @db_operation("get_workflow_scripts_by_permanent_id") + async def get_workflow_scripts_by_permanent_id( + self, + organization_id: str, + workflow_permanent_id: str, + statuses: list[ScriptStatus] | None = None, + ) -> list[WorkflowScriptModel]: + async with self.Session() as session: + query = ( + select(WorkflowScriptModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(deleted_at=None) + ) + + if statuses: + query = query.filter(WorkflowScriptModel.status.in_([s.value for s in statuses])) + + query = query.order_by(WorkflowScriptModel.modified_at.desc()) + return (await session.scalars(query)).all() + + # -- Script Run / Stats ------------------------------------------------ + + @db_operation("get_workflow_runs_for_script") + async def get_workflow_runs_for_script( + self, + organization_id: str, + script_id: str, + page_size: int = 50, + created_after: datetime | None = None, + created_before: datetime | None = None, + ) -> tuple[list[WorkflowRunModel], int, dict[str, int]]: + """Get workflow runs associated with a script, with total count and status counts. + + Returns (runs, total_count, status_counts) where runs is limited by page_size, + total_count is derived from the status_counts GROUP BY, and status_counts is a + GROUP BY aggregation of statuses across all runs. + + If created_after/created_before are provided, filters by the workflow_script + entry's created_at (not the run's created_at), scoping to the version that + was active in that time window. + """ + async with self.Session() as session: + # Subquery: distinct run IDs for this script + run_ids_subquery = ( + select(distinct(WorkflowScriptModel.workflow_run_id)) + .filter_by(organization_id=organization_id, script_id=script_id) + .filter(WorkflowScriptModel.deleted_at.is_(None)) + .filter(WorkflowScriptModel.workflow_run_id.isnot(None)) + ) + + # Time-window filters scope by workflow_script creation time, + # which aligns with the script version that was created/used. + if created_after is not None: + run_ids_subquery = run_ids_subquery.filter( + WorkflowScriptModel.created_at >= created_after, + ) + if created_before is not None: + run_ids_subquery = run_ids_subquery.filter( + WorkflowScriptModel.created_at < created_before, + ) + + # Base filter for workflow runs + base_filters = [ + WorkflowRunModel.workflow_run_id.in_(run_ids_subquery), + WorkflowRunModel.organization_id == organization_id, + ] + + # Count statuses via GROUP BY (also gives us total_count) + status_query = ( + select(WorkflowRunModel.status, func.count()).filter(*base_filters).group_by(WorkflowRunModel.status) + ) + status_counts = {(s or "unknown"): c for s, c in (await session.execute(status_query)).all()} + total_count = sum(status_counts.values()) + + if total_count == 0: + return [], 0, {} + + # Get the actual workflow runs (paginated) + runs_query = ( + select(WorkflowRunModel) + .filter(*base_filters) + .order_by(WorkflowRunModel.created_at.desc()) + .limit(page_size) + ) + runs = list((await session.scalars(runs_query)).all()) + + return runs, total_count, status_counts + + @db_operation("get_script_run_stats") + async def get_script_run_stats( + self, + organization_id: str, + script_ids: list[str], + ) -> dict[str, tuple[float | None, int]]: + """Get success rate and total run count for each script_id. + + Both metrics are computed from the same population (workflow_scripts joined + to workflow_runs), so they are always consistent. + + Returns a dict mapping script_id -> (success_rate, total_runs) where + success_rate is 0.0-1.0 or None if no runs. + """ + if not script_ids: + return {} + async with self.Session() as session: + # Join workflow_scripts -> workflow_runs, group by script_id and status + query = ( + select( + WorkflowScriptModel.script_id, + WorkflowRunModel.status, + func.count(distinct(WorkflowRunModel.workflow_run_id)), + ) + .join( + WorkflowRunModel, + WorkflowScriptModel.workflow_run_id == WorkflowRunModel.workflow_run_id, + ) + .filter( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.script_id.in_(script_ids), + WorkflowScriptModel.deleted_at.is_(None), + WorkflowScriptModel.workflow_run_id.isnot(None), + WorkflowRunModel.organization_id == organization_id, + ) + .group_by(WorkflowScriptModel.script_id, WorkflowRunModel.status) + ) + rows = (await session.execute(query)).all() + + # Aggregate per script_id + totals: dict[str, int] = {} + completed: dict[str, int] = {} + for sid, status, count in rows: + totals[sid] = totals.get(sid, 0) + count + if status == "completed": + completed[sid] = completed.get(sid, 0) + count + + return { + sid: ( + (completed.get(sid, 0) / totals[sid]) if totals.get(sid) else None, + totals.get(sid, 0), + ) + for sid in script_ids + } + + # -- Script Pinning ---------------------------------------------------- + + @db_operation("is_script_pinned") + async def is_script_pinned( + self, + organization_id: str, + script_id: str, + ) -> bool: + """Check if any active workflow_script row for this script_id is pinned.""" + async with self.Session() as session: + query = ( + select(WorkflowScriptModel.is_pinned) + .where( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.script_id == script_id, + WorkflowScriptModel.is_pinned.is_(True), + WorkflowScriptModel.deleted_at.is_(None), + ) + .limit(1) + ) + result = await session.scalars(query) + return result.first() is not None + + @db_operation("pin_workflow_script") + async def pin_workflow_script( + self, + organization_id: str, + workflow_permanent_id: str, + cache_key_value: str, + pinned_by: str | None = None, + ) -> WorkflowScriptModel | None: + """Pin all workflow scripts for a given cache key value.""" + async with self.Session() as session: + stmt = ( + update(WorkflowScriptModel) + .where( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.cache_key_value == cache_key_value, + WorkflowScriptModel.deleted_at.is_(None), + ) + .values( + is_pinned=True, + pinned_at=datetime.now(timezone.utc), + pinned_by=pinned_by, + ) + ) + await session.execute(stmt) + await session.commit() + + # Return the first updated model for the response + query = ( + select(WorkflowScriptModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(cache_key_value=cache_key_value) + .filter_by(deleted_at=None) + .limit(1) + ) + result = await session.scalars(query) + return result.first() + + @db_operation("unpin_workflow_script") + async def unpin_workflow_script( + self, + organization_id: str, + workflow_permanent_id: str, + cache_key_value: str, + ) -> WorkflowScriptModel | None: + """Unpin workflow scripts for a given cache key value.""" + async with self.Session() as session: + stmt = ( + update(WorkflowScriptModel) + .where( + WorkflowScriptModel.organization_id == organization_id, + WorkflowScriptModel.workflow_permanent_id == workflow_permanent_id, + WorkflowScriptModel.cache_key_value == cache_key_value, + WorkflowScriptModel.deleted_at.is_(None), + ) + .values( + is_pinned=False, + pinned_at=None, + pinned_by=None, + ) + ) + await session.execute(stmt) + await session.commit() + + # Return the first updated model for the response + query = ( + select(WorkflowScriptModel) + .filter_by(organization_id=organization_id) + .filter_by(workflow_permanent_id=workflow_permanent_id) + .filter_by(cache_key_value=cache_key_value) + .filter_by(deleted_at=None) + .limit(1) + ) + result = await session.scalars(query) + return result.first() + + # -- Script Fallback Episode CRUD -------------------------------------- + + @db_operation("create_fallback_episode") + async def create_fallback_episode( + self, + organization_id: str, + workflow_permanent_id: str, + workflow_run_id: str, + block_label: str, + fallback_type: str, + script_revision_id: str | None = None, + error_message: str | None = None, + classify_result: str | None = None, + agent_actions: list | dict | None = None, + page_url: str | None = None, + page_text_snapshot: str | None = None, + ) -> ScriptFallbackEpisode: + async with self.Session() as session: + episode = ScriptFallbackEpisodeModel( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + workflow_run_id=workflow_run_id, + block_label=block_label, + fallback_type=fallback_type, + script_revision_id=script_revision_id, + error_message=sanitize_postgres_text(error_message) if error_message else None, + classify_result=sanitize_postgres_text(classify_result) if classify_result else None, + agent_actions=agent_actions, + page_url=sanitize_postgres_text(page_url) if page_url else None, + page_text_snapshot=sanitize_postgres_text(page_text_snapshot) if page_text_snapshot else None, + ) + session.add(episode) + await session.commit() + await session.refresh(episode) + return ScriptFallbackEpisode.model_validate(episode) + + @db_operation("get_unreviewed_episodes") + async def get_unreviewed_episodes( + self, + workflow_permanent_id: str, + organization_id: str, + limit: int = 100, + script_revision_id: str | None = None, + ) -> list[ScriptFallbackEpisode]: + async with self.Session() as session: + query = ( + select(ScriptFallbackEpisodeModel) + .filter_by( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + reviewed=False, + ) + .order_by(ScriptFallbackEpisodeModel.created_at.asc()) + .limit(limit) + ) + if script_revision_id: + query = query.filter_by(script_revision_id=script_revision_id) + episodes = (await session.scalars(query)).all() + return [ScriptFallbackEpisode.model_validate(e) for e in episodes] + + @db_operation("update_fallback_episode") + async def update_fallback_episode( + self, + episode_id: str, + organization_id: str, + agent_actions: list | dict | None = None, + fallback_succeeded: bool | None = None, + ) -> None: + values: dict = {} + if agent_actions is not None: + values["agent_actions"] = agent_actions + if fallback_succeeded is not None: + values["fallback_succeeded"] = fallback_succeeded + if not values: + return + values["modified_at"] = datetime.now(timezone.utc) + async with self.Session() as session: + await session.execute( + update(ScriptFallbackEpisodeModel) + .where(ScriptFallbackEpisodeModel.episode_id == episode_id) + .where(ScriptFallbackEpisodeModel.organization_id == organization_id) + .values(**values) + ) + await session.commit() + + @db_operation("delete_fallback_episode") + async def delete_fallback_episode( + self, + episode_id: str, + organization_id: str, + ) -> None: + async with self.Session() as session: + await session.execute( + delete(ScriptFallbackEpisodeModel) + .where(ScriptFallbackEpisodeModel.episode_id == episode_id) + .where(ScriptFallbackEpisodeModel.organization_id == organization_id) + ) + await session.commit() + + @db_operation("get_fallback_episodes") + async def get_fallback_episodes( + self, + organization_id: str, + workflow_permanent_id: str, + page: int = 1, + page_size: int = 20, + workflow_run_id: str | None = None, + block_label: str | None = None, + reviewed: bool | None = None, + fallback_type: str | None = None, + ) -> list[ScriptFallbackEpisode]: + async with self.Session() as session: + query = select(ScriptFallbackEpisodeModel).filter( + ScriptFallbackEpisodeModel.organization_id == organization_id, + ScriptFallbackEpisodeModel.workflow_permanent_id == workflow_permanent_id, + ) + if workflow_run_id is not None: + query = query.filter(ScriptFallbackEpisodeModel.workflow_run_id == workflow_run_id) + if block_label is not None: + query = query.filter(ScriptFallbackEpisodeModel.block_label == block_label) + if reviewed is not None: + query = query.filter(ScriptFallbackEpisodeModel.reviewed == reviewed) + if fallback_type is not None: + query = query.filter(ScriptFallbackEpisodeModel.fallback_type == fallback_type) + + offset = (page - 1) * page_size + query = query.order_by(ScriptFallbackEpisodeModel.created_at.desc()).limit(page_size).offset(offset) + + result = await session.scalars(query) + return [ScriptFallbackEpisode.model_validate(row) for row in result.all()] + + @db_operation("get_fallback_episodes_count") + async def get_fallback_episodes_count( + self, + organization_id: str, + workflow_permanent_id: str | None = None, + workflow_run_id: str | None = None, + block_label: str | None = None, + reviewed: bool | None = None, + fallback_type: str | None = None, + script_revision_id: str | None = None, + ) -> int: + """Count fallback episodes matching the given filters. + + At least one scoping filter (workflow_permanent_id, workflow_run_id, + or script_revision_id) should be provided. Without any, this returns + the total count for the entire organization which is rarely intended. + """ + if workflow_permanent_id is None and workflow_run_id is None and script_revision_id is None: + LOG.warning( + "get_fallback_episodes_count called without any scoping filter", + organization_id=organization_id, + ) + async with self.Session() as session: + query = ( + select(func.count()) + .select_from(ScriptFallbackEpisodeModel) + .filter( + ScriptFallbackEpisodeModel.organization_id == organization_id, + ) + ) + if workflow_permanent_id is not None: + query = query.filter(ScriptFallbackEpisodeModel.workflow_permanent_id == workflow_permanent_id) + if workflow_run_id is not None: + query = query.filter(ScriptFallbackEpisodeModel.workflow_run_id == workflow_run_id) + if block_label is not None: + query = query.filter(ScriptFallbackEpisodeModel.block_label == block_label) + if reviewed is not None: + query = query.filter(ScriptFallbackEpisodeModel.reviewed == reviewed) + if fallback_type is not None: + query = query.filter(ScriptFallbackEpisodeModel.fallback_type == fallback_type) + if script_revision_id is not None: + query = query.filter(ScriptFallbackEpisodeModel.script_revision_id == script_revision_id) + + result = await session.scalar(query) + return result or 0 + + @db_operation("get_fallback_episode") + async def get_fallback_episode( + self, + episode_id: str, + organization_id: str, + ) -> ScriptFallbackEpisode | None: + async with self.Session() as session: + query = select(ScriptFallbackEpisodeModel).filter( + ScriptFallbackEpisodeModel.episode_id == episode_id, + ScriptFallbackEpisodeModel.organization_id == organization_id, + ) + result = await session.scalar(query) + if result: + return ScriptFallbackEpisode.model_validate(result) + return None + + @db_operation("mark_episode_reviewed") + async def mark_episode_reviewed( + self, + episode_id: str, + organization_id: str, + reviewer_output: str | None = None, + new_script_revision_id: str | None = None, + ) -> None: + async with self.Session() as session: + await session.execute( + update(ScriptFallbackEpisodeModel) + .where(ScriptFallbackEpisodeModel.episode_id == episode_id) + .where(ScriptFallbackEpisodeModel.organization_id == organization_id) + .values( + reviewed=True, + reviewer_output=sanitize_postgres_text(reviewer_output) if reviewer_output else None, + new_script_revision_id=new_script_revision_id, + modified_at=datetime.now(timezone.utc), + ) + ) + await session.commit() + + @db_operation("get_recent_reviewed_episodes") + async def get_recent_reviewed_episodes( + self, + workflow_permanent_id: str, + organization_id: str, + limit: int = 20, + ) -> list[ScriptFallbackEpisode]: + """Return recently reviewed episodes for cross-run historical context. + + These give the reviewer visibility into past failures and fixes so it can + avoid repeating the same mistakes. + """ + async with self.Session() as session: + query = ( + select(ScriptFallbackEpisodeModel) + .filter_by( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + reviewed=True, + ) + .order_by(ScriptFallbackEpisodeModel.created_at.desc()) + .limit(limit) + ) + episodes = (await session.scalars(query)).all() + return [ScriptFallbackEpisode.model_validate(e) for e in episodes] + + @db_operation("record_branch_hit") + async def record_branch_hit( + self, + organization_id: str, + workflow_permanent_id: str, + block_label: str, + branch_key: str, + ) -> None: + """Record a classify branch hit, upserting the hit count and last_hit_at.""" + now = datetime.now(timezone.utc) + async with self.Session() as session: + stmt = insert(ScriptBranchHitModel).values( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + block_label=block_label, + branch_key=branch_key, + hit_count=1, + first_hit_at=now, + last_hit_at=now, + ) + stmt = stmt.on_conflict_do_update( + index_elements=[ + "organization_id", + "workflow_permanent_id", + "block_label", + "branch_key", + ], + set_={ + "hit_count": ScriptBranchHitModel.hit_count + 1, + "last_hit_at": now, + }, + ) + await session.execute(stmt) + await session.commit() + + @db_operation("get_stale_branches") + async def get_stale_branches( + self, + organization_id: str, + workflow_permanent_id: str, + stale_days: int = 90, + limit: int = 200, + ) -> list[ScriptBranchHit]: + """Get branches that haven't been accessed in stale_days days.""" + cutoff = datetime.now(timezone.utc) - timedelta(days=stale_days) + async with self.Session() as session: + query = ( + select(ScriptBranchHitModel) + .filter_by( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + ) + .filter(ScriptBranchHitModel.last_hit_at < cutoff) + .order_by(ScriptBranchHitModel.last_hit_at.asc()) + .limit(limit) + ) + results = (await session.scalars(query)).all() + return [ScriptBranchHit.model_validate(r) for r in results] diff --git a/skyvern/forge/sdk/db/repositories/tasks.py b/skyvern/forge/sdk/db/repositories/tasks.py new file mode 100644 index 000000000..4493bcc42 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/tasks.py @@ -0,0 +1,727 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Any, Sequence + +import structlog +from sqlalchemy import and_, delete, distinct, func, select, tuple_, update + +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_alchemy_db import read_retry +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.enums import TaskType +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + ActionModel, + StepModel, + TaskModel, + WorkflowRunModel, +) +from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location +from skyvern.forge.sdk.models import Step, StepStatus +from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus +from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text +from skyvern.schemas.runs import ProxyLocationInput +from skyvern.schemas.steps import AgentStepOutput +from skyvern.webeye.actions.actions import Action + +LOG = structlog.get_logger() + + +class TasksRepository(BaseRepository): + @db_operation("create_task") + async def create_task( + self, + url: str, + title: str | None, + navigation_goal: str | None, + data_extraction_goal: str | None, + navigation_payload: dict[str, Any] | list | str | None, + status: str = "created", + complete_criterion: str | None = None, + terminate_criterion: str | None = None, + webhook_callback_url: str | None = None, + totp_verification_url: str | None = None, + totp_identifier: str | None = None, + organization_id: str | None = None, + proxy_location: ProxyLocationInput = None, + extracted_information_schema: dict[str, Any] | list | str | None = None, + workflow_run_id: str | None = None, + order: int | None = None, + retry: int | None = None, + max_steps_per_run: int | None = None, + error_code_mapping: dict[str, str] | None = None, + task_type: str = TaskType.general, + application: str | None = None, + include_action_history_in_verification: bool | None = None, + model: dict[str, Any] | None = None, + max_screenshot_scrolling_times: int | None = None, + extra_http_headers: dict[str, str] | None = None, + browser_session_id: str | None = None, + browser_address: str | None = None, + download_timeout: float | None = None, + ) -> Task: + # Sanitize text fields to remove NUL bytes and control characters + # that PostgreSQL cannot store in text columns + def _sanitize(v: str | None) -> str | None: + return sanitize_postgres_text(v) if isinstance(v, str) else v + + navigation_goal = _sanitize(navigation_goal) + data_extraction_goal = _sanitize(data_extraction_goal) + title = _sanitize(title) + url = sanitize_postgres_text(url) + complete_criterion = _sanitize(complete_criterion) + terminate_criterion = _sanitize(terminate_criterion) + + async with self.Session() as session: + new_task = TaskModel( + status=status, + task_type=task_type, + url=url, + title=title, + webhook_callback_url=webhook_callback_url, + totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, + navigation_goal=navigation_goal, + complete_criterion=complete_criterion, + terminate_criterion=terminate_criterion, + data_extraction_goal=data_extraction_goal, + navigation_payload=navigation_payload, + organization_id=organization_id, + proxy_location=serialize_proxy_location(proxy_location), + extracted_information_schema=extracted_information_schema, + workflow_run_id=workflow_run_id, + order=order, + retry=retry, + max_steps_per_run=max_steps_per_run, + error_code_mapping=error_code_mapping, + application=application, + include_action_history_in_verification=include_action_history_in_verification, + model=model, + max_screenshot_scrolling_times=max_screenshot_scrolling_times, + extra_http_headers=extra_http_headers, + browser_session_id=browser_session_id, + browser_address=browser_address, + download_timeout=download_timeout, + ) + session.add(new_task) + await session.commit() + await session.refresh(new_task) + return convert_to_task(new_task, self.debug_enabled) + + @db_operation("create_step") + async def create_step( + self, + task_id: str, + order: int, + retry_index: int, + organization_id: str | None = None, + status: StepStatus = StepStatus.created, + created_by: str | None = None, + ) -> Step: + async with self.Session() as session: + new_step = StepModel( + task_id=task_id, + order=order, + retry_index=retry_index, + status=status, + organization_id=organization_id, + created_by=created_by, + ) + session.add(new_step) + await session.commit() + await session.refresh(new_step) + return convert_to_step(new_step, debug_enabled=self.debug_enabled) + + @read_retry() + @db_operation("get_task", log_errors=False) + async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None: + """Get a task by its id""" + async with self.Session() as session: + query = select(TaskModel).filter_by(task_id=task_id) + if organization_id is not None: + query = query.filter_by(organization_id=organization_id) + if task_obj := (await session.scalars(query)).first(): + return convert_to_task(task_obj, self.debug_enabled) + else: + LOG.info( + "Task not found", + task_id=task_id, + organization_id=organization_id, + ) + return None + + @db_operation("get_tasks_by_ids") + async def get_tasks_by_ids( + self, + task_ids: list[str], + organization_id: str, + ) -> list[Task]: + async with self.Session() as session: + tasks = ( + await session.scalars( + select(TaskModel).filter(TaskModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks] + + @db_operation("get_step") + async def get_step(self, step_id: str, organization_id: str | None = None) -> Step | None: + async with self.Session() as session: + if step := ( + await session.scalars( + select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id) + ) + ).first(): + return convert_to_step(step, debug_enabled=self.debug_enabled) + + else: + return None + + @db_operation("get_task_steps") + async def get_task_steps(self, task_id: str, organization_id: str) -> list[Step]: + async with self.Session() as session: + if steps := ( + await session.scalars( + select(StepModel) + .filter_by(task_id=task_id) + .filter_by(organization_id=organization_id) + .order_by(StepModel.order) + .order_by(StepModel.retry_index) + ) + ).all(): + return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps] + else: + return [] + + @db_operation("get_steps_by_task_ids") + async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]: + async with self.Session() as session: + steps = ( + await session.scalars( + select(StepModel).filter(StepModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps] + + @db_operation("get_total_unique_step_order_count_by_task_ids") + async def get_total_unique_step_order_count_by_task_ids( + self, + *, + task_ids: list[str], + organization_id: str, + ) -> int: + """ + Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids + Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s + where s.task_id in task_ids + """ + async with self.Session() as session: + query = ( + select(func.count(distinct(tuple_(StepModel.task_id, StepModel.order)))) + .where(StepModel.task_id.in_(task_ids)) + .where(StepModel.organization_id == organization_id) + ) + return (await session.execute(query)).scalar() + + @db_operation("get_task_step_models") + async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]: + async with self.Session() as session: + return ( + await session.scalars( + select(StepModel) + .filter_by(task_id=task_id) + .filter_by(organization_id=organization_id) + .order_by(StepModel.order) + .order_by(StepModel.retry_index) + ) + ).all() + + @db_operation("get_task_step_count") + async def get_task_step_count(self, task_id: str, organization_id: str | None = None) -> int: + async with self.Session() as session: + result = await session.scalar( + select(func.count(StepModel.step_id)) + .filter_by(task_id=task_id) + .filter_by(organization_id=organization_id) + ) + return result or 0 + + @db_operation("get_task_actions") + async def get_task_actions(self, task_id: str, organization_id: str | None = None) -> list[Action]: + async with self.Session() as session: + query = ( + select(ActionModel) + .filter(ActionModel.organization_id == organization_id) + .filter(ActionModel.task_id == task_id) + .order_by(ActionModel.created_at) + ) + + actions = (await session.scalars(query)).all() + return [Action.model_validate(action) for action in actions] + + @db_operation("get_task_actions_hydrated") + async def get_task_actions_hydrated(self, task_id: str, organization_id: str | None = None) -> list[Action]: + async with self.Session() as session: + query = ( + select(ActionModel) + .filter(ActionModel.organization_id == organization_id) + .filter(ActionModel.task_id == task_id) + .order_by(ActionModel.created_at) + ) + + actions = (await session.scalars(query)).all() + return [hydrate_action(action) for action in actions] + + @db_operation("get_tasks_actions") + async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]: + async with self.Session() as session: + query = ( + select(ActionModel) + .filter(ActionModel.organization_id == organization_id) + .filter(ActionModel.task_id.in_(task_ids)) + .order_by(ActionModel.created_at.desc()) + ) + actions = (await session.scalars(query)).all() + return [hydrate_action(action) for action in actions] + + @db_operation("get_action_count_for_step") + async def get_action_count_for_step(self, step_id: str, task_id: str, organization_id: str) -> int: + """Get count of actions for a step. Uses composite index for efficiency.""" + async with self.Session() as session: + query = ( + select(func.count()) + .select_from(ActionModel) + .where(ActionModel.organization_id == organization_id) + .where(ActionModel.task_id == task_id) + .where(ActionModel.step_id == step_id) + ) + result = await session.scalar(query) + return result or 0 + + @db_operation("get_first_step") + async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None: + async with self.Session() as session: + if step := ( + await session.scalars( + select(StepModel) + .filter_by(task_id=task_id) + .filter_by(organization_id=organization_id) + .order_by(StepModel.order.asc()) + .order_by(StepModel.retry_index.asc()) + ) + ).first(): + return convert_to_step(step, debug_enabled=self.debug_enabled) + else: + LOG.info( + "Latest step not found", + task_id=task_id, + organization_id=organization_id, + ) + return None + + @db_operation("get_latest_step") + async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None: + async with self.Session() as session: + if step := ( + await session.scalars( + select(StepModel) + .filter_by(task_id=task_id) + .filter_by(organization_id=organization_id) + .filter(StepModel.status != StepStatus.canceled) + .order_by(StepModel.order.desc()) + .order_by(StepModel.retry_index.desc()) + ) + ).first(): + return convert_to_step(step, debug_enabled=self.debug_enabled) + else: + LOG.info( + "Latest step not found", + task_id=task_id, + organization_id=organization_id, + ) + return None + + @db_operation("update_step") + async def update_step( + self, + task_id: str, + step_id: str, + status: StepStatus | None = None, + output: AgentStepOutput | None = None, + is_last: bool | None = None, + retry_index: int | None = None, + organization_id: str | None = None, + incremental_cost: float | None = None, + incremental_input_tokens: int | None = None, + incremental_output_tokens: int | None = None, + incremental_reasoning_tokens: int | None = None, + incremental_cached_tokens: int | None = None, + created_by: str | None = None, + ) -> Step: + async with self.Session() as session: + if step := ( + await session.scalars( + select(StepModel) + .filter_by(task_id=task_id) + .filter_by(step_id=step_id) + .filter_by(organization_id=organization_id) + ) + ).first(): + if status is not None: + step.status = status + + if status.is_terminal() and step.finished_at is None: + step.finished_at = datetime.now(timezone.utc) + if output is not None: + step.output = output.model_dump(exclude_none=True) + if is_last is not None: + step.is_last = is_last + if retry_index is not None: + step.retry_index = retry_index + if incremental_cost is not None: + step.step_cost = incremental_cost + float(step.step_cost or 0) + if incremental_input_tokens is not None: + step.input_token_count = incremental_input_tokens + (step.input_token_count or 0) + if incremental_output_tokens is not None: + step.output_token_count = incremental_output_tokens + (step.output_token_count or 0) + if incremental_reasoning_tokens is not None: + step.reasoning_token_count = incremental_reasoning_tokens + (step.reasoning_token_count or 0) + if incremental_cached_tokens is not None: + step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0) + if created_by is not None: + step.created_by = created_by + + await session.commit() + updated_step = await self.get_step(step_id, organization_id) + if not updated_step: + raise NotFoundError("Step not found") + return updated_step + else: + raise NotFoundError("Step not found") + + @db_operation("clear_task_failure_reason") + async def clear_task_failure_reason(self, organization_id: str, task_id: str) -> Task: + async with self.Session() as session: + if task := ( + await session.scalars( + select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first(): + task.failure_reason = None + await session.commit() + await session.refresh(task) + return convert_to_task(task, debug_enabled=self.debug_enabled) + else: + raise NotFoundError("Task not found") + + @db_operation("update_task") + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + extracted_information: dict[str, Any] | list | str | None = None, + webhook_failure_reason: str | None = None, + failure_reason: str | None = None, + errors: list[dict[str, Any]] | None = None, + max_steps_per_run: int | None = None, + organization_id: str | None = None, + ) -> Task: + if ( + status is None + and extracted_information is None + and failure_reason is None + and errors is None + and max_steps_per_run is None + and webhook_failure_reason is None + ): + raise ValueError( + "At least one of status, extracted_information, or failure_reason must be provided to update the task" + ) + async with self.Session() as session: + if task := ( + await session.scalars( + select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first(): + if status is not None: + task.status = status + if status == TaskStatus.queued and task.queued_at is None: + task.queued_at = datetime.now(timezone.utc) + if status == TaskStatus.running and task.started_at is None: + task.started_at = datetime.now(timezone.utc) + if status.is_final() and task.finished_at is None: + task.finished_at = datetime.now(timezone.utc) + if extracted_information is not None: + task.extracted_information = extracted_information + if failure_reason is not None: + task.failure_reason = failure_reason + if errors is not None: + task.errors = (task.errors or []) + errors + if max_steps_per_run is not None: + task.max_steps_per_run = max_steps_per_run + if webhook_failure_reason is not None: + task.webhook_failure_reason = webhook_failure_reason + await session.commit() + updated_task = await self.get_task(task_id, organization_id=organization_id) + if not updated_task: + raise NotFoundError("Task not found") + return updated_task + else: + raise NotFoundError("Task not found") + + @db_operation("update_task_2fa_state") + async def update_task_2fa_state( + self, + task_id: str, + organization_id: str, + waiting_for_verification_code: bool, + verification_code_identifier: str | None = None, + verification_code_polling_started_at: datetime | None = None, + ) -> Task: + """Update task 2FA verification code waiting state.""" + async with self.Session() as session: + if task := ( + await session.scalars( + select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first(): + task.waiting_for_verification_code = waiting_for_verification_code + if verification_code_identifier is not None: + task.verification_code_identifier = verification_code_identifier + if verification_code_polling_started_at is not None: + task.verification_code_polling_started_at = verification_code_polling_started_at + if not waiting_for_verification_code: + # Clear identifiers when no longer waiting + task.verification_code_identifier = None + task.verification_code_polling_started_at = None + await session.commit() + updated_task = await self.get_task(task_id, organization_id=organization_id) + if not updated_task: + raise NotFoundError("Task not found") + return updated_task + else: + raise NotFoundError("Task not found") + + @db_operation("bulk_update_tasks") + async def bulk_update_tasks( + self, + task_ids: list[str], + status: TaskStatus | None = None, + failure_reason: str | None = None, + ) -> None: + """Bulk update tasks by their IDs. + + Args: + task_ids: List of task IDs to update + status: Optional status to set for all tasks + failure_reason: Optional failure reason to set for all tasks + """ + if not task_ids: + return + + async with self.Session() as session: + update_values = {} + if status: + update_values["status"] = status.value + if failure_reason: + update_values["failure_reason"] = failure_reason + + if update_values: + update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values) + await session.execute(update_stmt) + await session.commit() + + @db_operation("get_tasks") + async def get_tasks( + self, + page: int = 1, + page_size: int = 10, + task_status: list[TaskStatus] | None = None, + workflow_run_id: str | None = None, + organization_id: str | None = None, + only_standalone_tasks: bool = False, + application: str | None = None, + order_by_column: OrderBy = OrderBy.created_at, + order: SortDirection = SortDirection.desc, + ) -> list[Task]: + """ + Get all tasks. + :param page: Starts at 1 + :param page_size: + :param task_status: + :param workflow_run_id: + :param only_standalone_tasks: + :param order_by_column: + :param order: + :return: + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + + async with self.Session() as session: + db_page = page - 1 # offset logic is 0 based + query = ( + select(TaskModel, WorkflowRunModel.workflow_permanent_id) + .join(WorkflowRunModel, TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id, isouter=True) + .filter(TaskModel.organization_id == organization_id) + ) + if task_status: + query = query.filter(TaskModel.status.in_(task_status)) + if workflow_run_id: + query = query.filter(TaskModel.workflow_run_id == workflow_run_id) + if only_standalone_tasks: + query = query.filter(TaskModel.workflow_run_id.is_(None)) + if application: + query = query.filter(TaskModel.application == application) + order_by_col = getattr(TaskModel, order_by_column) + query = ( + query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc()) + .limit(page_size) + .offset(db_page * page_size) + ) + + results = (await session.execute(query)).all() + + return [ + convert_to_task(task, debug_enabled=self.debug_enabled, workflow_permanent_id=workflow_permanent_id) + for task, workflow_permanent_id in results + ] + + @db_operation("get_tasks_count") + async def get_tasks_count( + self, + organization_id: str, + task_status: list[TaskStatus] | None = None, + workflow_run_id: str | None = None, + only_standalone_tasks: bool = False, + application: str | None = None, + ) -> int: + async with self.Session() as session: + count_query = ( + select(func.count()).select_from(TaskModel).filter(TaskModel.organization_id == organization_id) + ) + if task_status: + count_query = count_query.filter(TaskModel.status.in_(task_status)) + if workflow_run_id: + count_query = count_query.filter(TaskModel.workflow_run_id == workflow_run_id) + if only_standalone_tasks: + count_query = count_query.filter(TaskModel.workflow_run_id.is_(None)) + if application: + count_query = count_query.filter(TaskModel.application == application) + return (await session.execute(count_query)).scalar_one() + + @db_operation("get_running_tasks_info_globally") + async def get_running_tasks_info_globally( + self, + stale_threshold_hours: int = 24, + ) -> tuple[int, int]: + """ + Get information about running tasks across all organizations. + Used by cleanup service to determine if cleanup should be skipped. + + Args: + stale_threshold_hours: Tasks not updated for this many hours are considered stale. + + Returns: + Tuple of (active_task_count, stale_task_count). + Active tasks are those updated within the threshold. + Stale tasks are those not updated within the threshold but still in running status. + """ + async with self.Session() as session: + running_statuses = [TaskStatus.created, TaskStatus.queued, TaskStatus.running] + stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=stale_threshold_hours) + + # Count active tasks (recently updated) + active_query = ( + select(func.count()) + .select_from(TaskModel) + .filter(TaskModel.status.in_(running_statuses)) + .filter(TaskModel.modified_at >= stale_cutoff) + ) + active_count = (await session.execute(active_query)).scalar_one() + + # Count stale tasks (not updated for a long time) + stale_query = ( + select(func.count()) + .select_from(TaskModel) + .filter(TaskModel.status.in_(running_statuses)) + .filter(TaskModel.modified_at < stale_cutoff) + ) + stale_count = (await session.execute(stale_query)).scalar_one() + + return (active_count, stale_count) + + @db_operation("get_latest_task_by_workflow_id") + async def get_latest_task_by_workflow_id( + self, + organization_id: str, + workflow_id: str, + before: datetime | None = None, + ) -> Task | None: + async with self.Session() as session: + query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id) + if before: + query = query.filter(TaskModel.created_at < before) + task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first() + if task: + return convert_to_task(task, debug_enabled=self.debug_enabled) + return None + + @db_operation("get_last_task_for_workflow_run") + async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None: + async with self.Session() as session: + if task := ( + await session.scalars( + select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at.desc()) + ) + ).first(): + return convert_to_task(task, debug_enabled=self.debug_enabled) + return None + + @db_operation("get_tasks_by_workflow_run_id") + async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]: + async with self.Session() as session: + tasks = ( + await session.scalars( + select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at) + ) + ).all() + return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks] + + @db_operation("delete_task_steps") + async def delete_task_steps(self, organization_id: str, task_id: str) -> None: + async with self.Session() as session: + # delete artifacts by filtering organization_id and task_id + stmt = delete(StepModel).where( + and_( + StepModel.organization_id == organization_id, + StepModel.task_id == task_id, + ) + ) + await session.execute(stmt) + await session.commit() + + @db_operation("get_previous_actions_for_task") + async def get_previous_actions_for_task(self, task_id: str) -> list[Action]: + async with self.Session() as session: + query = ( + select(ActionModel) + .filter_by(task_id=task_id) + .order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at) + ) + actions = (await session.scalars(query)).all() + return [Action.model_validate(action) for action in actions] + + @db_operation("delete_task_actions") + async def delete_task_actions(self, organization_id: str, task_id: str) -> None: + async with self.Session() as session: + # delete actions by filtering organization_id and task_id + stmt = delete(ActionModel).where( + and_( + ActionModel.organization_id == organization_id, + ActionModel.task_id == task_id, + ) + ) + await session.execute(stmt) + await session.commit() diff --git a/skyvern/forge/sdk/db/repositories/workflow_parameters.py b/skyvern/forge/sdk/db/repositories/workflow_parameters.py new file mode 100644 index 000000000..004ffef5a --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/workflow_parameters.py @@ -0,0 +1,699 @@ +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from typing import Any + +import structlog +from sqlalchemy import select + +from skyvern.config import settings +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + ActionModel, + AISuggestionModel, + AWSSecretParameterModel, + AzureVaultCredentialParameterModel, + Base, + BitwardenCreditCardDataParameterModel, + BitwardenLoginCredentialParameterModel, + BitwardenSensitiveInformationParameterModel, + CredentialParameterModel, + OnePasswordCredentialParameterModel, + OutputParameterModel, + TaskGenerationModel, + TaskModel, + TaskRunModel, + WorkflowCopilotChatMessageModel, + WorkflowCopilotChatModel, + WorkflowParameterModel, +) +from skyvern.forge.sdk.db.utils import ( + convert_to_aws_secret_parameter, + convert_to_output_parameter, + convert_to_workflow_copilot_chat_message, + convert_to_workflow_parameter, + hydrate_action, +) +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion +from skyvern.forge.sdk.schemas.runs import Run +from skyvern.forge.sdk.schemas.task_generations import TaskGeneration +from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus +from skyvern.forge.sdk.schemas.workflow_copilot import ( + WorkflowCopilotChat, + WorkflowCopilotChatMessage, + WorkflowCopilotChatSender, +) +from skyvern.forge.sdk.workflow.models.parameter import ( + PARAMETER_TYPE, + AWSSecretParameter, + AzureVaultCredentialParameter, + BitwardenCreditCardDataParameter, + BitwardenLoginCredentialParameter, + BitwardenSensitiveInformationParameter, + ContextParameter, + CredentialParameter, + OnePasswordCredentialParameter, + OutputParameter, + WorkflowParameter, + WorkflowParameterType, +) +from skyvern.schemas.runs import RunType +from skyvern.webeye.actions.actions import Action + +from . import _UNSET + +LOG = structlog.get_logger() + + +class WorkflowParametersRepository(BaseRepository): + """Database operations for workflow parameters, copilot chat, task generation, actions, and runs.""" + + @db_operation("create_workflow_parameter") + async def create_workflow_parameter( + self, + workflow_id: str, + workflow_parameter_type: WorkflowParameterType, + key: str, + default_value: Any, + description: str | None = None, + ) -> WorkflowParameter: + async with self.Session() as session: + if default_value is None: + pass + elif workflow_parameter_type == WorkflowParameterType.JSON: + default_value = json.dumps(default_value) + else: + default_value = str(default_value) + workflow_parameter = WorkflowParameterModel( + workflow_id=workflow_id, + workflow_parameter_type=workflow_parameter_type, + key=key, + default_value=default_value, + description=description, + ) + session.add(workflow_parameter) + await session.commit() + await session.refresh(workflow_parameter) + return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled) + + @db_operation("create_aws_secret_parameter") + async def create_aws_secret_parameter( + self, + workflow_id: str, + key: str, + aws_key: str, + description: str | None = None, + ) -> AWSSecretParameter: + async with self.Session() as session: + aws_secret_parameter = AWSSecretParameterModel( + workflow_id=workflow_id, + key=key, + aws_key=aws_key, + description=description, + ) + session.add(aws_secret_parameter) + await session.commit() + await session.refresh(aws_secret_parameter) + return convert_to_aws_secret_parameter(aws_secret_parameter) + + @db_operation("create_output_parameter") + async def create_output_parameter( + self, + workflow_id: str, + key: str, + description: str | None = None, + ) -> OutputParameter: + async with self.Session() as session: + output_parameter = OutputParameterModel( + key=key, + description=description, + workflow_id=workflow_id, + ) + session.add(output_parameter) + await session.commit() + await session.refresh(output_parameter) + return convert_to_output_parameter(output_parameter) + + @staticmethod + def _convert_parameter_to_model(parameter: PARAMETER_TYPE) -> Base: + """Convert a parameter object to its corresponding SQLAlchemy model.""" + if isinstance(parameter, WorkflowParameter): + if parameter.default_value is None: + default_value = None + elif parameter.workflow_parameter_type == WorkflowParameterType.JSON: + default_value = json.dumps(parameter.default_value) + else: + default_value = str(parameter.default_value) + return WorkflowParameterModel( + workflow_parameter_id=parameter.workflow_parameter_id, + workflow_parameter_type=parameter.workflow_parameter_type.value, + key=parameter.key, + description=parameter.description, + workflow_id=parameter.workflow_id, + default_value=default_value, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, OutputParameter): + return OutputParameterModel( + output_parameter_id=parameter.output_parameter_id, + key=parameter.key, + description=parameter.description, + workflow_id=parameter.workflow_id, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, AWSSecretParameter): + return AWSSecretParameterModel( + aws_secret_parameter_id=parameter.aws_secret_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + aws_key=parameter.aws_key, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, BitwardenLoginCredentialParameter): + return BitwardenLoginCredentialParameterModel( + bitwarden_login_credential_parameter_id=parameter.bitwarden_login_credential_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key, + bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key, + bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key, + bitwarden_collection_id=parameter.bitwarden_collection_id, + bitwarden_item_id=parameter.bitwarden_item_id, + url_parameter_key=parameter.url_parameter_key, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, BitwardenSensitiveInformationParameter): + return BitwardenSensitiveInformationParameterModel( + bitwarden_sensitive_information_parameter_id=parameter.bitwarden_sensitive_information_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key, + bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key, + bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key, + bitwarden_collection_id=parameter.bitwarden_collection_id, + bitwarden_identity_key=parameter.bitwarden_identity_key, + bitwarden_identity_fields=parameter.bitwarden_identity_fields, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, BitwardenCreditCardDataParameter): + return BitwardenCreditCardDataParameterModel( + bitwarden_credit_card_data_parameter_id=parameter.bitwarden_credit_card_data_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key, + bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key, + bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key, + bitwarden_collection_id=parameter.bitwarden_collection_id, + bitwarden_item_id=parameter.bitwarden_item_id, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, CredentialParameter): + return CredentialParameterModel( + credential_parameter_id=parameter.credential_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + credential_id=parameter.credential_id, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, OnePasswordCredentialParameter): + return OnePasswordCredentialParameterModel( + onepassword_credential_parameter_id=parameter.onepassword_credential_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + vault_id=parameter.vault_id, + item_id=parameter.item_id, + deleted_at=parameter.deleted_at, + ) + elif isinstance(parameter, AzureVaultCredentialParameter): + return AzureVaultCredentialParameterModel( + azure_vault_credential_parameter_id=parameter.azure_vault_credential_parameter_id, + workflow_id=parameter.workflow_id, + key=parameter.key, + description=parameter.description, + vault_name=parameter.vault_name, + username_key=parameter.username_key, + password_key=parameter.password_key, + totp_secret_key=parameter.totp_secret_key, + deleted_at=parameter.deleted_at, + ) + else: + raise ValueError(f"Unsupported workflow definition parameter type: {type(parameter).__name__}") + + @db_operation("save_workflow_definition_parameters") + async def save_workflow_definition_parameters(self, parameters: list[PARAMETER_TYPE]) -> None: + """Save multiple workflow definition parameters in a single transaction.""" + + # ContextParameter is not persisted + parameters_to_save = [p for p in parameters if not isinstance(p, ContextParameter)] + if not parameters_to_save: + return + + async with self.Session() as session: + for parameter in parameters_to_save: + model = self._convert_parameter_to_model(parameter) + session.add(model) + await session.commit() + + @db_operation("get_workflow_output_parameters") + async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]: + async with self.Session() as session: + output_parameters = ( + await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id)) + ).all() + return [convert_to_output_parameter(parameter) for parameter in output_parameters] + + @db_operation("get_workflow_output_parameters_by_ids") + async def get_workflow_output_parameters_by_ids(self, output_parameter_ids: list[str]) -> list[OutputParameter]: + async with self.Session() as session: + output_parameters = ( + await session.scalars( + select(OutputParameterModel).filter( + OutputParameterModel.output_parameter_id.in_(output_parameter_ids) + ) + ) + ).all() + return [convert_to_output_parameter(parameter) for parameter in output_parameters] + + @db_operation("get_workflow_parameters") + async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]: + async with self.Session() as session: + workflow_parameters = ( + await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id)) + ).all() + return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters] + + @db_operation("get_workflow_parameter") + async def get_workflow_parameter( + self, workflow_parameter_id: str, organization_id: str | None = None + ) -> WorkflowParameter | None: + async with self.Session() as session: + if workflow_parameter := ( + await session.scalars( + select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id) + ) + ).first(): + return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled) + return None + + @db_operation("create_task_generation") + async def create_task_generation( + self, + organization_id: str, + user_prompt: str, + user_prompt_hash: str, + url: str | None = None, + navigation_goal: str | None = None, + navigation_payload: dict[str, Any] | None = None, + data_extraction_goal: str | None = None, + extracted_information_schema: dict[str, Any] | None = None, + suggested_title: str | None = None, + llm: str | None = None, + llm_prompt: str | None = None, + llm_response: str | None = None, + source_task_generation_id: str | None = None, + ) -> TaskGeneration: + async with self.Session() as session: + new_task_generation = TaskGenerationModel( + organization_id=organization_id, + user_prompt=user_prompt, + user_prompt_hash=user_prompt_hash, + url=url, + navigation_goal=navigation_goal, + navigation_payload=navigation_payload, + data_extraction_goal=data_extraction_goal, + extracted_information_schema=extracted_information_schema, + llm=llm, + llm_prompt=llm_prompt, + llm_response=llm_response, + suggested_title=suggested_title, + source_task_generation_id=source_task_generation_id, + ) + session.add(new_task_generation) + await session.commit() + await session.refresh(new_task_generation) + return TaskGeneration.model_validate(new_task_generation) + + @db_operation("create_ai_suggestion") + async def create_ai_suggestion( + self, + organization_id: str, + ai_suggestion_type: str, + ) -> AISuggestion: + async with self.Session() as session: + new_ai_suggestion = AISuggestionModel( + organization_id=organization_id, + ai_suggestion_type=ai_suggestion_type, + ) + session.add(new_ai_suggestion) + await session.commit() + await session.refresh(new_ai_suggestion) + return AISuggestion.model_validate(new_ai_suggestion) + + @db_operation("create_workflow_copilot_chat") + async def create_workflow_copilot_chat( + self, + organization_id: str, + workflow_permanent_id: str, + ) -> WorkflowCopilotChat: + async with self.Session() as session: + new_chat = WorkflowCopilotChatModel( + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + ) + session.add(new_chat) + await session.commit() + await session.refresh(new_chat) + return WorkflowCopilotChat.model_validate(new_chat) + + @db_operation("update_workflow_copilot_chat") + async def update_workflow_copilot_chat( + self, + organization_id: str, + workflow_copilot_chat_id: str, + proposed_workflow: dict | None | object = _UNSET, + auto_accept: bool | None = None, + ) -> WorkflowCopilotChat | None: + async with self.Session() as session: + chat = ( + await session.scalars( + select(WorkflowCopilotChatModel) + .where(WorkflowCopilotChatModel.organization_id == organization_id) + .where(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id) + ) + ).first() + if not chat: + return None + + if proposed_workflow is not _UNSET: + chat.proposed_workflow = proposed_workflow + if auto_accept is not None: + chat.auto_accept = auto_accept + + await session.commit() + await session.refresh(chat) + return WorkflowCopilotChat.model_validate(chat) + + @db_operation("create_workflow_copilot_chat_message") + async def create_workflow_copilot_chat_message( + self, + organization_id: str, + workflow_copilot_chat_id: str, + sender: WorkflowCopilotChatSender, + content: str, + global_llm_context: str | None = None, + ) -> WorkflowCopilotChatMessage: + async with self.Session() as session: + new_message = WorkflowCopilotChatMessageModel( + workflow_copilot_chat_id=workflow_copilot_chat_id, + organization_id=organization_id, + sender=sender, + content=content, + global_llm_context=global_llm_context, + ) + session.add(new_message) + await session.commit() + await session.refresh(new_message) + return convert_to_workflow_copilot_chat_message(new_message, self.debug_enabled) + + @db_operation("get_workflow_copilot_chat_messages") + async def get_workflow_copilot_chat_messages( + self, + workflow_copilot_chat_id: str, + ) -> list[WorkflowCopilotChatMessage]: + async with self.Session() as session: + query = ( + select(WorkflowCopilotChatMessageModel) + .filter(WorkflowCopilotChatMessageModel.workflow_copilot_chat_id == workflow_copilot_chat_id) + .order_by(WorkflowCopilotChatMessageModel.workflow_copilot_chat_message_id.asc()) + ) + messages = (await session.scalars(query)).all() + return [convert_to_workflow_copilot_chat_message(message, self.debug_enabled) for message in messages] + + @db_operation("get_workflow_copilot_chat_by_id") + async def get_workflow_copilot_chat_by_id( + self, + organization_id: str, + workflow_copilot_chat_id: str, + ) -> WorkflowCopilotChat | None: + async with self.Session() as session: + query = ( + select(WorkflowCopilotChatModel) + .filter(WorkflowCopilotChatModel.organization_id == organization_id) + .filter(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id) + .order_by(WorkflowCopilotChatModel.created_at.desc()) + .limit(1) + ) + chat = (await session.scalars(query)).first() + if not chat: + return None + return WorkflowCopilotChat.model_validate(chat) + + @db_operation("get_latest_workflow_copilot_chat") + async def get_latest_workflow_copilot_chat( + self, + organization_id: str, + workflow_permanent_id: str, + ) -> WorkflowCopilotChat | None: + async with self.Session() as session: + query = ( + select(WorkflowCopilotChatModel) + .filter(WorkflowCopilotChatModel.organization_id == organization_id) + .filter(WorkflowCopilotChatModel.workflow_permanent_id == workflow_permanent_id) + .order_by(WorkflowCopilotChatModel.created_at.desc()) + .limit(1) + ) + chat = (await session.scalars(query)).first() + if not chat: + return None + return WorkflowCopilotChat.model_validate(chat) + + @db_operation("get_task_generation_by_prompt_hash") + async def get_task_generation_by_prompt_hash( + self, + user_prompt_hash: str, + query_window_hours: int = settings.PROMPT_CACHE_WINDOW_HOURS, + ) -> TaskGeneration | None: + before_time = datetime.now(timezone.utc) - timedelta(hours=query_window_hours) + async with self.Session() as session: + query = ( + select(TaskGenerationModel) + .filter_by(user_prompt_hash=user_prompt_hash) + .filter(TaskGenerationModel.llm.is_not(None)) + .filter(TaskGenerationModel.created_at > before_time) + ) + task_generation = (await session.scalars(query)).first() + if not task_generation: + return None + return TaskGeneration.model_validate(task_generation) + + @db_operation("create_action") + async def create_action(self, action: Action) -> Action: + async with self.Session() as session: + new_action = ActionModel( + action_type=action.action_type, + source_action_id=action.source_action_id, + organization_id=action.organization_id, + workflow_run_id=action.workflow_run_id, + task_id=action.task_id, + step_id=action.step_id, + step_order=action.step_order, + action_order=action.action_order, + status=action.status, + reasoning=action.reasoning, + intention=action.intention, + response=action.response, + element_id=action.element_id, + skyvern_element_hash=action.skyvern_element_hash, + skyvern_element_data=action.skyvern_element_data, + screenshot_artifact_id=action.screenshot_artifact_id, + action_json=action.model_dump(), + confidence_float=action.confidence_float, + created_by=action.created_by, + ) + session.add(new_action) + await session.commit() + await session.refresh(new_action) + return hydrate_action(new_action) + + @db_operation("update_action_reasoning") + async def update_action_reasoning( + self, + organization_id: str, + action_id: str, + reasoning: str, + ) -> Action: + async with self.Session() as session: + action = ( + await session.scalars( + select(ActionModel).filter_by(action_id=action_id).filter_by(organization_id=organization_id) + ) + ).first() + if action: + action.reasoning = reasoning + await session.commit() + await session.refresh(action) + return Action.model_validate(action) + raise NotFoundError(f"Action {action_id}") + + @db_operation("retrieve_action_plan") + async def retrieve_action_plan(self, task: Task) -> list[Action]: + async with self.Session() as session: + subquery = ( + select(TaskModel.task_id) + .filter(TaskModel.url == task.url) + .filter(TaskModel.navigation_goal == task.navigation_goal) + .filter(TaskModel.status == TaskStatus.completed) + .order_by(TaskModel.created_at.desc()) + .limit(1) + .subquery() + ) + + query = ( + select(ActionModel) + .filter(ActionModel.task_id == subquery.c.task_id) + .order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at) + ) + + actions = (await session.scalars(query)).all() + return [Action.model_validate(action) for action in actions] + + @db_operation("create_task_run") + async def create_task_run( + self, + task_run_type: RunType, + organization_id: str, + run_id: str, + title: str | None = None, + url: str | None = None, + url_hash: str | None = None, + ) -> Run: + async with self.Session() as session: + task_run = TaskRunModel( + task_run_type=task_run_type, + organization_id=organization_id, + run_id=run_id, + title=title, + url=url, + url_hash=url_hash, + ) + session.add(task_run) + await session.commit() + await session.refresh(task_run) + return Run.model_validate(task_run) + + @db_operation("update_task_run") + async def update_task_run( + self, + organization_id: str, + run_id: str, + title: str | None = None, + url: str | None = None, + url_hash: str | None = None, + ) -> None: + async with self.Session() as session: + task_run = ( + await session.scalars( + select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id) + ) + ).first() + if not task_run: + raise NotFoundError(f"TaskRun {run_id} not found") + + if title: + task_run.title = title + if url: + task_run.url = url + if url_hash: + task_run.url_hash = url_hash + await session.commit() + + @db_operation("update_job_run_compute_cost") + async def update_job_run_compute_cost( + self, + organization_id: str, + run_id: str, + instance_type: str | None = None, + vcpu_millicores: int | None = None, + memory_mb: int | None = None, + duration_ms: int | None = None, + compute_cost: float | None = None, + ) -> None: + """Update compute cost metrics for a job run.""" + async with self.Session() as session: + task_run = ( + await session.scalars( + select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id) + ) + ).first() + if not task_run: + LOG.warning( + "TaskRun not found for compute cost update", + run_id=run_id, + organization_id=organization_id, + ) + return + + if instance_type is not None: + task_run.instance_type = instance_type + if vcpu_millicores is not None: + task_run.vcpu_millicores = vcpu_millicores + if memory_mb is not None: + task_run.memory_mb = memory_mb + if duration_ms is not None: + task_run.duration_ms = duration_ms + if compute_cost is not None: + task_run.compute_cost = compute_cost + await session.commit() + + @db_operation("cache_task_run") + async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run: + async with self.Session() as session: + task_run = ( + await session.scalars( + select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id) + ) + ).first() + if task_run: + task_run.cached = True + await session.commit() + await session.refresh(task_run) + return Run.model_validate(task_run) + raise NotFoundError(f"Run {run_id} not found") + + @db_operation("get_cached_task_run") + async def get_cached_task_run( + self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None + ) -> Run | None: + async with self.Session() as session: + query = select(TaskRunModel) + if task_run_type: + query = query.filter_by(task_run_type=task_run_type) + if url_hash: + query = query.filter_by(url_hash=url_hash) + if organization_id: + query = query.filter_by(organization_id=organization_id) + query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc()) + task_run = (await session.scalars(query)).first() + return Run.model_validate(task_run) if task_run else None + + @db_operation("get_run") + async def get_run( + self, + run_id: str, + organization_id: str | None = None, + ) -> Run | None: + async with self.Session() as session: + query = select(TaskRunModel).filter_by(run_id=run_id) + if organization_id: + query = query.filter_by(organization_id=organization_id) + task_run = (await session.scalars(query)).first() + return Run.model_validate(task_run) if task_run else None diff --git a/skyvern/forge/sdk/db/repositories/workflow_runs.py b/skyvern/forge/sdk/db/repositories/workflow_runs.py new file mode 100644 index 000000000..e38076fe3 --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/workflow_runs.py @@ -0,0 +1,903 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Callable + +import structlog +from sqlalchemy import Text, and_, cast, exists, func, literal, literal_column, or_, select, update +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.exc import SQLAlchemyError + +from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db.base_alchemy_db import read_retry +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.enums import WorkflowRunTriggerType +from skyvern.forge.sdk.db.exceptions import NotFoundError + +if TYPE_CHECKING: + from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory + +from skyvern.forge.sdk.db.models import ( + TaskModel, + WorkflowModel, + WorkflowParameterModel, + WorkflowRunBlockModel, + WorkflowRunModel, + WorkflowRunOutputParameterModel, + WorkflowRunParameterModel, +) +from skyvern.forge.sdk.db.protocols import WorkflowParameterReader +from skyvern.forge.sdk.db.utils import ( + convert_to_task, + convert_to_workflow_run, + convert_to_workflow_run_output_parameter, + convert_to_workflow_run_parameter, + serialize_proxy_location, +) +from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs +from skyvern.forge.sdk.schemas.tasks import Task +from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter +from skyvern.forge.sdk.workflow.models.workflow import ( + WorkflowRun, + WorkflowRunOutputParameter, + WorkflowRunParameter, + WorkflowRunStatus, +) +from skyvern.schemas.runs import ProxyLocationInput + +from . import _UNSET + +LOG = structlog.get_logger() + + +class WorkflowRunsRepository(BaseRepository): + """Database operations for workflow runs.""" + + def __init__( + self, + session_factory: _SessionFactory, + debug_enabled: bool = False, + is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None, + workflow_parameter_reader: WorkflowParameterReader | None = None, + dialect_name: str = "postgresql", + ) -> None: + super().__init__(session_factory, debug_enabled, is_retryable_error_fn) + self._workflow_parameter_reader = workflow_parameter_reader + self._dialect_name = dialect_name + + @db_operation("get_running_workflow_runs_info_globally") + async def get_running_workflow_runs_info_globally( + self, + stale_threshold_hours: int = 24, + ) -> tuple[int, int]: + """ + Get information about running workflow runs across all organizations. + Used by cleanup service to determine if cleanup should be skipped. + + Args: + stale_threshold_hours: Workflow runs not updated for this many hours are considered stale. + + Returns: + Tuple of (active_workflow_count, stale_workflow_count). + Active workflows are those updated within the threshold. + Stale workflows are those not updated within the threshold but still in running status. + """ + async with self.Session() as session: + running_statuses = [ + WorkflowRunStatus.created, + WorkflowRunStatus.queued, + WorkflowRunStatus.running, + WorkflowRunStatus.paused, + ] + stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=stale_threshold_hours) + + # Count active workflow runs (recently updated) + active_query = ( + select(func.count()) + .select_from(WorkflowRunModel) + .filter(WorkflowRunModel.status.in_(running_statuses)) + .filter(WorkflowRunModel.modified_at >= stale_cutoff) + ) + active_count = (await session.execute(active_query)).scalar_one() + + # Count stale workflow runs (not updated for a long time) + stale_query = ( + select(func.count()) + .select_from(WorkflowRunModel) + .filter(WorkflowRunModel.status.in_(running_statuses)) + .filter(WorkflowRunModel.modified_at < stale_cutoff) + ) + stale_count = (await session.execute(stale_query)).scalar_one() + + return (active_count, stale_count) + + @db_operation("create_workflow_run") + async def create_workflow_run( + self, + workflow_permanent_id: str, + workflow_id: str, + organization_id: str, + browser_session_id: str | None = None, + browser_profile_id: str | None = None, + proxy_location: ProxyLocationInput = None, + webhook_callback_url: str | None = None, + totp_verification_url: str | None = None, + totp_identifier: str | None = None, + parent_workflow_run_id: str | None = None, + max_screenshot_scrolling_times: int | None = None, + extra_http_headers: dict[str, str] | None = None, + browser_address: str | None = None, + sequential_key: str | None = None, + run_with: str | None = None, + debug_session_id: str | None = None, + ai_fallback: bool | None = None, + code_gen: bool | None = None, + workflow_run_id: str | None = None, + trigger_type: WorkflowRunTriggerType | None = None, + workflow_schedule_id: str | None = None, + ) -> WorkflowRun: + async with self.Session() as session: + kwargs: dict[str, Any] = {} + if workflow_run_id is not None: + kwargs["workflow_run_id"] = workflow_run_id + workflow_run = WorkflowRunModel( + workflow_permanent_id=workflow_permanent_id, + workflow_id=workflow_id, + organization_id=organization_id, + browser_session_id=browser_session_id, + browser_profile_id=browser_profile_id, + proxy_location=serialize_proxy_location(proxy_location), + status="created", + webhook_callback_url=webhook_callback_url, + totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, + parent_workflow_run_id=parent_workflow_run_id, + max_screenshot_scrolling_times=max_screenshot_scrolling_times, + extra_http_headers=extra_http_headers, + browser_address=browser_address, + sequential_key=sequential_key, + run_with=run_with, + debug_session_id=debug_session_id, + ai_fallback=ai_fallback, + code_gen=code_gen, + trigger_type=trigger_type.value if trigger_type else None, + workflow_schedule_id=workflow_schedule_id, + **kwargs, + ) + session.add(workflow_run) + await session.commit() + await session.refresh(workflow_run) + return convert_to_workflow_run(workflow_run) + + @db_operation("update_workflow_run") + async def update_workflow_run( + self, + workflow_run_id: str, + status: WorkflowRunStatus | None = None, + failure_reason: str | None = None, + webhook_failure_reason: str | None = None, + ai_fallback_triggered: bool | None = None, + job_id: str | None = None, + run_with: str | None = None, + sequential_key: str | None = None, + ai_fallback: bool | None = None, + depends_on_workflow_run_id: str | None = None, + browser_session_id: str | None = None, + waiting_for_verification_code: bool | None = None, + verification_code_identifier: str | None = None, + verification_code_polling_started_at: datetime | None = None, + browser_profile_id: str | None | object = _UNSET, + ) -> WorkflowRun: + async with self.Session() as session: + workflow_run = ( + await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)) + ).first() + if workflow_run: + if status: + workflow_run.status = status + if status and status == WorkflowRunStatus.queued and workflow_run.queued_at is None: + workflow_run.queued_at = datetime.now(timezone.utc) + if status and status == WorkflowRunStatus.running and workflow_run.started_at is None: + workflow_run.started_at = datetime.now(timezone.utc) + if status and status.is_final() and workflow_run.finished_at is None: + workflow_run.finished_at = datetime.now(timezone.utc) + if failure_reason: + workflow_run.failure_reason = failure_reason + if webhook_failure_reason is not None: + workflow_run.webhook_failure_reason = webhook_failure_reason + if ai_fallback_triggered is not None: + workflow_run.script_run = {"ai_fallback_triggered": ai_fallback_triggered} + if job_id: + workflow_run.job_id = job_id + if run_with: + workflow_run.run_with = run_with + if sequential_key: + workflow_run.sequential_key = sequential_key + if ai_fallback is not None: + workflow_run.ai_fallback = ai_fallback + if depends_on_workflow_run_id: + workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id + if browser_session_id: + workflow_run.browser_session_id = browser_session_id + # 2FA verification code waiting state updates + if waiting_for_verification_code is not None: + workflow_run.waiting_for_verification_code = waiting_for_verification_code + if verification_code_identifier is not None: + workflow_run.verification_code_identifier = verification_code_identifier + if verification_code_polling_started_at is not None: + workflow_run.verification_code_polling_started_at = verification_code_polling_started_at + if waiting_for_verification_code is not None and not waiting_for_verification_code: + # Clear related fields when waiting is set to False + workflow_run.verification_code_identifier = None + workflow_run.verification_code_polling_started_at = None + if browser_profile_id is not _UNSET: + workflow_run.browser_profile_id = browser_profile_id + await session.commit() + await save_workflow_run_logs(workflow_run_id) + await session.refresh(workflow_run) + return convert_to_workflow_run(workflow_run) + else: + raise WorkflowRunNotFound(workflow_run_id) + + @db_operation("bulk_update_workflow_runs") + async def bulk_update_workflow_runs( + self, + workflow_run_ids: list[str], + status: WorkflowRunStatus | None = None, + failure_reason: str | None = None, + ) -> None: + """Bulk update workflow runs by their IDs. + + Args: + workflow_run_ids: List of workflow run IDs to update + status: Optional status to set for all workflow runs + failure_reason: Optional failure reason to set for all workflow runs + """ + if not workflow_run_ids: + return + + async with self.Session() as session: + update_values = {} + if status: + update_values["status"] = status.value + if failure_reason: + update_values["failure_reason"] = failure_reason + + if update_values: + update_stmt = ( + update(WorkflowRunModel) + .where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids)) + .values(**update_values) + ) + await session.execute(update_stmt) + await session.commit() + + @db_operation("clear_workflow_run_failure_reason") + async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun: + async with self.Session() as session: + workflow_run = ( + await session.scalars( + select(WorkflowRunModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(organization_id=organization_id) + ) + ).first() + if workflow_run: + workflow_run.failure_reason = None + await session.commit() + await session.refresh(workflow_run) + return convert_to_workflow_run(workflow_run) + else: + raise NotFoundError("Workflow run not found") + + @db_operation("get_all_runs") + async def get_all_runs( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + status: list[WorkflowRunStatus] | None = None, + include_debugger_runs: bool = False, + search_key: str | None = None, + ) -> list[WorkflowRun | Task]: + async with self.Session() as session: + # temporary limit to 10 pages + if page > 10: + return [] + + limit = page * page_size + + workflow_run_query = ( + select(WorkflowRunModel, WorkflowModel.title) + .join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id) + .filter(WorkflowRunModel.organization_id == organization_id) + .filter(WorkflowRunModel.parent_workflow_run_id.is_(None)) + ) + + if not include_debugger_runs: + workflow_run_query = workflow_run_query.filter(WorkflowRunModel.debug_session_id.is_(None)) + + if search_key: + key_like = f"%{search_key}%" + # Match workflow_run_id directly + id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like) + # Match parameter key or description (only for non-deleted parameter definitions) + param_key_desc_exists = exists( + select(1) + .select_from(WorkflowRunParameterModel) + .join( + WorkflowParameterModel, + WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id, + ) + .where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(WorkflowParameterModel.deleted_at.is_(None)) + .where( + or_( + WorkflowParameterModel.key.ilike(key_like), + WorkflowParameterModel.description.ilike(key_like), + ) + ) + ) + # Match run parameter value directly (searches all values regardless of parameter definition status) + param_value_exists = exists( + select(1) + .select_from(WorkflowRunParameterModel) + .where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(WorkflowRunParameterModel.value.ilike(key_like)) + ) + # Match extra HTTP headers (cast JSON to text for search, skip NULLs) + extra_headers_match = and_( + WorkflowRunModel.extra_http_headers.isnot(None), + func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like), + ) + workflow_run_query = workflow_run_query.where( + or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match) + ) + + if status: + workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status)) + workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit) + workflow_run_query_result = (await session.execute(workflow_run_query)).all() + workflow_runs = [ + convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled) + for run, title in workflow_run_query_result + ] + + task_query = ( + select(TaskModel) + .filter(TaskModel.organization_id == organization_id) + .filter(TaskModel.workflow_run_id.is_(None)) + ) + if status: + task_query = task_query.filter(TaskModel.status.in_(status)) + task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit) + task_query_result = (await session.scalars(task_query)).all() + tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result] + + runs = workflow_runs + tasks + + runs.sort(key=lambda x: x.created_at, reverse=True) + + lower = (page - 1) * page_size + upper = page * page_size + + return runs[lower:upper] + + @read_retry() + @db_operation("get_workflow_run", log_errors=False) + async def get_workflow_run( + self, + workflow_run_id: str, + organization_id: str | None = None, + job_id: str | None = None, + status: WorkflowRunStatus | None = None, + ) -> WorkflowRun | None: + async with self.Session() as session: + get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id) + if organization_id: + get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id) + if job_id: + get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id) + if status: + get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value) + if workflow_run := (await session.scalars(get_workflow_run_query)).first(): + return convert_to_workflow_run(workflow_run) + return None + + async def get_run(self, run_id: str, organization_id: str | None = None) -> WorkflowRun | None: + """Alias satisfying the RunReader protocol.""" + return await self.get_workflow_run(run_id, organization_id=organization_id) + + @db_operation("get_last_queued_workflow_run") + async def get_last_queued_workflow_run( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + sequential_key: str | None = None, + ) -> WorkflowRun | None: + async with self.Session() as session: + query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id) + query = query.filter(WorkflowRunModel.browser_session_id.is_(None)) + if organization_id: + query = query.filter_by(organization_id=organization_id) + query = query.filter_by(status=WorkflowRunStatus.queued) + if sequential_key: + query = query.filter_by(sequential_key=sequential_key) + query = query.order_by(WorkflowRunModel.modified_at.desc()) + workflow_run = (await session.scalars(query)).first() + return convert_to_workflow_run(workflow_run) if workflow_run else None + + @db_operation("get_workflow_runs_by_ids") + async def get_workflow_runs_by_ids( + self, + workflow_run_ids: list[str], + workflow_permanent_id: str | None = None, + organization_id: str | None = None, + ) -> list[WorkflowRun]: + async with self.Session() as session: + query = select(WorkflowRunModel).filter(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids)) + if workflow_permanent_id: + query = query.filter_by(workflow_permanent_id=workflow_permanent_id) + if organization_id: + query = query.filter_by(organization_id=organization_id) + workflow_runs = (await session.scalars(query)).all() + return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs] + + @db_operation("get_last_running_workflow_run") + async def get_last_running_workflow_run( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + sequential_key: str | None = None, + ) -> WorkflowRun | None: + async with self.Session() as session: + query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id) + query = query.filter(WorkflowRunModel.browser_session_id.is_(None)) + if organization_id: + query = query.filter_by(organization_id=organization_id) + query = query.filter_by(status=WorkflowRunStatus.running) + if sequential_key: + query = query.filter_by(sequential_key=sequential_key) + query = query.filter( + WorkflowRunModel.started_at.isnot(None) + ) # filter out workflow runs that does not have a started_at timestamp + query = query.order_by(WorkflowRunModel.started_at.desc()) + workflow_run = (await session.scalars(query)).first() + return convert_to_workflow_run(workflow_run) if workflow_run else None + + async def _get_last_workflow_run_by_filter( + self, + organization_id: str | None = None, + **filters: str, + ) -> WorkflowRun | None: + """Get the last queued or running workflow run matching the given column filters. + + Used for browser_session_id and browser_address sequential execution. + """ + async with self.Session() as session: + query = select(WorkflowRunModel).filter_by(**filters) + if organization_id: + query = query.filter_by(organization_id=organization_id) + + # check if there's a queued run + queue_query = query.filter_by(status=WorkflowRunStatus.queued) + queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc()) + workflow_run = (await session.scalars(queue_query)).first() + if workflow_run: + return convert_to_workflow_run(workflow_run) + + # check if there's a running run + running_query = query.filter_by(status=WorkflowRunStatus.running) + running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None)) + running_query = running_query.order_by(WorkflowRunModel.started_at.desc()) + workflow_run = (await session.scalars(running_query)).first() + if workflow_run: + return convert_to_workflow_run(workflow_run) + return None + + async def get_last_workflow_run_for_browser_session( + self, + browser_session_id: str, + organization_id: str | None = None, + ) -> WorkflowRun | None: + return await self._get_last_workflow_run_by_filter( + organization_id=organization_id, + browser_session_id=browser_session_id, + ) + + async def get_last_workflow_run_for_browser_address( + self, + browser_address: str, + organization_id: str | None = None, + ) -> WorkflowRun | None: + return await self._get_last_workflow_run_by_filter( + organization_id=organization_id, + browser_address=browser_address, + ) + + @db_operation("get_workflows_depending_on") + async def get_workflows_depending_on( + self, + workflow_run_id: str, + ) -> list[WorkflowRun]: + """ + Get all workflow runs that depend on the given workflow_run_id. + + Used to find workflows that should be signaled when a workflow completes, + for sequential workflow dependency handling. + + Args: + workflow_run_id: The workflow_run_id to find dependents for + + Returns: + List of WorkflowRun objects that have depends_on_workflow_run_id set to workflow_run_id + """ + async with self.Session() as session: + query = select(WorkflowRunModel).filter_by(depends_on_workflow_run_id=workflow_run_id) + workflow_runs = (await session.scalars(query)).all() + return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs] + + @staticmethod + def _apply_search_key_filter(query, search_key: str | None): # type: ignore[no-untyped-def] + if not search_key: + return query + key_like = f"%{search_key}%" + # Match workflow_run_id directly + id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like) + # Match parameter key or description (only for non-deleted parameter definitions) + # Use EXISTS to avoid duplicate rows and to keep pagination correct + param_key_desc_exists = exists( + select(1) + .select_from(WorkflowRunParameterModel) + .join( + WorkflowParameterModel, + WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id, + ) + .where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(WorkflowParameterModel.deleted_at.is_(None)) + .where( + or_( + WorkflowParameterModel.key.ilike(key_like), + WorkflowParameterModel.description.ilike(key_like), + ) + ) + ) + # Match run parameter value directly (searches all values regardless of parameter definition status) + param_value_exists = exists( + select(1) + .select_from(WorkflowRunParameterModel) + .where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(WorkflowRunParameterModel.value.ilike(key_like)) + ) + # Match extra HTTP headers (cast JSON to text for search, skip NULLs) + extra_headers_match = and_( + WorkflowRunModel.extra_http_headers.isnot(None), + func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like), + ) + return query.where(or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match)) + + def _apply_error_code_filter(self, query, error_code: str | None): # type: ignore[no-untyped-def] + if not error_code: + return query + + if self._dialect_name == "sqlite": + # Task errors: array of objects like [{"error_code": "timeout", ...}] + # Use json_each to iterate + json_extract to match the error_code field + error_code_in_tasks = exists( + select(1) + .select_from(TaskModel) + .where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where( + exists( + select(1) + .select_from(func.json_each(TaskModel.errors)) + .where(func.json_extract(literal_column("json_each.value"), "$.error_code") == error_code) + ) + ) + ) + # Block errors: flat array of strings like ["timeout", "network_error"] + error_code_in_blocks = exists( + select(1) + .select_from(WorkflowRunBlockModel) + .where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where( + exists( + select(1) + .select_from(func.json_each(WorkflowRunBlockModel.error_codes)) + .where(literal_column("json_each.value") == error_code) + ) + ) + ) + else: + # PostgreSQL: native JSONB containment + error_code_in_tasks = exists( + select(1) + .select_from(TaskModel) + .where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(cast(TaskModel.errors, JSONB).contains(literal([{"error_code": error_code}], type_=JSONB))) + ) + error_code_in_blocks = exists( + select(1) + .select_from(WorkflowRunBlockModel) + .where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id) + .where(cast(WorkflowRunBlockModel.error_codes, JSONB).contains(literal([error_code], type_=JSONB))) + ) + return query.where(or_(error_code_in_tasks, error_code_in_blocks)) + + @db_operation("get_workflow_runs") + async def get_workflow_runs( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + status: list[WorkflowRunStatus] | None = None, + ordering: tuple[str, str] | None = None, + search_key: str | None = None, + error_code: str | None = None, + ) -> list[WorkflowRun]: + async with self.Session() as session: + db_page = page - 1 # offset logic is 0 based + + query = ( + select(WorkflowRunModel, WorkflowModel.title) + .join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id) + .filter(WorkflowRunModel.organization_id == organization_id) + .filter(WorkflowRunModel.parent_workflow_run_id.is_(None)) + ) + + query = self._apply_search_key_filter(query, search_key) + query = self._apply_error_code_filter(query, error_code) + + if status: + query = query.filter(WorkflowRunModel.status.in_(status)) + + allowed_ordering_fields = { + "created_at": WorkflowRunModel.created_at, + "status": WorkflowRunModel.status, + } + + field, direction = ("created_at", "desc") + + if ordering and isinstance(ordering, tuple) and len(ordering) == 2: + req_field, req_direction = ordering + if req_field in allowed_ordering_fields and req_direction in ("asc", "desc"): + field, direction = req_field, req_direction + + order_column = allowed_ordering_fields[field] + + if direction == "asc": + query = query.order_by(order_column.asc()) + else: + query = query.order_by(order_column.desc()) + + query = query.limit(page_size).offset(db_page * page_size) + + workflow_runs = (await session.execute(query)).all() + + return [ + convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled) + for run, title in workflow_runs + ] + + @db_operation("get_workflow_runs_count") + async def get_workflow_runs_count( + self, + organization_id: str, + status: list[WorkflowRunStatus] | None = None, + ) -> int: + async with self.Session() as session: + count_query = ( + select(func.count()) + .select_from(WorkflowRunModel) + .filter(WorkflowRunModel.organization_id == organization_id) + ) + if status: + count_query = count_query.filter(WorkflowRunModel.status.in_(status)) + return (await session.execute(count_query)).scalar_one() + + @db_operation("get_workflow_runs_for_workflow_permanent_id") + async def get_workflow_runs_for_workflow_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str, + page: int = 1, + page_size: int = 10, + status: list[WorkflowRunStatus] | None = None, + search_key: str | None = None, + error_code: str | None = None, + ) -> list[WorkflowRun]: + """ + Get runs for a workflow, with optional `search_key` on run ID, parameter key/description/value, + or extra HTTP headers. + """ + async with self.Session() as session: + db_page = page - 1 # offset logic is 0 based + query = ( + select(WorkflowRunModel, WorkflowModel.title) + .join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id) + .filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id) + .filter(WorkflowRunModel.organization_id == organization_id) + ) + query = self._apply_search_key_filter(query, search_key) + query = self._apply_error_code_filter(query, error_code) + if status: + query = query.filter(WorkflowRunModel.status.in_(status)) + query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + workflow_runs_and_titles_tuples = (await session.execute(query)).all() + workflow_runs = [ + convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled) + for run, title in workflow_runs_and_titles_tuples + ] + return workflow_runs + + @db_operation("get_workflow_runs_by_parent_workflow_run_id") + async def get_workflow_runs_by_parent_workflow_run_id( + self, + parent_workflow_run_id: str, + organization_id: str | None = None, + ) -> list[WorkflowRun]: + async with self.Session() as session: + query = select(WorkflowRunModel).filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id) + if organization_id is not None: + query = query.filter(WorkflowRunModel.organization_id == organization_id) + workflow_runs = (await session.scalars(query)).all() + return [convert_to_workflow_run(run) for run in workflow_runs] + + @db_operation("get_workflow_run_output_parameters") + async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]: + async with self.Session() as session: + workflow_run_output_parameters = ( + await session.scalars( + select(WorkflowRunOutputParameterModel) + .filter_by(workflow_run_id=workflow_run_id) + .order_by(WorkflowRunOutputParameterModel.created_at) + ) + ).all() + return [ + convert_to_workflow_run_output_parameter(parameter, self.debug_enabled) + for parameter in workflow_run_output_parameters + ] + + @db_operation("get_workflow_run_output_parameter_by_id") + async def get_workflow_run_output_parameter_by_id( + self, workflow_run_id: str, output_parameter_id: str + ) -> WorkflowRunOutputParameter | None: + async with self.Session() as session: + parameter = ( + await session.scalars( + select(WorkflowRunOutputParameterModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(output_parameter_id=output_parameter_id) + .order_by(WorkflowRunOutputParameterModel.created_at) + ) + ).first() + + if parameter: + return convert_to_workflow_run_output_parameter(parameter, self.debug_enabled) + + return None + + @db_operation("create_or_update_workflow_run_output_parameter") + async def create_or_update_workflow_run_output_parameter( + self, + workflow_run_id: str, + output_parameter_id: str, + value: dict[str, Any] | list | str | None, + ) -> WorkflowRunOutputParameter: + async with self.Session() as session: + # check if the workflow run output parameter already exists + # if it does, update the value + if workflow_run_output_parameter := ( + await session.scalars( + select(WorkflowRunOutputParameterModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(output_parameter_id=output_parameter_id) + ) + ).first(): + LOG.info( + "Updating existing workflow run output parameter", + workflow_run_id=workflow_run_output_parameter.workflow_run_id, + output_parameter_id=workflow_run_output_parameter.output_parameter_id, + ) + workflow_run_output_parameter.value = value + await session.commit() + await session.refresh(workflow_run_output_parameter) + return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled) + + # if it does not exist, create a new one + workflow_run_output_parameter = WorkflowRunOutputParameterModel( + workflow_run_id=workflow_run_id, + output_parameter_id=output_parameter_id, + value=value, + ) + session.add(workflow_run_output_parameter) + await session.commit() + await session.refresh(workflow_run_output_parameter) + return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled) + + @db_operation("update_workflow_run_output_parameter") + async def update_workflow_run_output_parameter( + self, + workflow_run_id: str, + output_parameter_id: str, + value: dict[str, Any] | list | str | None, + ) -> WorkflowRunOutputParameter: + async with self.Session() as session: + workflow_run_output_parameter = ( + await session.scalars( + select(WorkflowRunOutputParameterModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(output_parameter_id=output_parameter_id) + ) + ).first() + if not workflow_run_output_parameter: + raise NotFoundError( + f"WorkflowRunOutputParameter not found for {workflow_run_id} and {output_parameter_id}" + ) + workflow_run_output_parameter.value = value + await session.commit() + await session.refresh(workflow_run_output_parameter) + return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled) + + @db_operation("create_workflow_run_parameter") + async def create_workflow_run_parameter( + self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any + ) -> WorkflowRunParameter: + workflow_parameter_id = workflow_parameter.workflow_parameter_id + async with self.Session() as session: + workflow_run_parameter = WorkflowRunParameterModel( + workflow_run_id=workflow_run_id, + workflow_parameter_id=workflow_parameter_id, + value=value, + ) + session.add(workflow_run_parameter) + await session.commit() + await session.refresh(workflow_run_parameter) + return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled) + + @db_operation("get_workflow_run_parameters") + async def get_workflow_run_parameters( + self, workflow_run_id: str + ) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]: + async with self.Session() as session: + workflow_run_parameters = ( + await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id)) + ).all() + results = [] + for workflow_run_parameter in workflow_run_parameters: + if self._workflow_parameter_reader is None: + raise RuntimeError("workflow_parameter_reader dependency not set") + workflow_parameter = await self._workflow_parameter_reader.get_workflow_parameter( + workflow_run_parameter.workflow_parameter_id + ) + if not workflow_parameter: + raise WorkflowParameterNotFound(workflow_parameter_id=workflow_run_parameter.workflow_parameter_id) + results.append( + ( + workflow_parameter, + convert_to_workflow_run_parameter( + workflow_run_parameter, + workflow_parameter, + self.debug_enabled, + ), + ) + ) + return results + + @db_operation("get_workflow_run_block_errors") + async def get_workflow_run_block_errors( + self, + workflow_run_id: str, + organization_id: str | None = None, + ) -> list[tuple[list[str], str | None]]: + """Return (error_codes, failure_reason) tuples for blocks with non-null error_codes.""" + async with self.Session() as session: + query = select(WorkflowRunBlockModel.error_codes, WorkflowRunBlockModel.failure_reason).filter_by( + workflow_run_id=workflow_run_id + ) + if organization_id is not None: + query = query.filter_by(organization_id=organization_id) + query = query.where(WorkflowRunBlockModel.error_codes.isnot(None)) + rows = (await session.execute(query)).all() + return [(row.error_codes, row.failure_reason) for row in rows] diff --git a/skyvern/forge/sdk/db/repositories/workflows.py b/skyvern/forge/sdk/db/repositories/workflows.py new file mode 100644 index 000000000..fd9cfb9cd --- /dev/null +++ b/skyvern/forge/sdk/db/repositories/workflows.py @@ -0,0 +1,705 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import structlog +from sqlalchemy import exists, func, or_, select, update + +from skyvern.constants import DEFAULT_SCRIPT_RUN_ID +from skyvern.forge.sdk.db._error_handling import db_operation +from skyvern.forge.sdk.db._soft_delete import exclude_deleted +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.exceptions import NotFoundError +from skyvern.forge.sdk.db.models import ( + AWSSecretParameterModel, + AzureVaultCredentialParameterModel, + BitwardenCreditCardDataParameterModel, + BitwardenLoginCredentialParameterModel, + BitwardenSensitiveInformationParameterModel, + CredentialParameterModel, + FolderModel, + OnePasswordCredentialParameterModel, + OutputParameterModel, + WorkflowModel, + WorkflowParameterModel, + WorkflowRunModel, + WorkflowScheduleModel, + WorkflowTemplateModel, +) +from skyvern.forge.sdk.db.utils import convert_to_workflow, serialize_proxy_location +from skyvern.forge.sdk.workflow.models.workflow import Workflow +from skyvern.schemas.runs import ProxyLocationInput +from skyvern.schemas.workflows import WorkflowStatus + +LOG = structlog.get_logger() + + +class WorkflowsRepository(BaseRepository): + """Database operations for workflow management.""" + + @db_operation("create_workflow") + async def create_workflow( + self, + title: str, + workflow_definition: dict[str, Any], + organization_id: str | None = None, + description: str | None = None, + proxy_location: ProxyLocationInput = None, + webhook_callback_url: str | None = None, + max_screenshot_scrolling_times: int | None = None, + extra_http_headers: dict[str, str] | None = None, + totp_verification_url: str | None = None, + totp_identifier: str | None = None, + persist_browser_session: bool = False, + model: dict[str, Any] | None = None, + workflow_permanent_id: str | None = None, + version: int | None = None, + is_saved_task: bool = False, + status: WorkflowStatus = WorkflowStatus.published, + run_with: str | None = None, + ai_fallback: bool = True, + cache_key: str | None = None, + adaptive_caching: bool = False, + code_version: int | None = None, + generate_script_on_terminal: bool = False, + run_sequentially: bool = False, + sequential_key: str | None = None, + folder_id: str | None = None, + ) -> Workflow: + async with self.Session() as session: + workflow = WorkflowModel( + organization_id=organization_id, + title=title, + description=description, + workflow_definition=workflow_definition, + proxy_location=serialize_proxy_location(proxy_location), + webhook_callback_url=webhook_callback_url, + totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, + max_screenshot_scrolling_times=max_screenshot_scrolling_times, + extra_http_headers=extra_http_headers, + persist_browser_session=persist_browser_session, + model=model, + is_saved_task=is_saved_task, + status=status, + run_with=run_with, + ai_fallback=ai_fallback, + cache_key=cache_key or DEFAULT_SCRIPT_RUN_ID, + adaptive_caching=adaptive_caching, + code_version=code_version, + generate_script_on_terminal=generate_script_on_terminal, + run_sequentially=run_sequentially, + sequential_key=sequential_key, + folder_id=folder_id, + ) + if workflow_permanent_id: + workflow.workflow_permanent_id = workflow_permanent_id + if version: + workflow.version = version + session.add(workflow) + + # Update folder's modified_at if folder_id is provided + if folder_id: + # Validate folder exists and belongs to the same organization + folder_stmt = ( + select(FolderModel) + .where(FolderModel.folder_id == folder_id) + .where(FolderModel.organization_id == organization_id) + .where(FolderModel.deleted_at.is_(None)) + ) + folder_model = await session.scalar(folder_stmt) + if not folder_model: + raise ValueError( + f"Folder {folder_id} not found or does not belong to organization {organization_id}" + ) + folder_model.modified_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(workflow) + return convert_to_workflow(workflow, self.debug_enabled) + + @db_operation("soft_delete_workflow_by_id") + async def soft_delete_workflow_by_id(self, workflow_id: str, organization_id: str) -> None: + async with self.Session() as session: + # soft delete the workflow by setting the deleted_at field to the current time + update_deleted_at_query = ( + update(WorkflowModel) + .where(WorkflowModel.workflow_id == workflow_id) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .values(deleted_at=datetime.now(timezone.utc)) + ) + await session.execute(update_deleted_at_query) + await session.commit() + + @db_operation("get_workflow") + async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None: + async with self.Session() as session: + get_workflow_query = exclude_deleted( + select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel + ) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if workflow := (await session.scalars(get_workflow_query)).first(): + is_template = ( + await self.is_workflow_template( + workflow_permanent_id=workflow.workflow_permanent_id, + organization_id=workflow.organization_id, + ) + if organization_id + else False + ) + return convert_to_workflow( + workflow, + self.debug_enabled, + is_template=is_template, + ) + return None + + @db_operation("get_workflow_by_permanent_id") + async def get_workflow_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + version: int | None = None, + ignore_version: int | None = None, + filter_deleted: bool = True, + ) -> Workflow | None: + get_workflow_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id) + if filter_deleted: + get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None)) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if version: + get_workflow_query = get_workflow_query.filter_by(version=version) + if ignore_version: + get_workflow_query = get_workflow_query.filter(WorkflowModel.version != ignore_version) + get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc()) + async with self.Session() as session: + if workflow := (await session.scalars(get_workflow_query)).first(): + is_template = ( + await self.is_workflow_template( + workflow_permanent_id=workflow.workflow_permanent_id, + organization_id=workflow.organization_id, + ) + if organization_id + else False + ) + return convert_to_workflow( + workflow, + self.debug_enabled, + is_template=is_template, + ) + return None + + @db_operation("get_workflow_for_workflow_run") + async def get_workflow_for_workflow_run( + self, + workflow_run_id: str, + organization_id: str | None = None, + filter_deleted: bool = True, + ) -> Workflow | None: + get_workflow_query = select(WorkflowModel) + + if filter_deleted: + get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None)) + + get_workflow_query = get_workflow_query.join( + WorkflowRunModel, + WorkflowRunModel.workflow_id == WorkflowModel.workflow_id, + ) + + if organization_id: + get_workflow_query = get_workflow_query.filter(WorkflowRunModel.organization_id == organization_id) + + get_workflow_query = get_workflow_query.filter(WorkflowRunModel.workflow_run_id == workflow_run_id) + async with self.Session() as session: + if workflow := (await session.scalars(get_workflow_query)).first(): + is_template = ( + await self.is_workflow_template( + workflow_permanent_id=workflow.workflow_permanent_id, + organization_id=workflow.organization_id, + ) + if organization_id + else False + ) + return convert_to_workflow( + workflow, + self.debug_enabled, + is_template=is_template, + ) + return None + + @db_operation("get_workflow_versions_by_permanent_id") + async def get_workflow_versions_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + filter_deleted: bool = True, + ) -> list[Workflow]: + """ + Get all versions of a workflow by its permanent ID, ordered by version descending (newest first). + """ + get_workflows_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id) + if filter_deleted: + get_workflows_query = get_workflows_query.filter(WorkflowModel.deleted_at.is_(None)) + if organization_id: + get_workflows_query = get_workflows_query.filter_by(organization_id=organization_id) + get_workflows_query = get_workflows_query.order_by(WorkflowModel.version.desc()) + + async with self.Session() as session: + workflows = (await session.scalars(get_workflows_query)).all() + template_permanent_ids: set[str] = set() + if workflows and organization_id: + template_permanent_ids = await self.get_org_template_permanent_ids(organization_id) + + return [ + convert_to_workflow( + workflow, + self.debug_enabled, + is_template=workflow.workflow_permanent_id in template_permanent_ids, + ) + for workflow in workflows + ] + + @db_operation("get_workflows_by_permanent_ids") + async def get_workflows_by_permanent_ids( + self, + workflow_permanent_ids: list[str], + organization_id: str | None = None, + page: int = 1, + page_size: int = 10, + title: str = "", + statuses: list[WorkflowStatus] | None = None, + ) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + db_page = page - 1 + async with self.Session() as session: + subquery = ( + select( + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids)) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + main_query = select(WorkflowModel).join( + subquery, + (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + if organization_id: + main_query = main_query.where(WorkflowModel.organization_id == organization_id) + if title: + main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%")) + if statuses: + main_query = main_query.where(WorkflowModel.status.in_(statuses)) + main_query = ( + main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + ) + workflows = (await session.scalars(main_query)).all() + + # Map template status by permanent_id so API responses surface is_template + template_permanent_ids: set[str] = set() + if workflows and organization_id: + template_permanent_ids = await self.get_org_template_permanent_ids(organization_id) + + return [ + convert_to_workflow( + workflow, + self.debug_enabled, + is_template=workflow.workflow_permanent_id in template_permanent_ids, + ) + for workflow in workflows + ] + + @db_operation("get_workflows_by_organization_id") + async def get_workflows_by_organization_id( + self, + organization_id: str, + page: int = 1, + page_size: int = 10, + only_saved_tasks: bool = False, + only_workflows: bool = False, + only_templates: bool = False, + search_key: str | None = None, + folder_id: str | None = None, + statuses: list[WorkflowStatus] | None = None, + ) -> list[Workflow]: + """ + Get all workflows with the latest version for the organization. + + Search semantics: + - If `search_key` is provided, its value is used as a unified search term for + `workflows.title`, `folders.title`, and workflow parameter metadata (key, description, and default_value). + - If `search_key` is not provided, no search filtering is applied. + - Parameter metadata search excludes soft-deleted parameter rows across parameter tables. + """ + if page < 1: + raise ValueError(f"Page must be greater than 0, got {page}") + db_page = page - 1 + async with self.Session() as session: + subquery = ( + select( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + func.max(WorkflowModel.version).label("max_version"), + ) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .group_by( + WorkflowModel.organization_id, + WorkflowModel.workflow_permanent_id, + ) + .subquery() + ) + main_query = ( + select(WorkflowModel) + .join( + subquery, + (WorkflowModel.organization_id == subquery.c.organization_id) + & (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id) + & (WorkflowModel.version == subquery.c.max_version), + ) + .outerjoin( + FolderModel, + (WorkflowModel.folder_id == FolderModel.folder_id) + & (FolderModel.organization_id == WorkflowModel.organization_id), + ) + ) + if only_saved_tasks: + main_query = main_query.where(WorkflowModel.is_saved_task.is_(True)) + elif only_workflows: + main_query = main_query.where(WorkflowModel.is_saved_task.is_(False)) + if only_templates: + # Filter by workflow_templates table (templates at permanent_id level) + template_subquery = select(WorkflowTemplateModel.workflow_permanent_id).where( + WorkflowTemplateModel.organization_id == organization_id, + WorkflowTemplateModel.deleted_at.is_(None), + ) + main_query = main_query.where(WorkflowModel.workflow_permanent_id.in_(template_subquery)) + if statuses: + main_query = main_query.where(WorkflowModel.status.in_(statuses)) + if folder_id: + main_query = main_query.where(WorkflowModel.folder_id == folder_id) + if search_key: + search_like = f"%{search_key}%" + title_like = WorkflowModel.title.ilike(search_like) + folder_title_like = FolderModel.title.ilike(search_like) + + parameter_filters = [ + # WorkflowParameterModel + exists( + select(1) + .select_from(WorkflowParameterModel) + .where(WorkflowParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(WorkflowParameterModel.deleted_at.is_(None)) + .where( + or_( + WorkflowParameterModel.key.ilike(search_like), + WorkflowParameterModel.description.ilike(search_like), + WorkflowParameterModel.default_value.ilike(search_like), + ) + ) + ), + # OutputParameterModel + exists( + select(1) + .select_from(OutputParameterModel) + .where(OutputParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(OutputParameterModel.deleted_at.is_(None)) + .where( + or_( + OutputParameterModel.key.ilike(search_like), + OutputParameterModel.description.ilike(search_like), + ) + ) + ), + # AWSSecretParameterModel + exists( + select(1) + .select_from(AWSSecretParameterModel) + .where(AWSSecretParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(AWSSecretParameterModel.deleted_at.is_(None)) + .where( + or_( + AWSSecretParameterModel.key.ilike(search_like), + AWSSecretParameterModel.description.ilike(search_like), + ) + ) + ), + # BitwardenLoginCredentialParameterModel + exists( + select(1) + .select_from(BitwardenLoginCredentialParameterModel) + .where(BitwardenLoginCredentialParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(BitwardenLoginCredentialParameterModel.deleted_at.is_(None)) + .where( + or_( + BitwardenLoginCredentialParameterModel.key.ilike(search_like), + BitwardenLoginCredentialParameterModel.description.ilike(search_like), + ) + ) + ), + # BitwardenSensitiveInformationParameterModel + exists( + select(1) + .select_from(BitwardenSensitiveInformationParameterModel) + .where(BitwardenSensitiveInformationParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(BitwardenSensitiveInformationParameterModel.deleted_at.is_(None)) + .where( + or_( + BitwardenSensitiveInformationParameterModel.key.ilike(search_like), + BitwardenSensitiveInformationParameterModel.description.ilike(search_like), + ) + ) + ), + # BitwardenCreditCardDataParameterModel + exists( + select(1) + .select_from(BitwardenCreditCardDataParameterModel) + .where(BitwardenCreditCardDataParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(BitwardenCreditCardDataParameterModel.deleted_at.is_(None)) + .where( + or_( + BitwardenCreditCardDataParameterModel.key.ilike(search_like), + BitwardenCreditCardDataParameterModel.description.ilike(search_like), + ) + ) + ), + # OnePasswordCredentialParameterModel + exists( + select(1) + .select_from(OnePasswordCredentialParameterModel) + .where(OnePasswordCredentialParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(OnePasswordCredentialParameterModel.deleted_at.is_(None)) + .where( + or_( + OnePasswordCredentialParameterModel.key.ilike(search_like), + OnePasswordCredentialParameterModel.description.ilike(search_like), + ) + ) + ), + # AzureVaultCredentialParameterModel + exists( + select(1) + .select_from(AzureVaultCredentialParameterModel) + .where(AzureVaultCredentialParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(AzureVaultCredentialParameterModel.deleted_at.is_(None)) + .where( + or_( + AzureVaultCredentialParameterModel.key.ilike(search_like), + AzureVaultCredentialParameterModel.description.ilike(search_like), + ) + ) + ), + # CredentialParameterModel + exists( + select(1) + .select_from(CredentialParameterModel) + .where(CredentialParameterModel.workflow_id == WorkflowModel.workflow_id) + .where(CredentialParameterModel.deleted_at.is_(None)) + .where( + or_( + CredentialParameterModel.key.ilike(search_like), + CredentialParameterModel.description.ilike(search_like), + ) + ) + ), + ] + main_query = main_query.where(or_(title_like, folder_title_like, or_(*parameter_filters))) + main_query = ( + main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + ) + workflows = (await session.scalars(main_query)).all() + template_permanent_ids: set[str] = set() + if workflows and organization_id: + template_permanent_ids = await self.get_org_template_permanent_ids(organization_id) + + return [ + convert_to_workflow( + workflow, + self.debug_enabled, + is_template=workflow.workflow_permanent_id in template_permanent_ids, + ) + for workflow in workflows + ] + + @db_operation("update_workflow") + async def update_workflow( + self, + workflow_id: str, + organization_id: str | None = None, + title: str | None = None, + description: str | None = None, + workflow_definition: dict[str, Any] | None = None, + version: int | None = None, + run_with: str | None = None, + cache_key: str | None = None, + status: str | None = None, + import_error: str | None = None, + ) -> Workflow: + async with self.Session() as session: + get_workflow_query = exclude_deleted( + select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel + ) + if organization_id: + get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) + if workflow := (await session.scalars(get_workflow_query)).first(): + if title is not None: + workflow.title = title + if description is not None: + workflow.description = description + if workflow_definition is not None: + workflow.workflow_definition = workflow_definition + if version is not None: + workflow.version = version + if run_with is not None: + workflow.run_with = run_with + if cache_key is not None: + workflow.cache_key = cache_key + if status is not None: + workflow.status = status + if import_error is not None: + workflow.import_error = import_error + await session.commit() + await session.refresh(workflow) + is_template = ( + await self.is_workflow_template( + workflow_permanent_id=workflow.workflow_permanent_id, + organization_id=workflow.organization_id, + ) + if organization_id + else False + ) + return convert_to_workflow( + workflow, + self.debug_enabled, + is_template=is_template, + ) + else: + raise NotFoundError("Workflow not found") + + @db_operation("soft_delete_workflow_and_schedules_by_permanent_id") + async def soft_delete_workflow_and_schedules_by_permanent_id( + self, + workflow_permanent_id: str, + organization_id: str | None = None, + ) -> list[str]: + """Soft-delete a workflow and its active schedules in a single DB transaction.""" + async with self.Session() as session: + select_query = ( + select(WorkflowScheduleModel.workflow_schedule_id) + .where(WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowScheduleModel.deleted_at.is_(None)) + ) + if organization_id is not None: + select_query = select_query.where(WorkflowScheduleModel.organization_id == organization_id) + result = await session.execute(select_query) + schedule_ids = list(result.scalars().all()) + + deleted_at = datetime.now(timezone.utc) + if schedule_ids: + update_schedules_query = ( + update(WorkflowScheduleModel) + .where(WorkflowScheduleModel.workflow_schedule_id.in_(schedule_ids)) + .values(deleted_at=deleted_at) + ) + await session.execute(update_schedules_query) + + update_workflow_query = ( + update(WorkflowModel) + .where(WorkflowModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowModel.deleted_at.is_(None)) + ) + if organization_id is not None: + update_workflow_query = update_workflow_query.filter_by(organization_id=organization_id) + await session.execute(update_workflow_query.values(deleted_at=deleted_at)) + await session.commit() + return schedule_ids + + @db_operation("add_workflow_template") + async def add_workflow_template( + self, + workflow_permanent_id: str, + organization_id: str, + ) -> None: + """Add a workflow to the templates table.""" + async with self.Session() as session: + existing = ( + await session.scalars( + select(WorkflowTemplateModel) + .where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowTemplateModel.organization_id == organization_id) + ) + ).first() + if existing: + if existing.deleted_at is not None: + existing.deleted_at = None + await session.commit() + return + template = WorkflowTemplateModel( + workflow_permanent_id=workflow_permanent_id, + organization_id=organization_id, + ) + session.add(template) + await session.commit() + + @db_operation("remove_workflow_template") + async def remove_workflow_template( + self, + workflow_permanent_id: str, + organization_id: str, + ) -> None: + """Soft delete a workflow from the templates table.""" + async with self.Session() as session: + update_deleted_at_query = ( + update(WorkflowTemplateModel) + .where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowTemplateModel.organization_id == organization_id) + .where(WorkflowTemplateModel.deleted_at.is_(None)) + .values(deleted_at=datetime.now(timezone.utc)) + ) + await session.execute(update_deleted_at_query) + await session.commit() + + @db_operation("get_org_template_permanent_ids") + async def get_org_template_permanent_ids( + self, + organization_id: str, + ) -> set[str]: + """Get all workflow_permanent_ids that are templates for an organization.""" + async with self.Session() as session: + result = await session.scalars( + select(WorkflowTemplateModel.workflow_permanent_id) + .where(WorkflowTemplateModel.organization_id == organization_id) + .where(WorkflowTemplateModel.deleted_at.is_(None)) + ) + return set(result.all()) + + @db_operation("is_workflow_template") + async def is_workflow_template( + self, + workflow_permanent_id: str, + organization_id: str, + ) -> bool: + """Check if a workflow is marked as a template.""" + async with self.Session() as session: + result = ( + await session.scalars( + select(WorkflowTemplateModel) + .where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id) + .where(WorkflowTemplateModel.organization_id == organization_id) + .where(WorkflowTemplateModel.deleted_at.is_(None)) + ) + ).first() + return result is not None diff --git a/tests/unit/forge/__init__.py b/tests/unit/forge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/forge/sdk/__init__.py b/tests/unit/forge/sdk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/forge/sdk/db/__init__.py b/tests/unit/forge/sdk/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/forge/sdk/db/test_base_repository.py b/tests/unit/forge/sdk/db/test_base_repository.py new file mode 100644 index 000000000..c7eaef2e8 --- /dev/null +++ b/tests/unit/forge/sdk/db/test_base_repository.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock + +from sqlalchemy.exc import OperationalError + +from skyvern.forge.sdk.db.base_repository import BaseRepository +from skyvern.forge.sdk.db.protocols import RunReader, TaskReader, WorkflowParameterReader, WorkflowReader +from skyvern.forge.sdk.db.utils import serialize_proxy_location +from skyvern.schemas.runs import GeoTarget, ProxyLocation + + +class TestSerializeProxyLocation: + def test_none(self): + assert serialize_proxy_location(None) is None + + def test_geo_target(self): + geo = GeoTarget(country="US", state="CA") + result = serialize_proxy_location(geo) + assert result is not None + assert "US" in result + + def test_enum(self): + result = serialize_proxy_location(ProxyLocation.RESIDENTIAL) + assert result == "RESIDENTIAL" + + +class TestBaseRepository: + def test_init_with_defaults(self): + mock_session = MagicMock() + repo = BaseRepository(session_factory=mock_session, debug_enabled=True) + assert repo.Session is mock_session + assert repo.debug_enabled is True + + def test_is_retryable_error_default(self): + mock_session = MagicMock() + repo = BaseRepository(session_factory=mock_session) + error = OperationalError("statement", {}, Exception("server closed the connection unexpectedly")) + assert repo.is_retryable_error(error) is False # default returns False + + def test_is_retryable_error_custom(self): + mock_session = MagicMock() + + def custom_fn(e): + return "closed" in str(e).lower() + + repo = BaseRepository(session_factory=mock_session, is_retryable_error_fn=custom_fn) + error = OperationalError("statement", {}, Exception("server closed the connection")) + assert repo.is_retryable_error(error) is True + + +class TestProtocols: + def test_protocols_importable(self): + """Protocols should be importable and runtime-checkable.""" + assert TaskReader is not None + assert WorkflowReader is not None + assert WorkflowParameterReader is not None + assert RunReader is not None diff --git a/tests/unit/forge/sdk/db/test_repositories.py b/tests/unit/forge/sdk/db/test_repositories.py new file mode 100644 index 000000000..840f881e4 --- /dev/null +++ b/tests/unit/forge/sdk/db/test_repositories.py @@ -0,0 +1,238 @@ +"""Tests for all 14 OSS repository instantiations + dependency injection.""" + +from unittest.mock import MagicMock, patch + + +def test_credential_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository + + mock_session = MagicMock() + repo = CredentialRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "create_credential") + assert hasattr(repo, "get_credential") + assert hasattr(repo, "get_credentials") + assert hasattr(repo, "update_credential") + assert hasattr(repo, "delete_credential") + assert hasattr(repo, "create_organization_bitwarden_collection") + assert hasattr(repo, "get_organization_bitwarden_collection") + + +def test_otp_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.otp import OTPRepository + + mock_session = MagicMock() + repo = OTPRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "get_otp_codes") + assert hasattr(repo, "create_otp_code") + + +def test_debug_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.debug import DebugRepository + + mock_session = MagicMock() + repo = DebugRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "get_debug_session") + assert hasattr(repo, "create_debug_session") + assert hasattr(repo, "create_block_run") + + +def test_organizations_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.organizations import OrganizationsRepository + + mock_session = MagicMock() + repo = OrganizationsRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "get_organization") + assert hasattr(repo, "create_organization") + assert hasattr(repo, "create_org_auth_token") + assert hasattr(repo, "validate_org_auth_token") + + +def test_schedules_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.schedules import SchedulesRepository + + mock_session = MagicMock() + repo = SchedulesRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "create_workflow_schedule") + assert hasattr(repo, "get_workflow_schedules") + + +def test_scripts_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.scripts import ScriptsRepository + + mock_session = MagicMock() + repo = ScriptsRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "create_script") + assert hasattr(repo, "get_scripts") + + +def test_workflow_parameters_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.workflow_parameters import WorkflowParametersRepository + + mock_session = MagicMock() + repo = WorkflowParametersRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "get_workflow_parameter") + assert hasattr(repo, "create_workflow_parameter") + + +def test_tasks_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.tasks import TasksRepository + + mock_session = MagicMock() + repo = TasksRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "create_task") + assert hasattr(repo, "get_task") + assert hasattr(repo, "create_step") + + +def test_workflows_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.workflows import WorkflowsRepository + + mock_session = MagicMock() + repo = WorkflowsRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "get_workflow") + assert hasattr(repo, "create_workflow") + assert hasattr(repo, "get_workflow_by_permanent_id") + + +def test_browser_sessions_repository_instantiation(): + from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository + + mock_session = MagicMock() + repo = BrowserSessionsRepository(session_factory=mock_session, debug_enabled=False) + assert repo.Session is mock_session + assert hasattr(repo, "create_browser_profile") + assert hasattr(repo, "get_browser_profile") + + +# ── Cross-dependency repositories ── + + +def test_workflow_runs_repository_with_dependency(): + from skyvern.forge.sdk.db.repositories.workflow_runs import WorkflowRunsRepository + + mock_session = MagicMock() + mock_param_reader = MagicMock() + repo = WorkflowRunsRepository( + session_factory=mock_session, + debug_enabled=False, + workflow_parameter_reader=mock_param_reader, + ) + assert repo.Session is mock_session + assert repo._workflow_parameter_reader is mock_param_reader + assert hasattr(repo, "get_workflow_run_parameters") + assert hasattr(repo, "create_workflow_run") + assert hasattr(repo, "get_workflow_run") + + +def test_artifacts_repository_with_dependency(): + from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository + + mock_session = MagicMock() + mock_run_reader = MagicMock() + repo = ArtifactsRepository( + session_factory=mock_session, + debug_enabled=False, + run_reader=mock_run_reader, + ) + assert repo.Session is mock_session + assert repo._run_reader is mock_run_reader + assert hasattr(repo, "create_artifact") + assert hasattr(repo, "get_artifact") + + +def test_folders_repository_with_dependency(): + from skyvern.forge.sdk.db.repositories.folders import FoldersRepository + + mock_session = MagicMock() + mock_workflow_reader = MagicMock() + repo = FoldersRepository( + session_factory=mock_session, + debug_enabled=False, + workflow_reader=mock_workflow_reader, + ) + assert repo.Session is mock_session + assert repo._workflow_reader is mock_workflow_reader + assert hasattr(repo, "create_folder") + assert hasattr(repo, "update_workflow_folder") + + +def test_observer_repository_with_dependency(): + from skyvern.forge.sdk.db.repositories.observer import ObserverRepository + + mock_session = MagicMock() + mock_task_reader = MagicMock() + repo = ObserverRepository( + session_factory=mock_session, + debug_enabled=False, + task_reader=mock_task_reader, + ) + assert repo.Session is mock_session + assert repo._task_reader is mock_task_reader + assert hasattr(repo, "create_workflow_run_block") + assert hasattr(repo, "get_workflow_run_blocks") + + +# ── AgentDB composition test ── + + +def test_agent_db_has_typed_repo_attributes(): + """After refactoring, AgentDB should expose typed repository attributes.""" + from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository + from skyvern.forge.sdk.db.repositories.tasks import TasksRepository + + with patch("skyvern.forge.sdk.db.agent_db.create_async_engine"): + from skyvern.forge.sdk.db.agent_db import AgentDB + + db = AgentDB("postgresql+asyncpg://test", debug_enabled=True) + assert isinstance(db.tasks, TasksRepository) + assert isinstance(db.credentials, CredentialRepository) + assert hasattr(db, "get_task") # backward compat delegate + assert hasattr(db, "create_workflow") + + +def test_agent_db_delegates_route_to_repositories(): + """Verify delegate methods actually forward to the correct repository.""" + from unittest.mock import AsyncMock + from unittest.mock import patch as mock_patch + + with mock_patch("skyvern.forge.sdk.db.agent_db.create_async_engine"): + from skyvern.forge.sdk.db.agent_db import AgentDB + + db = AgentDB("postgresql+asyncpg://test", debug_enabled=False) + + # Patch a method on each major repository and verify the delegate calls it + delegates_to_check = [ + ("get_task", "tasks"), + ("create_workflow", "workflows"), + ("create_artifact", "artifacts"), + ("get_organization", "organizations"), + ("get_credential", "credentials"), + ("create_workflow_run", "workflow_runs"), + ] + + import asyncio + + loop = asyncio.new_event_loop() + try: + for delegate_name, repo_attr in delegates_to_check: + repo = getattr(db, repo_attr) + mock_method = AsyncMock(return_value="sentinel") + original = getattr(repo, delegate_name) + setattr(repo, delegate_name, mock_method) + try: + result = loop.run_until_complete(getattr(db, delegate_name)("arg1", key="val")) + mock_method.assert_called_once_with("arg1", key="val") + assert result == "sentinel", f"Delegate {delegate_name} did not return repository result" + finally: + setattr(repo, delegate_name, original) + finally: + loop.close() diff --git a/tests/unit/test_agent_db_imports.py b/tests/unit/test_agent_db_imports.py index 653b830a8..f4256c38b 100644 --- a/tests/unit/test_agent_db_imports.py +++ b/tests/unit/test_agent_db_imports.py @@ -5,13 +5,15 @@ Guards against: - Re-export contracts breaking (ScheduleLimitExceededError used by cloud/routes/) """ +from skyvern.forge.sdk.db.base_repository import BaseRepository + def test_agent_db_exports_schedule_limit_exceeded_error() -> None: """ScheduleLimitExceededError must be importable from agent_db (re-export contract).""" from skyvern.forge.sdk.db.agent_db import ScheduleLimitExceededError - # Verify it's the canonical class, not a shadow - from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError as Original + # Verify it's the canonical class from exceptions.py, not a shadow + from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError as Original assert ScheduleLimitExceededError is Original @@ -23,30 +25,35 @@ def test_agent_db_exports_agent_db_class() -> None: assert AgentDB is not None -def test_all_mixins_in_agent_db_mro() -> None: - """All 14 domain mixins must appear in AgentDB's MRO. +def test_all_repositories_on_agent_db() -> None: + """All 14 domain repositories must be present as typed attributes on AgentDB. - If someone re-introduces a late import for any mixin, this test - catches it because the mixin won't be in the class hierarchy. + After the mixin-to-repository refactor, AgentDB uses composition instead + of inheritance. This test verifies every domain repository is wired up + by checking that __init__ assigns BaseRepository instances to each expected name. """ from skyvern.forge.sdk.db.agent_db import AgentDB - mro_names = {cls.__name__ for cls in AgentDB.__mro__} - expected_mixins = [ - "TasksMixin", - "WorkflowsMixin", - "WorkflowRunsMixin", - "WorkflowParametersMixin", - "SchedulesMixin", - "ArtifactsMixin", - "BrowserSessionsMixin", - "ScriptsMixin", - "OTPMixin", - "CredentialsMixin", - "FoldersMixin", - "OrganizationsMixin", - "ObserverMixin", - "DebugMixin", + expected_repos = [ + "tasks", + "workflows", + "workflow_runs", + "workflow_params", + "schedules", + "artifacts", + "browser_sessions", + "scripts", + "otp", + "credentials", + "folders", + "organizations", + "observer", + "debug", ] - for name in expected_mixins: - assert name in mro_names, f"{name} missing from AgentDB MRO" + # Instantiate with a dummy database string (sqlite in-memory) + db = AgentDB("sqlite+aiosqlite:///", debug_enabled=False) + for repo in expected_repos: + assert hasattr(db, repo), f"Repository '{repo}' missing from AgentDB instance" + assert isinstance(getattr(db, repo), BaseRepository), ( + f"AgentDB.{repo} should be a BaseRepository, got {type(getattr(db, repo))}" + ) diff --git a/tests/unit/test_db_operation_decorator.py b/tests/unit/test_db_operation_decorator.py index d571b6961..e039d80ff 100644 --- a/tests/unit/test_db_operation_decorator.py +++ b/tests/unit/test_db_operation_decorator.py @@ -177,7 +177,7 @@ async def test_db_operation_log_errors_false_suppresses_logging() -> None: @pytest.mark.asyncio async def test_db_operation_schedule_limit_exceeded_is_passthrough() -> None: """ScheduleLimitExceededError should be treated as business logic, not unexpected.""" - from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError + from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError class ScheduleDB: @db_operation("create_schedule")