mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
Repository Design Pattern (#SKY-8139) (#5279)
This commit is contained in:
parent
f691c128f3
commit
c91cd98d50
28 changed files with 8968 additions and 71 deletions
|
|
@ -13,20 +13,21 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB
|
||||
from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin
|
||||
from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin
|
||||
from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin
|
||||
from skyvern.forge.sdk.db.mixins.debug import DebugMixin
|
||||
from skyvern.forge.sdk.db.mixins.folders import FoldersMixin
|
||||
from skyvern.forge.sdk.db.mixins.observer import ObserverMixin
|
||||
from skyvern.forge.sdk.db.mixins.organizations import OrganizationsMixin
|
||||
from skyvern.forge.sdk.db.mixins.otp import OTPMixin
|
||||
from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError, SchedulesMixin
|
||||
from skyvern.forge.sdk.db.mixins.scripts import ScriptsMixin
|
||||
from skyvern.forge.sdk.db.mixins.tasks import TasksMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflow_parameters import WorkflowParametersMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflow_runs import WorkflowRunsMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflows import WorkflowsMixin
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError # noqa: F401
|
||||
from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository
|
||||
from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository
|
||||
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
|
||||
from skyvern.forge.sdk.db.repositories.debug import DebugRepository
|
||||
from skyvern.forge.sdk.db.repositories.folders import FoldersRepository
|
||||
from skyvern.forge.sdk.db.repositories.observer import ObserverRepository
|
||||
from skyvern.forge.sdk.db.repositories.organizations import OrganizationsRepository
|
||||
from skyvern.forge.sdk.db.repositories.otp import OTPRepository
|
||||
from skyvern.forge.sdk.db.repositories.schedules import SchedulesRepository
|
||||
from skyvern.forge.sdk.db.repositories.scripts import ScriptsRepository
|
||||
from skyvern.forge.sdk.db.repositories.tasks import TasksRepository
|
||||
from skyvern.forge.sdk.db.repositories.workflow_parameters import WorkflowParametersRepository
|
||||
from skyvern.forge.sdk.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from skyvern.forge.sdk.db.repositories.workflows import WorkflowsRepository
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
_custom_json_serializer,
|
||||
)
|
||||
|
|
@ -104,23 +105,7 @@ def _build_engine(database_string: str) -> AsyncEngine:
|
|||
__all__ = ["AgentDB", "ScheduleLimitExceededError"]
|
||||
|
||||
|
||||
class AgentDB(
|
||||
TasksMixin,
|
||||
WorkflowsMixin,
|
||||
WorkflowRunsMixin,
|
||||
WorkflowParametersMixin,
|
||||
SchedulesMixin,
|
||||
ArtifactsMixin,
|
||||
BrowserSessionsMixin,
|
||||
ScriptsMixin,
|
||||
OTPMixin,
|
||||
CredentialsMixin,
|
||||
FoldersMixin,
|
||||
OrganizationsMixin,
|
||||
ObserverMixin,
|
||||
DebugMixin,
|
||||
BaseAlchemyDB,
|
||||
):
|
||||
class AgentDB(BaseAlchemyDB):
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
|
||||
super().__init__(db_engine or _build_engine(database_string))
|
||||
self.debug_enabled = debug_enabled
|
||||
|
|
@ -131,6 +116,819 @@ class AgentDB(
|
|||
asyncio.Lock() if self.engine.dialect.name == "sqlite" else None
|
||||
)
|
||||
|
||||
# -- Zero-dependency repositories --
|
||||
self.tasks = TasksRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.workflows = WorkflowsRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.workflow_params = WorkflowParametersRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.credentials = CredentialRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.otp = OTPRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.debug = DebugRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.organizations = OrganizationsRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.scripts = ScriptsRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.browser_sessions = BrowserSessionsRepository(self.Session, debug_enabled, self.is_retryable_error)
|
||||
self.schedules = SchedulesRepository(
|
||||
self.Session,
|
||||
debug_enabled,
|
||||
self.is_retryable_error,
|
||||
sqlite_schedule_lock=self._sqlite_schedule_lock,
|
||||
)
|
||||
|
||||
# -- Cross-dependency repositories --
|
||||
self.workflow_runs = WorkflowRunsRepository(
|
||||
self.Session,
|
||||
debug_enabled,
|
||||
self.is_retryable_error,
|
||||
workflow_parameter_reader=self.workflow_params,
|
||||
dialect_name=self.engine.dialect.name,
|
||||
)
|
||||
self.artifacts = ArtifactsRepository(
|
||||
self.Session,
|
||||
debug_enabled,
|
||||
self.is_retryable_error,
|
||||
run_reader=self.workflow_runs,
|
||||
)
|
||||
self.folders = FoldersRepository(
|
||||
self.Session,
|
||||
debug_enabled,
|
||||
self.is_retryable_error,
|
||||
workflow_reader=self.workflows,
|
||||
)
|
||||
self.observer = ObserverRepository(
|
||||
self.Session,
|
||||
debug_enabled,
|
||||
self.is_retryable_error,
|
||||
task_reader=self.tasks,
|
||||
)
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
error_msg = str(error).lower()
|
||||
return "server closed the connection" in error_msg
|
||||
|
||||
# ======================================================================
|
||||
# Backward-compatible delegate methods
|
||||
# TODO(SKY-62): These delegates erase type information (*args: Any -> Any).
|
||||
# Migrate callers to use typed repository attributes directly
|
||||
# (e.g., db.tasks.get_task(...) instead of db.get_task(...)), then remove.
|
||||
# ======================================================================
|
||||
|
||||
# -- Task delegates --
|
||||
|
||||
async def create_task(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.create_task(*args, **kwargs)
|
||||
|
||||
async def create_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.create_step(*args, **kwargs)
|
||||
|
||||
async def get_task(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task(*args, **kwargs)
|
||||
|
||||
async def get_tasks_by_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_tasks_by_ids(*args, **kwargs)
|
||||
|
||||
async def get_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_step(*args, **kwargs)
|
||||
|
||||
async def get_task_steps(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task_steps(*args, **kwargs)
|
||||
|
||||
async def get_steps_by_task_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_steps_by_task_ids(*args, **kwargs)
|
||||
|
||||
async def get_total_unique_step_order_count_by_task_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_total_unique_step_order_count_by_task_ids(*args, **kwargs)
|
||||
|
||||
async def get_task_step_models(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task_step_models(*args, **kwargs)
|
||||
|
||||
async def get_task_step_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task_step_count(*args, **kwargs)
|
||||
|
||||
async def get_task_actions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task_actions(*args, **kwargs)
|
||||
|
||||
async def get_task_actions_hydrated(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_task_actions_hydrated(*args, **kwargs)
|
||||
|
||||
async def get_tasks_actions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_tasks_actions(*args, **kwargs)
|
||||
|
||||
async def get_action_count_for_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_action_count_for_step(*args, **kwargs)
|
||||
|
||||
async def get_first_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_first_step(*args, **kwargs)
|
||||
|
||||
async def get_latest_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_latest_step(*args, **kwargs)
|
||||
|
||||
async def update_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.update_step(*args, **kwargs)
|
||||
|
||||
async def clear_task_failure_reason(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.clear_task_failure_reason(*args, **kwargs)
|
||||
|
||||
async def update_task(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.update_task(*args, **kwargs)
|
||||
|
||||
async def update_task_2fa_state(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.update_task_2fa_state(*args, **kwargs)
|
||||
|
||||
async def bulk_update_tasks(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.bulk_update_tasks(*args, **kwargs)
|
||||
|
||||
async def get_tasks(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_tasks(*args, **kwargs)
|
||||
|
||||
async def get_tasks_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_tasks_count(*args, **kwargs)
|
||||
|
||||
async def get_running_tasks_info_globally(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_running_tasks_info_globally(*args, **kwargs)
|
||||
|
||||
async def get_latest_task_by_workflow_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_latest_task_by_workflow_id(*args, **kwargs)
|
||||
|
||||
async def get_last_task_for_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_last_task_for_workflow_run(*args, **kwargs)
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_tasks_by_workflow_run_id(*args, **kwargs)
|
||||
|
||||
async def delete_task_steps(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.delete_task_steps(*args, **kwargs)
|
||||
|
||||
async def get_previous_actions_for_task(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.get_previous_actions_for_task(*args, **kwargs)
|
||||
|
||||
async def delete_task_actions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.delete_task_actions(*args, **kwargs)
|
||||
|
||||
# -- Workflow delegates --
|
||||
|
||||
async def create_workflow(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.create_workflow(*args, **kwargs)
|
||||
|
||||
async def soft_delete_workflow_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.soft_delete_workflow_by_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflow(*args, **kwargs)
|
||||
|
||||
async def get_workflow_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflow_by_permanent_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_for_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflow_for_workflow_run(*args, **kwargs)
|
||||
|
||||
async def get_workflow_versions_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflow_versions_by_permanent_id(*args, **kwargs)
|
||||
|
||||
async def get_workflows_by_permanent_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflows_by_permanent_ids(*args, **kwargs)
|
||||
|
||||
async def get_workflows_by_organization_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_workflows_by_organization_id(*args, **kwargs)
|
||||
|
||||
async def update_workflow(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.update_workflow(*args, **kwargs)
|
||||
|
||||
async def soft_delete_workflow_and_schedules_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.soft_delete_workflow_and_schedules_by_permanent_id(*args, **kwargs)
|
||||
|
||||
async def add_workflow_template(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.add_workflow_template(*args, **kwargs)
|
||||
|
||||
async def remove_workflow_template(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.remove_workflow_template(*args, **kwargs)
|
||||
|
||||
async def get_org_template_permanent_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.get_org_template_permanent_ids(*args, **kwargs)
|
||||
|
||||
async def is_workflow_template(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflows.is_workflow_template(*args, **kwargs)
|
||||
|
||||
# -- Workflow run delegates --
|
||||
|
||||
async def get_running_workflow_runs_info_globally(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_running_workflow_runs_info_globally(*args, **kwargs)
|
||||
|
||||
async def create_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.create_workflow_run(*args, **kwargs)
|
||||
|
||||
async def update_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.update_workflow_run(*args, **kwargs)
|
||||
|
||||
async def bulk_update_workflow_runs(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.bulk_update_workflow_runs(*args, **kwargs)
|
||||
|
||||
async def clear_workflow_run_failure_reason(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.clear_workflow_run_failure_reason(*args, **kwargs)
|
||||
|
||||
async def get_all_runs(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_all_runs(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_run(*args, **kwargs)
|
||||
|
||||
async def get_last_queued_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_last_queued_workflow_run(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_by_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_runs_by_ids(*args, **kwargs)
|
||||
|
||||
async def get_last_running_workflow_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_last_running_workflow_run(*args, **kwargs)
|
||||
|
||||
async def get_last_workflow_run_for_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_last_workflow_run_for_browser_session(*args, **kwargs)
|
||||
|
||||
async def get_last_workflow_run_for_browser_address(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_last_workflow_run_for_browser_address(*args, **kwargs)
|
||||
|
||||
async def get_workflows_depending_on(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflows_depending_on(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_runs(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_runs_count(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_for_workflow_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_runs_for_workflow_permanent_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_by_parent_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_output_parameters(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_run_output_parameters(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_output_parameter_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_run_output_parameter_by_id(*args, **kwargs)
|
||||
|
||||
async def create_or_update_workflow_run_output_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.create_or_update_workflow_run_output_parameter(*args, **kwargs)
|
||||
|
||||
async def update_workflow_run_output_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.update_workflow_run_output_parameter(*args, **kwargs)
|
||||
|
||||
async def create_workflow_run_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.create_workflow_run_parameter(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_parameters(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_run_parameters(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_block_errors(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_runs.get_workflow_run_block_errors(*args, **kwargs)
|
||||
|
||||
# -- Workflow parameter delegates --
|
||||
|
||||
async def create_workflow_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_workflow_parameter(*args, **kwargs)
|
||||
|
||||
async def create_aws_secret_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_aws_secret_parameter(*args, **kwargs)
|
||||
|
||||
async def create_output_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_output_parameter(*args, **kwargs)
|
||||
|
||||
async def save_workflow_definition_parameters(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.save_workflow_definition_parameters(*args, **kwargs)
|
||||
|
||||
async def get_workflow_output_parameters(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_output_parameters(*args, **kwargs)
|
||||
|
||||
async def get_workflow_output_parameters_by_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_output_parameters_by_ids(*args, **kwargs)
|
||||
|
||||
async def get_workflow_parameters(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_parameters(*args, **kwargs)
|
||||
|
||||
async def get_workflow_parameter(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_parameter(*args, **kwargs)
|
||||
|
||||
async def create_task_generation(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_task_generation(*args, **kwargs)
|
||||
|
||||
async def create_ai_suggestion(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_ai_suggestion(*args, **kwargs)
|
||||
|
||||
async def create_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_workflow_copilot_chat(*args, **kwargs)
|
||||
|
||||
async def update_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_workflow_copilot_chat(*args, **kwargs)
|
||||
|
||||
async def create_workflow_copilot_chat_message(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_workflow_copilot_chat_message(*args, **kwargs)
|
||||
|
||||
async def get_workflow_copilot_chat_messages(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_copilot_chat_messages(*args, **kwargs)
|
||||
|
||||
async def get_workflow_copilot_chat_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_workflow_copilot_chat_by_id(*args, **kwargs)
|
||||
|
||||
async def get_latest_workflow_copilot_chat(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_latest_workflow_copilot_chat(*args, **kwargs)
|
||||
|
||||
async def get_task_generation_by_prompt_hash(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_task_generation_by_prompt_hash(*args, **kwargs)
|
||||
|
||||
async def create_action(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_action(*args, **kwargs)
|
||||
|
||||
async def update_action_reasoning(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_action_reasoning(*args, **kwargs)
|
||||
|
||||
async def retrieve_action_plan(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.retrieve_action_plan(*args, **kwargs)
|
||||
|
||||
async def create_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_task_run(*args, **kwargs)
|
||||
|
||||
async def update_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_task_run(*args, **kwargs)
|
||||
|
||||
async def update_job_run_compute_cost(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_job_run_compute_cost(*args, **kwargs)
|
||||
|
||||
async def cache_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.cache_task_run(*args, **kwargs)
|
||||
|
||||
async def get_cached_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_cached_task_run(*args, **kwargs)
|
||||
|
||||
async def get_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_run(*args, **kwargs)
|
||||
|
||||
# -- Artifact delegates --
|
||||
|
||||
async def create_artifact(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.create_artifact(*args, **kwargs)
|
||||
|
||||
async def bulk_create_artifacts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.bulk_create_artifacts(*args, **kwargs)
|
||||
|
||||
async def get_artifacts_for_task_v2(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifacts_for_task_v2(*args, **kwargs)
|
||||
|
||||
async def get_artifacts_for_task_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifacts_for_task_step(*args, **kwargs)
|
||||
|
||||
async def get_artifacts_for_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifacts_for_run(*args, **kwargs)
|
||||
|
||||
async def get_artifact_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifact_by_id(*args, **kwargs)
|
||||
|
||||
async def get_artifacts_by_ids(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifacts_by_ids(*args, **kwargs)
|
||||
|
||||
async def get_artifacts_by_entity_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifacts_by_entity_id(*args, **kwargs)
|
||||
|
||||
async def get_artifact_by_entity_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifact_by_entity_id(*args, **kwargs)
|
||||
|
||||
async def get_artifact(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifact(*args, **kwargs)
|
||||
|
||||
async def get_artifact_for_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifact_for_run(*args, **kwargs)
|
||||
|
||||
async def get_latest_artifact(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_latest_artifact(*args, **kwargs)
|
||||
|
||||
async def get_latest_n_artifacts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_latest_n_artifacts(*args, **kwargs)
|
||||
|
||||
async def delete_task_artifacts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.delete_task_artifacts(*args, **kwargs)
|
||||
|
||||
async def delete_task_v2_artifacts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.delete_task_v2_artifacts(*args, **kwargs)
|
||||
|
||||
async def update_action_screenshot_artifact_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.update_action_screenshot_artifact_id(*args, **kwargs)
|
||||
|
||||
# -- Browser session delegates --
|
||||
|
||||
async def create_browser_profile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.create_browser_profile(*args, **kwargs)
|
||||
|
||||
async def get_browser_profile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_browser_profile(*args, **kwargs)
|
||||
|
||||
async def list_browser_profiles(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.list_browser_profiles(*args, **kwargs)
|
||||
|
||||
async def delete_browser_profile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.delete_browser_profile(*args, **kwargs)
|
||||
|
||||
async def get_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_active_persistent_browser_sessions(*args, **kwargs)
|
||||
|
||||
async def get_persistent_browser_sessions_history(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_persistent_browser_sessions_history(*args, **kwargs)
|
||||
|
||||
async def get_persistent_browser_session_by_runnable_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_persistent_browser_session_by_runnable_id(*args, **kwargs)
|
||||
|
||||
async def get_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def create_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.create_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def update_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.update_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def set_persistent_browser_session_browser_address(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.set_persistent_browser_session_browser_address(*args, **kwargs)
|
||||
|
||||
async def update_persistent_browser_session_compute_cost(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.update_persistent_browser_session_compute_cost(*args, **kwargs)
|
||||
|
||||
async def mark_persistent_browser_session_deleted(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.mark_persistent_browser_session_deleted(*args, **kwargs)
|
||||
|
||||
async def occupy_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.occupy_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def release_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.release_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def close_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.close_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def get_all_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_all_active_persistent_browser_sessions(*args, **kwargs)
|
||||
|
||||
async def archive_browser_session_address(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.archive_browser_session_address(*args, **kwargs)
|
||||
|
||||
async def get_uncompleted_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_uncompleted_persistent_browser_sessions(*args, **kwargs)
|
||||
|
||||
async def get_debug_session_by_browser_session_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_debug_session_by_browser_session_id(*args, **kwargs)
|
||||
|
||||
# -- Schedule delegates --
|
||||
|
||||
async def create_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.create_workflow_schedule(*args, **kwargs)
|
||||
|
||||
async def create_workflow_schedule_with_limit(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.create_workflow_schedule_with_limit(*args, **kwargs)
|
||||
|
||||
async def set_temporal_schedule_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.set_temporal_schedule_id(*args, **kwargs)
|
||||
|
||||
async def update_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.update_workflow_schedule(*args, **kwargs)
|
||||
|
||||
async def get_workflow_schedule_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.get_workflow_schedule_by_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_schedules(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.get_workflow_schedules(*args, **kwargs)
|
||||
|
||||
async def get_all_enabled_schedules(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.get_all_enabled_schedules(*args, **kwargs)
|
||||
|
||||
async def has_schedule_fired_since(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.has_schedule_fired_since(*args, **kwargs)
|
||||
|
||||
async def update_workflow_schedule_enabled(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.update_workflow_schedule_enabled(*args, **kwargs)
|
||||
|
||||
async def delete_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.delete_workflow_schedule(*args, **kwargs)
|
||||
|
||||
async def restore_workflow_schedule(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.restore_workflow_schedule(*args, **kwargs)
|
||||
|
||||
async def count_workflow_schedules(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.count_workflow_schedules(*args, **kwargs)
|
||||
|
||||
async def list_organization_schedules(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.list_organization_schedules(*args, **kwargs)
|
||||
|
||||
async def soft_delete_orphaned_schedules(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.schedules.soft_delete_orphaned_schedules(*args, **kwargs)
|
||||
|
||||
# -- Script delegates --
|
||||
|
||||
async def create_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.create_script(*args, **kwargs)
|
||||
|
||||
async def get_scripts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_scripts(*args, **kwargs)
|
||||
|
||||
async def get_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script(*args, **kwargs)
|
||||
|
||||
async def get_script_revision(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_revision(*args, **kwargs)
|
||||
|
||||
async def get_latest_script_version(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_latest_script_version(*args, **kwargs)
|
||||
|
||||
async def get_script_versions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_versions(*args, **kwargs)
|
||||
|
||||
async def get_script_version_stats(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_version_stats(*args, **kwargs)
|
||||
|
||||
async def soft_delete_script_by_revision(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.soft_delete_script_by_revision(*args, **kwargs)
|
||||
|
||||
async def create_script_file(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.create_script_file(*args, **kwargs)
|
||||
|
||||
async def create_script_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.create_script_block(*args, **kwargs)
|
||||
|
||||
async def update_script_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.update_script_block(*args, **kwargs)
|
||||
|
||||
async def get_script_files(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_files(*args, **kwargs)
|
||||
|
||||
async def get_script_file_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_file_by_id(*args, **kwargs)
|
||||
|
||||
async def get_script_file_by_path(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_file_by_path(*args, **kwargs)
|
||||
|
||||
async def update_script_file(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.update_script_file(*args, **kwargs)
|
||||
|
||||
async def get_script_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_block(*args, **kwargs)
|
||||
|
||||
async def get_script_block_by_label(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_block_by_label(*args, **kwargs)
|
||||
|
||||
async def get_script_blocks_by_script_revision_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_blocks_by_script_revision_id(*args, **kwargs)
|
||||
|
||||
async def create_workflow_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.create_workflow_script(*args, **kwargs)
|
||||
|
||||
async def get_workflow_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_script(*args, **kwargs)
|
||||
|
||||
async def get_workflow_script_by_cache_key_value(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_script_by_cache_key_value(*args, **kwargs)
|
||||
|
||||
async def get_workflow_cache_key_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_cache_key_count(*args, **kwargs)
|
||||
|
||||
async def get_workflow_cache_key_values(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_cache_key_values(*args, **kwargs)
|
||||
|
||||
async def delete_workflow_cache_key_value(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.delete_workflow_cache_key_value(*args, **kwargs)
|
||||
|
||||
async def delete_workflow_scripts_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.delete_workflow_scripts_by_permanent_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_scripts_by_permanent_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_scripts_by_permanent_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_for_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_workflow_runs_for_script(*args, **kwargs)
|
||||
|
||||
async def get_script_run_stats(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_script_run_stats(*args, **kwargs)
|
||||
|
||||
async def is_script_pinned(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.is_script_pinned(*args, **kwargs)
|
||||
|
||||
async def pin_workflow_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.pin_workflow_script(*args, **kwargs)
|
||||
|
||||
async def unpin_workflow_script(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.unpin_workflow_script(*args, **kwargs)
|
||||
|
||||
async def create_fallback_episode(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.create_fallback_episode(*args, **kwargs)
|
||||
|
||||
async def get_unreviewed_episodes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_unreviewed_episodes(*args, **kwargs)
|
||||
|
||||
async def update_fallback_episode(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.update_fallback_episode(*args, **kwargs)
|
||||
|
||||
async def delete_fallback_episode(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.delete_fallback_episode(*args, **kwargs)
|
||||
|
||||
async def get_fallback_episodes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_fallback_episodes(*args, **kwargs)
|
||||
|
||||
async def get_fallback_episodes_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_fallback_episodes_count(*args, **kwargs)
|
||||
|
||||
async def get_fallback_episode(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_fallback_episode(*args, **kwargs)
|
||||
|
||||
async def mark_episode_reviewed(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.mark_episode_reviewed(*args, **kwargs)
|
||||
|
||||
async def get_recent_reviewed_episodes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_recent_reviewed_episodes(*args, **kwargs)
|
||||
|
||||
async def record_branch_hit(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.record_branch_hit(*args, **kwargs)
|
||||
|
||||
async def get_stale_branches(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.scripts.get_stale_branches(*args, **kwargs)
|
||||
|
||||
# -- OTP delegates --
|
||||
|
||||
async def get_otp_codes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.otp.get_otp_codes(*args, **kwargs)
|
||||
|
||||
async def get_otp_codes_by_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.otp.get_otp_codes_by_run(*args, **kwargs)
|
||||
|
||||
async def get_recent_otp_codes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.otp.get_recent_otp_codes(*args, **kwargs)
|
||||
|
||||
async def create_otp_code(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.otp.create_otp_code(*args, **kwargs)
|
||||
|
||||
# -- Credential delegates --
|
||||
|
||||
async def create_credential(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.create_credential(*args, **kwargs)
|
||||
|
||||
async def get_credential(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.get_credential(*args, **kwargs)
|
||||
|
||||
async def get_credentials(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.get_credentials(*args, **kwargs)
|
||||
|
||||
async def update_credential(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.update_credential(*args, **kwargs)
|
||||
|
||||
async def update_credential_vault_data(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.update_credential_vault_data(*args, **kwargs)
|
||||
|
||||
async def delete_credential(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.delete_credential(*args, **kwargs)
|
||||
|
||||
async def create_organization_bitwarden_collection(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.create_organization_bitwarden_collection(*args, **kwargs)
|
||||
|
||||
async def get_organization_bitwarden_collection(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.credentials.get_organization_bitwarden_collection(*args, **kwargs)
|
||||
|
||||
# -- Folder delegates --
|
||||
|
||||
async def create_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.create_folder(*args, **kwargs)
|
||||
|
||||
async def get_folders(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.get_folders(*args, **kwargs)
|
||||
|
||||
async def get_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.get_folder(*args, **kwargs)
|
||||
|
||||
async def update_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.update_folder(*args, **kwargs)
|
||||
|
||||
async def get_workflow_permanent_ids_in_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.get_workflow_permanent_ids_in_folder(*args, **kwargs)
|
||||
|
||||
async def soft_delete_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.soft_delete_folder(*args, **kwargs)
|
||||
|
||||
async def get_folder_workflow_count(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.get_folder_workflow_count(*args, **kwargs)
|
||||
|
||||
async def get_folder_workflow_counts_batch(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.get_folder_workflow_counts_batch(*args, **kwargs)
|
||||
|
||||
async def update_workflow_folder(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.folders.update_workflow_folder(*args, **kwargs)
|
||||
|
||||
# -- Organization delegates --
|
||||
|
||||
async def get_active_verification_requests(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_active_verification_requests(*args, **kwargs)
|
||||
|
||||
async def get_all_organizations(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_all_organizations(*args, **kwargs)
|
||||
|
||||
async def get_organization(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_organization(*args, **kwargs)
|
||||
|
||||
async def get_organization_by_domain(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_organization_by_domain(*args, **kwargs)
|
||||
|
||||
async def create_organization(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.create_organization(*args, **kwargs)
|
||||
|
||||
async def update_organization(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.update_organization(*args, **kwargs)
|
||||
|
||||
async def get_valid_org_auth_token(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_valid_org_auth_token(*args, **kwargs)
|
||||
|
||||
async def get_valid_org_auth_tokens(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.get_valid_org_auth_tokens(*args, **kwargs)
|
||||
|
||||
async def validate_org_auth_token(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.validate_org_auth_token(*args, **kwargs)
|
||||
|
||||
async def create_org_auth_token(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.create_org_auth_token(*args, **kwargs)
|
||||
|
||||
async def invalidate_org_auth_tokens(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.invalidate_org_auth_tokens(*args, **kwargs)
|
||||
|
||||
# -- Observer delegates --
|
||||
|
||||
async def get_task_v2(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_task_v2(*args, **kwargs)
|
||||
|
||||
async def delete_thoughts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.delete_thoughts(*args, **kwargs)
|
||||
|
||||
async def get_task_v2_by_workflow_run_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_task_v2_by_workflow_run_id(*args, **kwargs)
|
||||
|
||||
async def get_thought(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_thought(*args, **kwargs)
|
||||
|
||||
async def get_thoughts(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_thoughts(*args, **kwargs)
|
||||
|
||||
async def create_task_v2(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.create_task_v2(*args, **kwargs)
|
||||
|
||||
async def create_thought(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.create_thought(*args, **kwargs)
|
||||
|
||||
async def update_thought(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.update_thought(*args, **kwargs)
|
||||
|
||||
async def update_task_v2(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.update_task_v2(*args, **kwargs)
|
||||
|
||||
async def create_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.create_workflow_run_block(*args, **kwargs)
|
||||
|
||||
async def delete_workflow_run_blocks(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.delete_workflow_run_blocks(*args, **kwargs)
|
||||
|
||||
async def update_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.update_workflow_run_block(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_block(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_workflow_run_block(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_block_by_task_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_workflow_run_block_by_task_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_run_blocks(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.observer.get_workflow_run_blocks(*args, **kwargs)
|
||||
|
||||
# -- Debug delegates --
|
||||
|
||||
async def get_debug_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_debug_session(*args, **kwargs)
|
||||
|
||||
async def get_latest_block_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_latest_block_run(*args, **kwargs)
|
||||
|
||||
async def get_latest_completed_block_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_latest_completed_block_run(*args, **kwargs)
|
||||
|
||||
async def create_block_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.create_block_run(*args, **kwargs)
|
||||
|
||||
async def get_latest_debug_session_for_user(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_latest_debug_session_for_user(*args, **kwargs)
|
||||
|
||||
async def get_debug_session_by_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_debug_session_by_id(*args, **kwargs)
|
||||
|
||||
async def get_workflow_runs_by_debug_session_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.get_workflow_runs_by_debug_session_id(*args, **kwargs)
|
||||
|
||||
async def complete_debug_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.complete_debug_sessions(*args, **kwargs)
|
||||
|
||||
async def create_debug_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.debug.create_debug_session(*args, **kwargs)
|
||||
|
||||
# -- NEW delegate methods (missing from branch) --
|
||||
|
||||
async def get_artifact_by_id_no_org(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.artifacts.get_artifact_by_id_no_org(*args, **kwargs)
|
||||
|
||||
async def replace_org_auth_token(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.organizations.replace_org_auth_token(*args, **kwargs)
|
||||
|
|
|
|||
33
skyvern/forge/sdk/db/base_repository.py
Normal file
33
skyvern/forge/sdk/db/base_repository.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Base class for all repository classes extracted from AgentDB mixins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""Base for domain-specific repositories.
|
||||
|
||||
Provides the session factory, debug flag, and retryable-error check
|
||||
that decorators like ``read_retry`` and ``db_operation`` rely on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
) -> None:
|
||||
self.Session = session_factory
|
||||
self.debug_enabled = debug_enabled
|
||||
self._is_retryable_error_fn = is_retryable_error_fn
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
if self._is_retryable_error_fn:
|
||||
return self._is_retryable_error_fn(error)
|
||||
return False
|
||||
|
|
@ -1,2 +1,13 @@
|
|||
class NotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleLimitExceededError(Exception):
|
||||
"""Raised when attempting to create a schedule that would exceed the per-workflow limit."""
|
||||
|
||||
def __init__(self, organization_id: str, workflow_permanent_id: str, current_count: int, max_allowed: int):
|
||||
self.organization_id = organization_id
|
||||
self.workflow_permanent_id = workflow_permanent_id
|
||||
self.current_count = current_count
|
||||
self.max_allowed = max_allowed
|
||||
super().__init__(f"Schedule limit {max_allowed} reached (current: {current_count})")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
# DEPRECATED: These mixin classes have been superseded by the repository pattern
|
||||
# in skyvern/forge/sdk/db/repositories/. AgentDB no longer inherits from these
|
||||
# mixins — it composes repository instances instead. These files are retained
|
||||
# temporarily to avoid breaking any downstream code that may import from here.
|
||||
# TODO(SKY-62): Remove after 2026-Q2 once all imports are verified migrated.
|
||||
from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin
|
||||
from skyvern.forge.sdk.db.mixins.base import BaseAlchemyDB, read_retry
|
||||
from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any
|
|||
import structlog
|
||||
from sqlalchemy import func, or_, select, text, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation, register_passthrough_exception
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
|
|
@ -26,20 +27,6 @@ LOG = structlog.get_logger()
|
|||
_UNSET = object()
|
||||
|
||||
|
||||
class ScheduleLimitExceededError(Exception):
|
||||
"""Raised when attempting to create a schedule that would exceed the per-workflow limit."""
|
||||
|
||||
def __init__(self, organization_id: str, workflow_permanent_id: str, current_count: int, max_allowed: int):
|
||||
self.organization_id = organization_id
|
||||
self.workflow_permanent_id = workflow_permanent_id
|
||||
self.current_count = current_count
|
||||
self.max_allowed = max_allowed
|
||||
super().__init__(f"Schedule limit {max_allowed} reached (current: {current_count})")
|
||||
|
||||
|
||||
register_passthrough_exception(ScheduleLimitExceededError)
|
||||
|
||||
|
||||
class SchedulesMixin:
|
||||
"""Database operations for workflow schedules."""
|
||||
|
||||
|
|
|
|||
42
skyvern/forge/sdk/db/protocols.py
Normal file
42
skyvern/forge/sdk/db/protocols.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Typed Protocol contracts for cross-repository dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TaskReader(Protocol):
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None: ...
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WorkflowReader(Protocol):
|
||||
async def get_workflow_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
version: int | None = None,
|
||||
ignore_version: int | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WorkflowParameterReader(Protocol):
|
||||
async def get_workflow_parameter(
|
||||
self,
|
||||
workflow_parameter_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowParameter | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class RunReader(Protocol):
|
||||
async def get_run(self, run_id: str, organization_id: str | None = None) -> WorkflowRun | None: ...
|
||||
4
skyvern/forge/sdk/db/repositories/__init__.py
Normal file
4
skyvern/forge/sdk/db/repositories/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""Domain-specific repository classes extracted from AgentDB mixins."""
|
||||
|
||||
# Sentinel for distinguishing "not passed" from "passed as None" in update methods.
|
||||
_UNSET = object()
|
||||
452
skyvern/forge/sdk/db/repositories/artifacts.py
Normal file
452
skyvern/forge/sdk/db/repositories/artifacts.py
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, delete, or_, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.models import ActionModel, ArtifactModel
|
||||
from skyvern.forge.sdk.db.protocols import RunReader
|
||||
from skyvern.forge.sdk.db.utils import convert_to_artifact
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactsRepository(BaseRepository):
|
||||
"""Database operations for artifact management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
run_reader: RunReader | None = None,
|
||||
) -> None:
|
||||
super().__init__(session_factory, debug_enabled, is_retryable_error_fn)
|
||||
self._run_reader = run_reader
|
||||
|
||||
@db_operation("create_artifact")
|
||||
async def create_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
artifact_type: str,
|
||||
uri: str,
|
||||
organization_id: str,
|
||||
step_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
ai_suggestion_id: str | None = None,
|
||||
) -> Artifact:
|
||||
async with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
observer_cruise_id=task_v2_id,
|
||||
observer_thought_id=thought_id,
|
||||
run_id=run_id,
|
||||
ai_suggestion_id=ai_suggestion_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
await session.commit()
|
||||
await session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
|
||||
@db_operation("bulk_create_artifacts")
|
||||
async def bulk_create_artifacts(
|
||||
self,
|
||||
artifact_models: list[ArtifactModel],
|
||||
) -> list[Artifact]:
|
||||
"""
|
||||
Bulk create multiple artifacts in a single database transaction.
|
||||
|
||||
Args:
|
||||
artifact_models: List of ArtifactModel instances to insert
|
||||
|
||||
Returns:
|
||||
List of created Artifact objects
|
||||
"""
|
||||
if not artifact_models:
|
||||
return []
|
||||
|
||||
async with self.Session() as session:
|
||||
session.add_all(artifact_models)
|
||||
await session.commit()
|
||||
|
||||
# Refresh all artifacts to get their created_at and modified_at values
|
||||
for artifact in artifact_models:
|
||||
await session.refresh(artifact)
|
||||
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifact_models]
|
||||
|
||||
@db_operation("get_artifacts_for_task_v2")
|
||||
async def get_artifacts_for_task_v2(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
organization_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ArtifactModel)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
if artifact_types:
|
||||
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
query = query.order_by(ArtifactModel.created_at)
|
||||
if artifacts := (await session.scalars(query)).all():
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_artifacts_for_task_step")
|
||||
async def get_artifacts_for_task_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
if artifacts := (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(ArtifactModel.created_at)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_artifacts_for_run")
|
||||
async def get_artifacts_for_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
group_by_type: bool = False,
|
||||
sort_by: str = "created_at",
|
||||
) -> dict[ArtifactType, list[Artifact]] | list[Artifact]:
|
||||
"""Return artifacts associated with a run.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to get artifacts for
|
||||
organization_id: The ID of the organization that owns the run
|
||||
artifact_types: Optional list of artifact types to filter by
|
||||
group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If False, returns a flat list of artifacts. Defaults to False.
|
||||
sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'.
|
||||
Defaults to 'created_at'.
|
||||
|
||||
Returns:
|
||||
If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If group_by_type is False, returns a list of artifacts sorted by the specified field.
|
||||
|
||||
Raises:
|
||||
ValueError: If sort_by is not one of the allowed values
|
||||
"""
|
||||
allowed_sort_fields = {"created_at", "step_id", "task_id"}
|
||||
if sort_by not in allowed_sort_fields:
|
||||
raise ValueError(f"sort_by must be one of {allowed_sort_fields}")
|
||||
if self._run_reader is None:
|
||||
raise RuntimeError("run_reader dependency not set")
|
||||
run = await self._run_reader.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
return []
|
||||
|
||||
async with self.Session() as session:
|
||||
query = select(ArtifactModel).filter_by(organization_id=organization_id)
|
||||
|
||||
query = query.filter_by(run_id=run.run_id)
|
||||
|
||||
if artifact_types:
|
||||
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == "created_at":
|
||||
query = query.order_by(ArtifactModel.created_at)
|
||||
elif sort_by == "step_id":
|
||||
query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at)
|
||||
elif sort_by == "task_id":
|
||||
query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at)
|
||||
|
||||
# Execute query and convert to Artifact objects
|
||||
artifacts = [
|
||||
convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all()
|
||||
]
|
||||
|
||||
# Group artifacts by type if requested
|
||||
if group_by_type:
|
||||
result: dict[ArtifactType, list[Artifact]] = {}
|
||||
for artifact in artifacts:
|
||||
if artifact.artifact_type not in result:
|
||||
result[artifact.artifact_type] = []
|
||||
result[artifact.artifact_type].append(artifact)
|
||||
return result
|
||||
|
||||
return artifacts
|
||||
|
||||
@db_operation("get_artifact_by_id")
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
if artifact := (
|
||||
await session.scalars(
|
||||
select(ArtifactModel).filter_by(artifact_id=artifact_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_artifact_by_id_no_org(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> Artifact | None:
|
||||
"""Fetch an artifact by ID without an organization filter.
|
||||
|
||||
Only use this when the caller has already verified authorization through
|
||||
an out-of-band mechanism (e.g. a valid HMAC-signed URL).
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
if artifact := (await session.scalars(select(ArtifactModel).filter_by(artifact_id=artifact_id))).first():
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_artifacts_by_ids")
|
||||
async def get_artifacts_by_ids(
|
||||
self,
|
||||
artifact_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> list[Artifact]:
|
||||
if not artifact_ids:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
artifacts = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter(ArtifactModel.artifact_id.in_(artifact_ids))
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
|
||||
@db_operation("get_artifacts_by_entity_id")
|
||||
async def get_artifacts_by_entity_id(
|
||||
self,
|
||||
*,
|
||||
organization_id: str | None,
|
||||
artifact_type: ArtifactType | None = None,
|
||||
task_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
# Build base query
|
||||
query = select(ArtifactModel)
|
||||
|
||||
if artifact_type is not None:
|
||||
query = query.filter_by(artifact_type=artifact_type)
|
||||
if task_id is not None:
|
||||
query = query.filter_by(task_id=task_id)
|
||||
if step_id is not None:
|
||||
query = query.filter_by(step_id=step_id)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter_by(workflow_run_id=workflow_run_id)
|
||||
if workflow_run_block_id is not None:
|
||||
query = query.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
if thought_id is not None:
|
||||
query = query.filter_by(observer_thought_id=thought_id)
|
||||
if task_v2_id is not None:
|
||||
query = query.filter_by(observer_cruise_id=task_v2_id)
|
||||
# Handle backward compatibility where old artifact rows were stored with organization_id NULL
|
||||
if organization_id is not None:
|
||||
query = query.filter(
|
||||
or_(ArtifactModel.organization_id == organization_id, ArtifactModel.organization_id.is_(None))
|
||||
)
|
||||
|
||||
query = query.order_by(ArtifactModel.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
artifacts = (await session.scalars(query)).all()
|
||||
LOG.debug("Artifacts fetched", count=len(artifacts))
|
||||
return [convert_to_artifact(a, self.debug_enabled) for a in artifacts]
|
||||
|
||||
@db_operation("get_artifact_by_entity_id")
|
||||
async def get_artifact_by_entity_id(
|
||||
self,
|
||||
*,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
artifacts = await self.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=artifact_type,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
thought_id=thought_id,
|
||||
task_v2_id=task_v2_id,
|
||||
limit=1,
|
||||
)
|
||||
return artifacts[0] if artifacts else None
|
||||
|
||||
@db_operation("get_artifact")
|
||||
async def get_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_artifact_for_run")
|
||||
async def get_artifact_for_run(
|
||||
self,
|
||||
run_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter(ArtifactModel.run_id == run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_artifact")
|
||||
async def get_latest_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
artifacts = await self.get_latest_n_artifacts(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
artifact_types=artifact_types,
|
||||
organization_id=organization_id,
|
||||
n=1,
|
||||
)
|
||||
if artifacts:
|
||||
return artifacts[0]
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_n_artifacts")
|
||||
async def get_latest_n_artifacts(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
n: int = 1,
|
||||
) -> list[Artifact] | None:
|
||||
async with self.Session() as session:
|
||||
artifact_query = select(ArtifactModel).filter_by(task_id=task_id)
|
||||
if organization_id:
|
||||
artifact_query = artifact_query.filter_by(organization_id=organization_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifacts = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).fetchmany(n)
|
||||
if artifacts:
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
return None
|
||||
|
||||
@db_operation("delete_task_artifacts")
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("delete_task_v2_artifacts")
|
||||
async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.observer_cruise_id == task_v2_id,
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_action_screenshot_artifact_id")
|
||||
async def update_action_screenshot_artifact_id(
|
||||
self, *, organization_id: str, action_id: str, screenshot_artifact_id: str
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
await session.execute(
|
||||
update(ActionModel)
|
||||
.where(ActionModel.action_id == action_id, ActionModel.organization_id == organization_id)
|
||||
.values(screenshot_artifact_id=screenshot_artifact_id)
|
||||
)
|
||||
await session.commit()
|
||||
472
skyvern/forge/sdk/db/repositories/browser_sessions.py
Normal file
472
skyvern/forge/sdk/db/repositories/browser_sessions.py
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import asc, case, select
|
||||
|
||||
from skyvern.exceptions import BrowserProfileNotFound
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import read_retry
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
BrowserProfileModel,
|
||||
DebugSessionModel,
|
||||
PersistentBrowserSessionModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import serialize_proxy_location
|
||||
from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
|
||||
Extensions,
|
||||
PersistentBrowserSession,
|
||||
PersistentBrowserType,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class BrowserSessionsRepository(BaseRepository):
|
||||
"""Database operations for browser profiles and persistent browser sessions."""
|
||||
|
||||
@db_operation("create_browser_profile")
|
||||
async def create_browser_profile(
|
||||
self,
|
||||
organization_id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
) -> BrowserProfile:
|
||||
async with self.Session() as session:
|
||||
browser_profile = BrowserProfileModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(browser_profile)
|
||||
await session.commit()
|
||||
await session.refresh(browser_profile)
|
||||
return BrowserProfile.model_validate(browser_profile)
|
||||
|
||||
@db_operation("get_browser_profile")
|
||||
async def get_browser_profile(
|
||||
self,
|
||||
profile_id: str,
|
||||
organization_id: str,
|
||||
include_deleted: bool = False,
|
||||
) -> BrowserProfile | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BrowserProfileModel)
|
||||
.filter_by(browser_profile_id=profile_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
if not include_deleted:
|
||||
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
|
||||
browser_profile = (await session.scalars(query)).first()
|
||||
if not browser_profile:
|
||||
return None
|
||||
return BrowserProfile.model_validate(browser_profile)
|
||||
|
||||
@db_operation("list_browser_profiles")
|
||||
async def list_browser_profiles(
|
||||
self,
|
||||
organization_id: str,
|
||||
include_deleted: bool = False,
|
||||
) -> list[BrowserProfile]:
|
||||
async with self.Session() as session:
|
||||
query = select(BrowserProfileModel).filter_by(organization_id=organization_id)
|
||||
if not include_deleted:
|
||||
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
browser_profiles = await session.scalars(query.order_by(asc(BrowserProfileModel.created_at)))
|
||||
return [BrowserProfile.model_validate(profile) for profile in browser_profiles.all()]
|
||||
|
||||
@db_operation("delete_browser_profile")
|
||||
async def delete_browser_profile(
|
||||
self,
|
||||
profile_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BrowserProfileModel)
|
||||
.filter_by(browser_profile_id=profile_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
)
|
||||
browser_profile = (await session.scalars(query)).first()
|
||||
if not browser_profile:
|
||||
raise BrowserProfileNotFound(profile_id=profile_id, organization_id=organization_id)
|
||||
browser_profile.deleted_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_active_persistent_browser_sessions")
|
||||
async def get_active_persistent_browser_sessions(
|
||||
self,
|
||||
organization_id: str,
|
||||
active_hours: int = 24,
|
||||
) -> list[PersistentBrowserSession]:
|
||||
"""Get all active persistent browser sessions for an organization."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(completed_at=None)
|
||||
.filter(
|
||||
PersistentBrowserSessionModel.created_at
|
||||
> datetime.now(timezone.utc) - timedelta(hours=active_hours)
|
||||
)
|
||||
)
|
||||
sessions = result.scalars().all()
|
||||
return [PersistentBrowserSession.model_validate(session) for session in sessions]
|
||||
|
||||
@db_operation("get_persistent_browser_sessions_history")
|
||||
async def get_persistent_browser_sessions_history(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
lookback_hours: int = 24 * 7,
|
||||
) -> list[PersistentBrowserSession]:
|
||||
"""Get persistent browser sessions history for an organization."""
|
||||
async with self.Session() as session:
|
||||
open_first = case(
|
||||
(
|
||||
PersistentBrowserSessionModel.status == "running",
|
||||
0, # open
|
||||
),
|
||||
else_=1, # not open
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter(
|
||||
PersistentBrowserSessionModel.created_at
|
||||
> (datetime.now(timezone.utc) - timedelta(hours=lookback_hours))
|
||||
)
|
||||
.order_by(
|
||||
open_first.asc(), # open sessions first
|
||||
PersistentBrowserSessionModel.created_at.desc(), # then newest within each group
|
||||
)
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
sessions = result.scalars().all()
|
||||
return [PersistentBrowserSession.model_validate(session) for session in sessions]
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_persistent_browser_session_by_runnable_id", log_errors=False)
|
||||
async def get_persistent_browser_session_by_runnable_id(
|
||||
self, runnable_id: str, organization_id: str | None = None
|
||||
) -> PersistentBrowserSession | None:
|
||||
"""Get a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(runnable_id=runnable_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(completed_at=None)
|
||||
)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
persistent_browser_session = (await session.scalars(query)).first()
|
||||
if persistent_browser_session:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
return None
|
||||
|
||||
@db_operation("get_persistent_browser_session")
|
||||
async def get_persistent_browser_session(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> PersistentBrowserSession | None:
|
||||
"""Get a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
return None
|
||||
|
||||
@db_operation("create_persistent_browser_session")
|
||||
async def create_persistent_browser_session(
|
||||
self,
|
||||
organization_id: str,
|
||||
runnable_type: str | None = None,
|
||||
runnable_id: str | None = None,
|
||||
timeout_minutes: int | None = None,
|
||||
proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL,
|
||||
extensions: list[Extensions] | None = None,
|
||||
browser_type: PersistentBrowserType | None = None,
|
||||
browser_profile_id: str | None = None,
|
||||
) -> PersistentBrowserSession:
|
||||
"""Create a new persistent browser session."""
|
||||
extensions_str: list[str] | None = (
|
||||
[extension.value for extension in extensions] if extensions is not None else None
|
||||
)
|
||||
async with self.Session() as session:
|
||||
browser_session = PersistentBrowserSessionModel(
|
||||
organization_id=organization_id,
|
||||
runnable_type=runnable_type,
|
||||
runnable_id=runnable_id,
|
||||
timeout_minutes=timeout_minutes,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
extensions=extensions_str,
|
||||
browser_type=browser_type.value if browser_type else None,
|
||||
browser_profile_id=browser_profile_id,
|
||||
)
|
||||
session.add(browser_session)
|
||||
await session.commit()
|
||||
await session.refresh(browser_session)
|
||||
return PersistentBrowserSession.model_validate(browser_session)
|
||||
|
||||
@db_operation("update_persistent_browser_session")
|
||||
async def update_persistent_browser_session(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
timeout_minutes: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
completed_at: datetime | None = None,
|
||||
) -> PersistentBrowserSession:
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if not persistent_browser_session:
|
||||
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
|
||||
|
||||
if status:
|
||||
persistent_browser_session.status = status
|
||||
if timeout_minutes:
|
||||
persistent_browser_session.timeout_minutes = timeout_minutes
|
||||
if completed_at:
|
||||
persistent_browser_session.completed_at = completed_at
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
|
||||
@db_operation("set_persistent_browser_session_browser_address")
|
||||
async def set_persistent_browser_session_browser_address(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
browser_address: str | None,
|
||||
ip_address: str | None,
|
||||
ecs_task_arn: str | None,
|
||||
organization_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set the browser address for a persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
if browser_address:
|
||||
persistent_browser_session.browser_address = browser_address
|
||||
# once the address is set, the session is started
|
||||
persistent_browser_session.started_at = datetime.now(timezone.utc)
|
||||
if ip_address:
|
||||
persistent_browser_session.ip_address = ip_address
|
||||
if ecs_task_arn:
|
||||
persistent_browser_session.ecs_task_arn = ecs_task_arn
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
|
||||
|
||||
@db_operation("update_persistent_browser_session_compute_cost")
|
||||
async def update_persistent_browser_session_compute_cost(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str,
|
||||
instance_type: str,
|
||||
vcpu_millicores: int,
|
||||
memory_mb: int,
|
||||
duration_ms: int,
|
||||
compute_cost: float,
|
||||
) -> None:
|
||||
"""Update the compute cost fields for a persistent browser session"""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.instance_type = instance_type
|
||||
persistent_browser_session.vcpu_millicores = vcpu_millicores
|
||||
persistent_browser_session.memory_mb = memory_mb
|
||||
persistent_browser_session.duration_ms = duration_ms
|
||||
persistent_browser_session.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("mark_persistent_browser_session_deleted")
|
||||
async def mark_persistent_browser_session_deleted(self, session_id: str, organization_id: str) -> None:
|
||||
"""Mark a persistent browser session as deleted."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.deleted_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("occupy_persistent_browser_session")
|
||||
async def occupy_persistent_browser_session(
|
||||
self, session_id: str, runnable_type: str, runnable_id: str, organization_id: str
|
||||
) -> None:
|
||||
"""Occupy a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.runnable_type = runnable_type
|
||||
persistent_browser_session.runnable_id = runnable_id
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("release_persistent_browser_session")
|
||||
async def release_persistent_browser_session(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str,
|
||||
) -> PersistentBrowserSession:
|
||||
"""Release a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.runnable_type = None
|
||||
persistent_browser_session.runnable_id = None
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("close_persistent_browser_session")
|
||||
async def close_persistent_browser_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession:
|
||||
"""Close a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
if persistent_browser_session.completed_at:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
persistent_browser_session.completed_at = datetime.now(timezone.utc)
|
||||
persistent_browser_session.status = "completed"
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("get_all_active_persistent_browser_sessions")
|
||||
async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
"""Get all active persistent browser sessions across all organizations."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(select(PersistentBrowserSessionModel).filter_by(deleted_at=None))
|
||||
return result.scalars().all()
|
||||
|
||||
@db_operation("archive_browser_session_address")
|
||||
async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None:
|
||||
"""Suffix browser_address with a unique tag so the unique constraint
|
||||
no longer blocks new sessions that reuse the same local address."""
|
||||
async with self.Session() as session:
|
||||
row = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not row or not row.browser_address:
|
||||
return
|
||||
if "::closed::" in row.browser_address:
|
||||
return
|
||||
|
||||
row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}"
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_uncompleted_persistent_browser_sessions")
|
||||
async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
"""Get all browser sessions that have not been completed or deleted."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@db_operation("get_debug_session_by_browser_session_id")
|
||||
async def get_debug_session_by_browser_session_id(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
model = (await session.scalars(query)).first()
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
209
skyvern/forge/sdk/db/repositories/credentials.py
Normal file
209
skyvern/forge/sdk/db/repositories/credentials.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import CredentialModel, OrganizationBitwardenCollectionModel
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
|
||||
from . import _UNSET
|
||||
|
||||
|
||||
class CredentialRepository(BaseRepository):
|
||||
"""Database operations for credential and Bitwarden collection management."""
|
||||
|
||||
@db_operation("create_credential")
|
||||
async def create_credential(
|
||||
self,
|
||||
organization_id: str,
|
||||
name: str,
|
||||
vault_type: CredentialVaultType,
|
||||
item_id: str,
|
||||
credential_type: CredentialType,
|
||||
username: str | None,
|
||||
totp_type: str,
|
||||
card_last4: str | None,
|
||||
card_brand: str | None,
|
||||
totp_identifier: str | None = None,
|
||||
secret_label: str | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = CredentialModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=credential_type,
|
||||
username=username,
|
||||
totp_type=totp_type,
|
||||
totp_identifier=totp_identifier,
|
||||
card_last4=card_last4,
|
||||
card_brand=card_brand,
|
||||
secret_label=secret_label,
|
||||
)
|
||||
session.add(credential)
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("get_credential")
|
||||
async def get_credential(self, credential_id: str, organization_id: str) -> Credential | None:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
if credential:
|
||||
return Credential.model_validate(credential)
|
||||
return None
|
||||
|
||||
@db_operation("get_credentials")
|
||||
async def get_credentials(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
vault_type: str | None = None,
|
||||
) -> list[Credential]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(CredentialModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
if vault_type is not None:
|
||||
query = query.filter(CredentialModel.vault_type == vault_type)
|
||||
credentials = (
|
||||
await session.scalars(
|
||||
query.order_by(CredentialModel.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
)
|
||||
).all()
|
||||
return [Credential.model_validate(credential) for credential in credentials]
|
||||
|
||||
@db_operation("update_credential")
|
||||
async def update_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
organization_id: str,
|
||||
name: str | None = None,
|
||||
browser_profile_id: str | None | object = _UNSET,
|
||||
tested_url: str | None | object = _UNSET,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
if name is not None:
|
||||
credential.name = name
|
||||
if browser_profile_id is not _UNSET:
|
||||
credential.browser_profile_id = browser_profile_id
|
||||
if tested_url is not _UNSET:
|
||||
credential.tested_url = tested_url
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("update_credential_vault_data")
|
||||
async def update_credential_vault_data(
|
||||
self,
|
||||
credential_id: str,
|
||||
organization_id: str,
|
||||
item_id: str,
|
||||
name: str,
|
||||
credential_type: CredentialType,
|
||||
username: str | None = None,
|
||||
totp_type: str = "none",
|
||||
totp_identifier: str | None = None,
|
||||
card_last4: str | None = None,
|
||||
card_brand: str | None = None,
|
||||
secret_label: str | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
.with_for_update()
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
credential.item_id = item_id
|
||||
credential.name = name
|
||||
credential.credential_type = credential_type
|
||||
credential.username = username
|
||||
credential.totp_type = totp_type
|
||||
credential.totp_identifier = totp_identifier
|
||||
credential.card_last4 = card_last4
|
||||
credential.card_brand = card_brand
|
||||
credential.secret_label = secret_label
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("delete_credential")
|
||||
async def delete_credential(self, credential_id: str, organization_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
credential.deleted_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return None
|
||||
|
||||
@db_operation("create_organization_bitwarden_collection")
|
||||
async def create_organization_bitwarden_collection(
|
||||
self,
|
||||
organization_id: str,
|
||||
collection_id: str,
|
||||
) -> OrganizationBitwardenCollection:
|
||||
async with self.Session() as session:
|
||||
organization_bitwarden_collection = OrganizationBitwardenCollectionModel(
|
||||
organization_id=organization_id, collection_id=collection_id
|
||||
)
|
||||
session.add(organization_bitwarden_collection)
|
||||
await session.commit()
|
||||
await session.refresh(organization_bitwarden_collection)
|
||||
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
|
||||
|
||||
@db_operation("get_organization_bitwarden_collection")
|
||||
async def get_organization_bitwarden_collection(
|
||||
self,
|
||||
organization_id: str,
|
||||
) -> OrganizationBitwardenCollection | None:
|
||||
async with self.Session() as session:
|
||||
organization_bitwarden_collection = (
|
||||
await session.scalars(
|
||||
select(OrganizationBitwardenCollectionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if organization_bitwarden_collection:
|
||||
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
|
||||
return None
|
||||
254
skyvern/forge/sdk/db/repositories/debug.py
Normal file
254
skyvern/forge/sdk/db/repositories/debug.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
BlockRunModel,
|
||||
DebugSessionModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession, DebugSessionRun
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
|
||||
class DebugRepository(BaseRepository):
|
||||
"""Database operations for debug sessions and block runs."""
|
||||
|
||||
@db_operation("get_debug_session")
|
||||
async def get_debug_session(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
debug_session = (
|
||||
await session.scalars(
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
.order_by(DebugSessionModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not debug_session:
|
||||
return None
|
||||
|
||||
return DebugSession.model_validate(debug_session)
|
||||
|
||||
@db_operation("get_latest_block_run")
|
||||
async def get_latest_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
) -> BlockRun | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BlockRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(block_label=block_label)
|
||||
.order_by(BlockRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return BlockRun.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_latest_completed_block_run")
|
||||
async def get_latest_completed_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> BlockRun | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BlockRunModel)
|
||||
.join(WorkflowRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.filter(BlockRunModel.organization_id == organization_id)
|
||||
.filter(BlockRunModel.user_id == user_id)
|
||||
.filter(BlockRunModel.block_label == block_label)
|
||||
.filter(WorkflowRunModel.status == WorkflowRunStatus.completed)
|
||||
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.order_by(BlockRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return BlockRun.model_validate(model) if model else None
|
||||
|
||||
@db_operation("create_block_run")
|
||||
async def create_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
output_parameter_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
block_run = BlockRunModel(
|
||||
organization_id=organization_id,
|
||||
user_id=user_id,
|
||||
block_label=block_label,
|
||||
output_parameter_id=output_parameter_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
session.add(block_run)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_latest_debug_session_for_user")
|
||||
async def get_latest_debug_session_for_user(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
.order_by(DebugSessionModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_debug_session_by_id")
|
||||
async def get_debug_session_by_id(
|
||||
self,
|
||||
debug_session_id: str,
|
||||
organization_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(debug_session_id=debug_session_id)
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_workflow_runs_by_debug_session_id")
|
||||
async def get_workflow_runs_by_debug_session_id(
|
||||
self,
|
||||
debug_session_id: str,
|
||||
organization_id: str,
|
||||
) -> list[DebugSessionRun]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowRunModel, BlockRunModel)
|
||||
.join(BlockRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.debug_session_id == debug_session_id)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
results = (await session.execute(query)).all()
|
||||
|
||||
debug_session_runs = []
|
||||
for workflow_run, block_run in results:
|
||||
debug_session_runs.append(
|
||||
DebugSessionRun(
|
||||
ai_fallback=workflow_run.ai_fallback,
|
||||
block_label=block_run.block_label,
|
||||
browser_session_id=workflow_run.browser_session_id,
|
||||
code_gen=workflow_run.code_gen,
|
||||
debug_session_id=workflow_run.debug_session_id,
|
||||
failure_reason=workflow_run.failure_reason,
|
||||
output_parameter_id=block_run.output_parameter_id,
|
||||
run_with=workflow_run.run_with,
|
||||
script_run_id=workflow_run.script_run.get("script_run_id") if workflow_run.script_run else None,
|
||||
status=workflow_run.status,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_permanent_id=workflow_run.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
created_at=workflow_run.created_at,
|
||||
queued_at=workflow_run.queued_at,
|
||||
started_at=workflow_run.started_at,
|
||||
finished_at=workflow_run.finished_at,
|
||||
)
|
||||
)
|
||||
|
||||
return debug_session_runs
|
||||
|
||||
@db_operation("complete_debug_sessions")
|
||||
async def complete_debug_sessions(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
) -> list[DebugSession]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
)
|
||||
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if workflow_permanent_id:
|
||||
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
|
||||
models = (await session.scalars(query)).all()
|
||||
|
||||
for model in models:
|
||||
model.status = "completed"
|
||||
|
||||
debug_sessions = [DebugSession.model_validate(model) for model in models]
|
||||
|
||||
await session.commit()
|
||||
|
||||
return debug_sessions
|
||||
|
||||
@db_operation("create_debug_session")
|
||||
async def create_debug_session(
|
||||
self,
|
||||
*,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
vnc_streaming_supported: bool,
|
||||
) -> DebugSession:
|
||||
async with self.Session() as session:
|
||||
debug_session = DebugSessionModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
user_id=user_id,
|
||||
browser_session_id=browser_session_id,
|
||||
vnc_streaming_supported=vnc_streaming_supported,
|
||||
status="created",
|
||||
)
|
||||
|
||||
session.add(debug_session)
|
||||
await session.commit()
|
||||
await session.refresh(debug_session)
|
||||
|
||||
return DebugSession.model_validate(debug_session)
|
||||
370
skyvern/forge/sdk/db/repositories/folders.py
Normal file
370
skyvern/forge/sdk/db/repositories/folders.py
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from sqlalchemy import func, or_, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.models import FolderModel, WorkflowModel
|
||||
from skyvern.forge.sdk.db.protocols import WorkflowReader
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class FoldersRepository(BaseRepository):
|
||||
"""Database operations for folder management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
workflow_reader: WorkflowReader | None = None,
|
||||
) -> None:
|
||||
super().__init__(session_factory, debug_enabled, is_retryable_error_fn)
|
||||
self._workflow_reader = workflow_reader
|
||||
|
||||
@db_operation("create_folder")
|
||||
async def create_folder(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
) -> FolderModel:
|
||||
"""Create a new folder."""
|
||||
async with self.Session() as session:
|
||||
folder = FolderModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
)
|
||||
session.add(folder)
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
@db_operation("get_folders")
|
||||
async def get_folders(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
search_query: str | None = None,
|
||||
) -> list[FolderModel]:
|
||||
"""Get all folders for an organization with pagination and optional search."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel).filter_by(organization_id=organization_id).filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
stmt = stmt.filter(
|
||||
or_(
|
||||
FolderModel.title.ilike(search_pattern),
|
||||
FolderModel.description.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(FolderModel.modified_at.desc())
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@db_operation("get_folder")
|
||||
async def get_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> FolderModel | None:
|
||||
"""Get a folder by ID."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@db_operation("update_folder")
|
||||
async def update_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> FolderModel | None:
|
||||
"""Update a folder's title or description."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
if title is not None:
|
||||
folder.title = title
|
||||
if description is not None:
|
||||
folder.description = description
|
||||
|
||||
folder.modified_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
@db_operation("get_workflow_permanent_ids_in_folder")
|
||||
async def get_workflow_permanent_ids_in_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> list[str]:
|
||||
"""Get workflow permanent IDs (latest versions only) in a folder."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Get workflow_permanent_ids where the latest version is in this folder
|
||||
stmt = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@db_operation("soft_delete_folder")
|
||||
async def soft_delete_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
delete_workflows: bool = False,
|
||||
) -> bool:
|
||||
"""Soft delete a folder. Optionally delete all workflows in the folder."""
|
||||
async with self.Session() as session:
|
||||
# Check if folder exists
|
||||
folder_stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
folder_result = await session.execute(folder_stmt)
|
||||
folder = folder_result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return False
|
||||
|
||||
# If delete_workflows is True, delete all workflows in the folder
|
||||
if delete_workflows:
|
||||
# Get workflow permanent IDs in the folder (inline logic)
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
workflow_permanent_ids_stmt = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(workflow_permanent_ids_stmt)
|
||||
workflow_permanent_ids = list(result.scalars().all())
|
||||
|
||||
# Soft delete all workflows with these permanent IDs in a single bulk update
|
||||
if workflow_permanent_ids:
|
||||
update_workflows_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.execute(update_workflows_query)
|
||||
else:
|
||||
# Just remove folder_id from all workflows in this folder
|
||||
update_workflows_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.values(folder_id=None, modified_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.execute(update_workflows_query)
|
||||
|
||||
# Soft delete the folder
|
||||
folder.deleted_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@db_operation("get_folder_workflow_count")
|
||||
async def get_folder_workflow_count(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> int:
|
||||
"""Get the count of workflows (latest versions only) in a folder."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow (same pattern as get_workflows_by_organization_id)
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Count workflows where the latest version is in this folder
|
||||
stmt = (
|
||||
select(func.count(WorkflowModel.workflow_permanent_id))
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
@db_operation("get_folder_workflow_counts_batch")
|
||||
async def get_folder_workflow_counts_batch(
|
||||
self,
|
||||
folder_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> dict[str, int]:
|
||||
"""Get workflow counts for multiple folders in a single query."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Count workflows grouped by folder_id
|
||||
stmt = (
|
||||
select(
|
||||
WorkflowModel.folder_id,
|
||||
func.count(WorkflowModel.workflow_permanent_id).label("count"),
|
||||
)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id.in_(folder_ids))
|
||||
.group_by(WorkflowModel.folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to dict; folders with no workflows will be absent from the result
|
||||
return {row.folder_id: row.count for row in rows}
|
||||
|
||||
@db_operation("update_workflow_folder")
|
||||
async def update_workflow_folder(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
folder_id: str | None,
|
||||
) -> Workflow | None:
|
||||
"""Update folder assignment for the latest version of a workflow."""
|
||||
# Get the latest version of the workflow
|
||||
if self._workflow_reader is None:
|
||||
raise RuntimeError("workflow_reader dependency not set")
|
||||
latest_workflow = await self._workflow_reader.get_workflow_by_permanent_id(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not latest_workflow:
|
||||
return None
|
||||
|
||||
async with self.Session() as session:
|
||||
# Validate folder exists in-org if folder_id is provided
|
||||
if folder_id:
|
||||
stmt = (
|
||||
select(FolderModel.folder_id)
|
||||
.where(FolderModel.folder_id == folder_id)
|
||||
.where(FolderModel.organization_id == organization_id)
|
||||
.where(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
if (await session.scalar(stmt)) is None:
|
||||
raise ValueError(f"Folder {folder_id} not found")
|
||||
|
||||
workflow_model = await session.get(WorkflowModel, latest_workflow.workflow_id)
|
||||
if workflow_model:
|
||||
workflow_model.folder_id = folder_id
|
||||
workflow_model.modified_at = datetime.now(timezone.utc)
|
||||
|
||||
# Update folder's modified_at in the same transaction
|
||||
if folder_id:
|
||||
folder_model = await session.get(FolderModel, folder_id)
|
||||
if folder_model:
|
||||
folder_model.modified_at = datetime.now(timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workflow_model)
|
||||
|
||||
return convert_to_workflow(workflow_model, self.debug_enabled)
|
||||
return None
|
||||
588
skyvern/forge/sdk/db/repositories/observer.py
Normal file
588
skyvern/forge/sdk/db/repositories/observer.py
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import read_retry
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
TaskV2Model,
|
||||
ThoughtModel,
|
||||
WorkflowRunBlockModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.protocols import TaskReader
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_task_v2,
|
||||
convert_to_workflow_run_block,
|
||||
serialize_proxy_location,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunEngine
|
||||
from skyvern.schemas.workflows import BlockStatus, BlockType
|
||||
|
||||
|
||||
class ObserverRepository(BaseRepository):
|
||||
"""Database operations for observer tasks (TaskV2), thoughts, and workflow run blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
task_reader: TaskReader | None = None,
|
||||
) -> None:
|
||||
super().__init__(session_factory, debug_enabled, is_retryable_error_fn)
|
||||
self._task_reader = task_reader
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_task_v2", log_errors=False)
|
||||
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
|
||||
async with self.Session() as session:
|
||||
if task_v2 := (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("delete_thoughts")
|
||||
async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(ThoughtModel).where(
|
||||
and_(
|
||||
ThoughtModel.observer_cruise_id == task_v2_id,
|
||||
ThoughtModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_task_v2_by_workflow_run_id")
|
||||
async def get_task_v2_by_workflow_run_id(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> TaskV2 | None:
|
||||
async with self.Session() as session:
|
||||
if task_v2 := (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_thought")
|
||||
async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None:
|
||||
async with self.Session() as session:
|
||||
if thought := (
|
||||
await session.scalars(
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_thought_id=thought_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return Thought.model_validate(thought)
|
||||
return None
|
||||
|
||||
@db_operation("get_thoughts")
|
||||
async def get_thoughts(
|
||||
self,
|
||||
*,
|
||||
task_v2_id: str,
|
||||
thought_types: list[ThoughtType],
|
||||
organization_id: str,
|
||||
) -> list[Thought]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(ThoughtModel.created_at)
|
||||
)
|
||||
if thought_types:
|
||||
query = query.filter(ThoughtModel.observer_thought_type.in_(thought_types))
|
||||
thoughts = (await session.scalars(query)).all()
|
||||
return [Thought.model_validate(thought) for thought in thoughts]
|
||||
|
||||
@db_operation("create_task_v2")
|
||||
async def create_task_v2(
|
||||
self,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
prompt: str | None = None,
|
||||
url: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
extracted_information_schema: dict | list | str | None = None,
|
||||
error_code_mapping: dict | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_address: str | None = None,
|
||||
run_with: str | None = None,
|
||||
) -> TaskV2:
|
||||
async with self.Session() as session:
|
||||
new_task_v2 = TaskV2Model(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_verification_url,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
organization_id=organization_id,
|
||||
model=model,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_address=browser_address,
|
||||
run_with=run_with,
|
||||
)
|
||||
session.add(new_task_v2)
|
||||
await session.commit()
|
||||
await session.refresh(new_task_v2)
|
||||
return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled)
|
||||
|
||||
@db_operation("create_thought")
|
||||
async def create_thought(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
user_input: str | None = None,
|
||||
observation: str | None = None,
|
||||
thought: str | None = None,
|
||||
answer: str | None = None,
|
||||
thought_scenario: str | None = None,
|
||||
thought_type: str = ThoughtType.plan,
|
||||
output: dict[str, Any] | None = None,
|
||||
input_token_count: int | None = None,
|
||||
output_token_count: int | None = None,
|
||||
reasoning_token_count: int | None = None,
|
||||
cached_token_count: int | None = None,
|
||||
thought_cost: float | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Thought:
|
||||
async with self.Session() as session:
|
||||
new_thought = ThoughtModel(
|
||||
observer_cruise_id=task_v2_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
user_input=user_input,
|
||||
observation=observation,
|
||||
thought=thought,
|
||||
answer=answer,
|
||||
observer_thought_scenario=thought_scenario,
|
||||
observer_thought_type=thought_type,
|
||||
output=output,
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
reasoning_token_count=reasoning_token_count,
|
||||
cached_token_count=cached_token_count,
|
||||
thought_cost=thought_cost,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_thought)
|
||||
await session.commit()
|
||||
await session.refresh(new_thought)
|
||||
return Thought.model_validate(new_thought)
|
||||
|
||||
@db_operation("update_thought")
|
||||
async def update_thought(
|
||||
self,
|
||||
thought_id: str,
|
||||
workflow_run_block_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
observation: str | None = None,
|
||||
thought: str | None = None,
|
||||
answer: str | None = None,
|
||||
output: dict[str, Any] | None = None,
|
||||
input_token_count: int | None = None,
|
||||
output_token_count: int | None = None,
|
||||
reasoning_token_count: int | None = None,
|
||||
cached_token_count: int | None = None,
|
||||
thought_cost: float | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Thought:
|
||||
async with self.Session() as session:
|
||||
thought_obj = (
|
||||
await session.scalars(
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_thought_id=thought_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if thought_obj:
|
||||
if workflow_run_block_id:
|
||||
thought_obj.workflow_run_block_id = workflow_run_block_id
|
||||
if workflow_run_id:
|
||||
thought_obj.workflow_run_id = workflow_run_id
|
||||
if workflow_id:
|
||||
thought_obj.workflow_id = workflow_id
|
||||
if workflow_permanent_id:
|
||||
thought_obj.workflow_permanent_id = workflow_permanent_id
|
||||
if observation:
|
||||
thought_obj.observation = observation
|
||||
if thought:
|
||||
thought_obj.thought = thought
|
||||
if answer:
|
||||
thought_obj.answer = answer
|
||||
if output:
|
||||
thought_obj.output = output
|
||||
if input_token_count:
|
||||
thought_obj.input_token_count = input_token_count
|
||||
if output_token_count:
|
||||
thought_obj.output_token_count = output_token_count
|
||||
if reasoning_token_count:
|
||||
thought_obj.reasoning_token_count = reasoning_token_count
|
||||
if cached_token_count:
|
||||
thought_obj.cached_token_count = cached_token_count
|
||||
if thought_cost:
|
||||
thought_obj.thought_cost = thought_cost
|
||||
await session.commit()
|
||||
await session.refresh(thought_obj)
|
||||
return Thought.model_validate(thought_obj)
|
||||
raise NotFoundError(f"Thought {thought_id}")
|
||||
|
||||
@db_operation("update_task_v2")
|
||||
async def update_task_v2(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
status: TaskV2Status | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
url: str | None = None,
|
||||
prompt: str | None = None,
|
||||
summary: str | None = None,
|
||||
output: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
failure_category: list[dict[str, Any]] | None = None,
|
||||
) -> TaskV2:
|
||||
async with self.Session() as session:
|
||||
task_v2 = (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if task_v2:
|
||||
if status:
|
||||
task_v2.status = status
|
||||
if status == TaskV2Status.queued and task_v2.queued_at is None:
|
||||
task_v2.queued_at = datetime.now(timezone.utc)
|
||||
if status == TaskV2Status.running and task_v2.started_at is None:
|
||||
task_v2.started_at = datetime.now(timezone.utc)
|
||||
if status.is_final() and task_v2.finished_at is None:
|
||||
task_v2.finished_at = datetime.now(timezone.utc)
|
||||
if workflow_run_id:
|
||||
task_v2.workflow_run_id = workflow_run_id
|
||||
if workflow_id:
|
||||
task_v2.workflow_id = workflow_id
|
||||
if workflow_permanent_id:
|
||||
task_v2.workflow_permanent_id = workflow_permanent_id
|
||||
if url:
|
||||
task_v2.url = url
|
||||
if prompt:
|
||||
task_v2.prompt = prompt
|
||||
if summary:
|
||||
task_v2.summary = summary
|
||||
if output:
|
||||
task_v2.output = output
|
||||
if webhook_failure_reason is not None:
|
||||
task_v2.webhook_failure_reason = webhook_failure_reason
|
||||
if failure_category is not None:
|
||||
task_v2.failure_category = failure_category
|
||||
await session.commit()
|
||||
await session.refresh(task_v2)
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
|
||||
|
||||
@db_operation("create_workflow_run_block")
|
||||
async def create_workflow_run_block(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
parent_workflow_run_block_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
label: str | None = None,
|
||||
block_type: BlockType | None = None,
|
||||
status: BlockStatus = BlockStatus.running,
|
||||
output: dict | list | str | None = None,
|
||||
continue_on_failure: bool = False,
|
||||
engine: RunEngine | None = None,
|
||||
current_value: str | None = None,
|
||||
current_index: int | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
new_workflow_run_block = WorkflowRunBlockModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
parent_workflow_run_block_id=parent_workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
task_id=task_id,
|
||||
label=label,
|
||||
block_type=block_type,
|
||||
status=status,
|
||||
output=output,
|
||||
continue_on_failure=continue_on_failure,
|
||||
engine=engine,
|
||||
current_value=current_value,
|
||||
current_index=current_index,
|
||||
)
|
||||
session.add(new_workflow_run_block)
|
||||
await session.commit()
|
||||
await session.refresh(new_workflow_run_block)
|
||||
|
||||
task = None
|
||||
if task_id:
|
||||
if self._task_reader is None:
|
||||
raise RuntimeError("task_reader dependency not set")
|
||||
task = await self._task_reader.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(new_workflow_run_block, task=task)
|
||||
|
||||
@db_operation("delete_workflow_run_blocks")
|
||||
async def delete_workflow_run_blocks(self, workflow_run_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(WorkflowRunBlockModel).where(
|
||||
and_(
|
||||
WorkflowRunBlockModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowRunBlockModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_workflow_run_block")
|
||||
async def update_workflow_run_block(
|
||||
self,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
status: BlockStatus | None = None,
|
||||
output: dict | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
task_id: str | None = None,
|
||||
loop_values: list | None = None,
|
||||
current_value: str | None = None,
|
||||
current_index: int | None = None,
|
||||
recipients: list[str] | None = None,
|
||||
attachments: list[str] | None = None,
|
||||
subject: str | None = None,
|
||||
body: str | None = None,
|
||||
prompt: str | None = None,
|
||||
wait_sec: int | None = None,
|
||||
description: str | None = None,
|
||||
block_workflow_run_id: str | None = None,
|
||||
engine: str | None = None,
|
||||
# HTTP request block parameters
|
||||
http_request_method: str | None = None,
|
||||
http_request_url: str | None = None,
|
||||
http_request_headers: dict[str, str] | None = None,
|
||||
http_request_body: dict[str, Any] | None = None,
|
||||
http_request_parameters: dict[str, Any] | None = None,
|
||||
http_request_timeout: int | None = None,
|
||||
http_request_follow_redirects: bool | None = None,
|
||||
ai_fallback_triggered: bool | None = None,
|
||||
# block-level error codes (e.g. ["FILE_PARSER_ERROR"])
|
||||
error_codes: list[str] | None = None,
|
||||
# human interaction block
|
||||
instructions: str | None = None,
|
||||
positive_descriptor: str | None = None,
|
||||
negative_descriptor: str | None = None,
|
||||
# conditional block
|
||||
executed_branch_id: str | None = None,
|
||||
executed_branch_expression: str | None = None,
|
||||
executed_branch_result: bool | None = None,
|
||||
executed_branch_next_block: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
if status:
|
||||
workflow_run_block.status = status
|
||||
if output:
|
||||
workflow_run_block.output = output
|
||||
if task_id:
|
||||
workflow_run_block.task_id = task_id
|
||||
if failure_reason:
|
||||
workflow_run_block.failure_reason = failure_reason
|
||||
# Use `is not None` instead of truthiness checks so that falsy
|
||||
# values like current_index=0, empty loop_values=[], or
|
||||
# current_value="" are correctly persisted. Without this,
|
||||
# the first loop iteration (index 0) loses its metadata.
|
||||
if loop_values is not None:
|
||||
workflow_run_block.loop_values = loop_values
|
||||
if current_value is not None:
|
||||
workflow_run_block.current_value = current_value
|
||||
if current_index is not None:
|
||||
workflow_run_block.current_index = current_index
|
||||
if recipients:
|
||||
workflow_run_block.recipients = recipients
|
||||
if attachments:
|
||||
workflow_run_block.attachments = attachments
|
||||
if subject:
|
||||
workflow_run_block.subject = subject
|
||||
if body:
|
||||
workflow_run_block.body = body
|
||||
if prompt:
|
||||
workflow_run_block.prompt = prompt
|
||||
if wait_sec:
|
||||
workflow_run_block.wait_sec = wait_sec
|
||||
if description:
|
||||
workflow_run_block.description = description
|
||||
if block_workflow_run_id:
|
||||
workflow_run_block.block_workflow_run_id = block_workflow_run_id
|
||||
if engine:
|
||||
workflow_run_block.engine = engine
|
||||
# HTTP request block fields
|
||||
if http_request_method:
|
||||
workflow_run_block.http_request_method = http_request_method
|
||||
if http_request_url:
|
||||
workflow_run_block.http_request_url = http_request_url
|
||||
if http_request_headers:
|
||||
workflow_run_block.http_request_headers = http_request_headers
|
||||
if http_request_body:
|
||||
workflow_run_block.http_request_body = http_request_body
|
||||
if http_request_parameters:
|
||||
workflow_run_block.http_request_parameters = http_request_parameters
|
||||
if http_request_timeout:
|
||||
workflow_run_block.http_request_timeout = http_request_timeout
|
||||
if http_request_follow_redirects is not None:
|
||||
workflow_run_block.http_request_follow_redirects = http_request_follow_redirects
|
||||
if ai_fallback_triggered is not None:
|
||||
workflow_run_block.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
|
||||
if error_codes is not None:
|
||||
workflow_run_block.error_codes = error_codes
|
||||
# human interaction block fields
|
||||
if instructions:
|
||||
workflow_run_block.instructions = instructions
|
||||
if positive_descriptor:
|
||||
workflow_run_block.positive_descriptor = positive_descriptor
|
||||
if negative_descriptor:
|
||||
workflow_run_block.negative_descriptor = negative_descriptor
|
||||
# conditional block fields
|
||||
if executed_branch_id:
|
||||
workflow_run_block.executed_branch_id = executed_branch_id
|
||||
if executed_branch_expression is not None:
|
||||
workflow_run_block.executed_branch_expression = executed_branch_expression
|
||||
if executed_branch_result is not None:
|
||||
workflow_run_block.executed_branch_result = executed_branch_result
|
||||
if executed_branch_next_block is not None:
|
||||
workflow_run_block.executed_branch_next_block = executed_branch_next_block
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_block)
|
||||
else:
|
||||
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
if self._task_reader is None:
|
||||
raise RuntimeError("task_reader dependency not set")
|
||||
task = await self._task_reader.get_task(task_id, organization_id=workflow_run_block.organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
|
||||
@db_operation("get_workflow_run_block")
|
||||
async def get_workflow_run_block(
|
||||
self,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
if self._task_reader is None:
|
||||
raise RuntimeError("task_reader dependency not set")
|
||||
task = await self._task_reader.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
|
||||
|
||||
@db_operation("get_workflow_run_block_by_task_id")
|
||||
async def get_workflow_run_block_by_task_id(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
if self._task_reader is None:
|
||||
raise RuntimeError("task_reader dependency not set")
|
||||
task = await self._task_reader.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
raise NotFoundError(f"WorkflowRunBlock not found by {task_id}")
|
||||
|
||||
@db_operation("get_workflow_run_blocks")
|
||||
async def get_workflow_run_blocks(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRunBlock]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_blocks = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(WorkflowRunBlockModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
if self._task_reader is None:
|
||||
raise RuntimeError("task_reader dependency not set")
|
||||
tasks = await self._task_reader.get_tasks_by_workflow_run_id(workflow_run_id)
|
||||
tasks_dict = {task.task_id: task for task in tasks}
|
||||
return [
|
||||
convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id))
|
||||
for workflow_run_block in workflow_run_blocks
|
||||
]
|
||||
375
skyvern/forge/sdk/db/repositories/organizations.py
Normal file
375
skyvern/forge/sdk/db/repositories/organizations.py
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal, overload
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import read_retry
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
TaskModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
)
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
BitwardenCredential,
|
||||
BitwardenOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
|
||||
class OrganizationsRepository(BaseRepository):
|
||||
"""Database operations for organization and auth-token management."""
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_active_verification_requests", log_errors=False)
|
||||
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||
"""Return active 2FA verification requests for an organization.
|
||||
|
||||
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||
Used to provide initial state when a WebSocket notification client connects.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
async with self.Session() as session:
|
||||
# Tasks waiting for verification (exclude finalized tasks)
|
||||
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||
task_rows = (
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter_by(workflow_run_id=None)
|
||||
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||
.filter(TaskModel.created_at > datetime.now(timezone.utc) - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for t in task_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": t.task_id,
|
||||
"workflow_run_id": None,
|
||||
"verification_code_identifier": t.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
t.verification_code_polling_started_at.isoformat()
|
||||
if t.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
# Workflow runs waiting for verification (exclude finalized runs)
|
||||
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||
wr_rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||
.filter(WorkflowRunModel.created_at > datetime.now(timezone.utc) - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for wr in wr_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": None,
|
||||
"workflow_run_id": wr.workflow_run_id,
|
||||
"verification_code_identifier": wr.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
wr.verification_code_polling_started_at.isoformat()
|
||||
if wr.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
@db_operation("get_all_organizations")
|
||||
async def get_all_organizations(self) -> list[Organization]:
|
||||
async with self.Session() as session:
|
||||
organizations = (await session.scalars(select(OrganizationModel))).all()
|
||||
return [convert_to_organization(organization) for organization in organizations]
|
||||
|
||||
@db_operation("get_organization")
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
async with self.Session() as session:
|
||||
if organization := (
|
||||
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||
).first():
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_organization_by_domain")
|
||||
async def get_organization_by_domain(self, domain: str) -> Organization | None:
|
||||
async with self.Session() as session:
|
||||
if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first():
|
||||
return convert_to_organization(organization)
|
||||
return None
|
||||
|
||||
@db_operation("create_organization")
|
||||
async def create_organization(
|
||||
self,
|
||||
organization_name: str,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
max_retries_per_step: int | None = None,
|
||||
domain: str | None = None,
|
||||
) -> Organization:
|
||||
async with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
max_retries_per_step=max_retries_per_step,
|
||||
domain=domain,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
@db_operation("update_organization")
|
||||
async def update_organization(
|
||||
self,
|
||||
organization_id: str,
|
||||
organization_name: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
max_retries_per_step: int | None = None,
|
||||
) -> Organization:
|
||||
async with self.Session() as session:
|
||||
organization = (
|
||||
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||
).first()
|
||||
if not organization:
|
||||
raise NotFoundError
|
||||
if organization_name:
|
||||
organization.organization_name = organization_name
|
||||
if webhook_callback_url:
|
||||
organization.webhook_callback_url = webhook_callback_url
|
||||
if max_steps_per_run:
|
||||
organization.max_steps_per_run = max_steps_per_run
|
||||
if max_retries_per_step:
|
||||
organization.max_retries_per_step = max_retries_per_step
|
||||
await session.commit()
|
||||
await session.refresh(organization)
|
||||
return Organization.model_validate(organization)
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["api", "onepassword_service_account", "custom_credential_service"],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token( # type: ignore
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["azure_client_secret_credential"],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token( # type: ignore
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["bitwarden_credential"],
|
||||
) -> BitwardenOrganizationAuthToken | None: ...
|
||||
|
||||
@db_operation("get_valid_org_auth_token")
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[
|
||||
"api",
|
||||
"onepassword_service_account",
|
||||
"azure_client_secret_credential",
|
||||
"bitwarden_credential",
|
||||
"custom_credential_service",
|
||||
],
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken | None:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return await convert_to_organization_auth_token(token, token_type)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("replace_org_auth_token")
|
||||
async def replace_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str | AzureClientSecretCredential | BitwardenCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
||||
"""Atomically invalidate existing tokens and create a new one in a single transaction."""
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
||||
if not isinstance(token, BitwardenCredential):
|
||||
raise TypeError("Expected BitwardenCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
# Invalidate existing tokens
|
||||
await session.execute(
|
||||
update(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.values(valid=False)
|
||||
)
|
||||
# Create new token
|
||||
auth_token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=plaintext_token,
|
||||
encrypted_token=encrypted_token,
|
||||
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||
)
|
||||
session.add(auth_token)
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
@db_operation("get_valid_org_auth_tokens")
|
||||
async def get_valid_org_auth_tokens(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> list[OrganizationAuthToken]:
|
||||
async with self.Session() as session:
|
||||
tokens = (
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||
|
||||
@db_operation("validate_org_auth_token")
|
||||
async def validate_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
valid: bool | None = True,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | None:
|
||||
encrypted_token = ""
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
)
|
||||
if encrypted_token:
|
||||
query = query.filter_by(encrypted_token=encrypted_token)
|
||||
else:
|
||||
query = query.filter_by(token=token)
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("create_org_auth_token")
|
||||
async def create_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str | AzureClientSecretCredential | BitwardenCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
||||
if not isinstance(token, BitwardenCredential):
|
||||
raise TypeError("Expected BitwardenCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
auth_token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=plaintext_token,
|
||||
encrypted_token=encrypted_token,
|
||||
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||
)
|
||||
session.add(auth_token)
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
@db_operation("invalidate_org_auth_tokens")
|
||||
async def invalidate_org_auth_tokens(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> None:
|
||||
"""Invalidate all existing tokens of a specific type for an organization."""
|
||||
async with self.Session() as session:
|
||||
await session.execute(
|
||||
update(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.values(valid=False)
|
||||
)
|
||||
await session.commit()
|
||||
154
skyvern/forge/sdk/db/repositories/otp.py
Normal file
154
skyvern/forge/sdk/db/repositories/otp.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import and_, asc, select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.models import TOTPCodeModel
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode
|
||||
|
||||
|
||||
class OTPRepository(BaseRepository):
|
||||
"""Database operations for OTP/TOTP management."""
|
||||
|
||||
@db_operation("get_otp_codes")
|
||||
async def get_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
otp_type: OTPType | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[TOTPCode]:
|
||||
"""
|
||||
1. filter by:
|
||||
- organization_id
|
||||
- totp_identifier
|
||||
- workflow_run_id (optional)
|
||||
2. make sure created_at is within the valid lifespan
|
||||
3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc
|
||||
4. apply an optional limit at the DB layer
|
||||
"""
|
||||
all_null = and_(
|
||||
TOTPCodeModel.task_id.is_(None),
|
||||
TOTPCodeModel.workflow_id.is_(None),
|
||||
TOTPCodeModel.workflow_run_id.is_(None),
|
||||
)
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(totp_identifier=totp_identifier)
|
||||
.filter(
|
||||
TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes)
|
||||
)
|
||||
)
|
||||
if otp_type:
|
||||
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
query = query.order_by(asc(all_null), TOTPCodeModel.created_at.desc())
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
totp_codes = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(code) for code in totp_codes]
|
||||
|
||||
@db_operation("get_otp_codes_by_run")
|
||||
async def get_otp_codes_by_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
limit: int = 1,
|
||||
) -> list[TOTPCode]:
|
||||
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||
|
||||
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||
The user submits codes manually via the UI, and this method finds them by run context.
|
||||
"""
|
||||
if not workflow_run_id and not task_id:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(
|
||||
TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes)
|
||||
)
|
||||
)
|
||||
if workflow_run_id:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
elif task_id:
|
||||
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
results = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(r) for r in results]
|
||||
|
||||
@db_operation("get_recent_otp_codes")
|
||||
async def get_recent_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
limit: int = 50,
|
||||
valid_lifespan_minutes: int | None = None,
|
||||
otp_type: OTPType | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
) -> list[TOTPCode]:
|
||||
"""
|
||||
Return recent otp codes for an organization ordered by newest first with optional
|
||||
workflow_run_id filtering.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(TOTPCodeModel).filter_by(organization_id=organization_id)
|
||||
|
||||
if valid_lifespan_minutes is not None:
|
||||
query = query.filter(
|
||||
TOTPCodeModel.created_at > datetime.now(timezone.utc) - timedelta(minutes=valid_lifespan_minutes)
|
||||
)
|
||||
|
||||
if otp_type:
|
||||
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
if totp_identifier:
|
||||
query = query.filter(TOTPCodeModel.totp_identifier == totp_identifier)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
totp_codes = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_codes]
|
||||
|
||||
@db_operation("create_otp_code")
|
||||
async def create_otp_code(
|
||||
self,
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
content: str,
|
||||
code: str,
|
||||
otp_type: OTPType,
|
||||
task_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
source: str | None = None,
|
||||
expired_at: datetime | None = None,
|
||||
) -> TOTPCode:
|
||||
async with self.Session() as session:
|
||||
new_totp_code = TOTPCodeModel(
|
||||
organization_id=organization_id,
|
||||
totp_identifier=totp_identifier,
|
||||
content=content,
|
||||
code=code,
|
||||
task_id=task_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
source=source,
|
||||
expired_at=expired_at,
|
||||
otp_type=otp_type,
|
||||
)
|
||||
session.add(new_totp_code)
|
||||
await session.commit()
|
||||
await session.refresh(new_totp_code)
|
||||
return TOTPCode.model_validate(new_totp_code)
|
||||
602
skyvern/forge/sdk/db/repositories/schedules.py
Normal file
602
skyvern/forge/sdk/db/repositories/schedules.py
Normal file
|
|
@ -0,0 +1,602 @@
|
|||
"""Database operations for workflow schedules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import exists, func, or_, select, text, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation, register_passthrough_exception
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError
|
||||
from skyvern.forge.sdk.db.models import WorkflowModel, WorkflowRunModel, WorkflowScheduleModel
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow_schedule
|
||||
from skyvern.forge.sdk.schemas.workflow_schedules import OrganizationScheduleItem, WorkflowSchedule
|
||||
from skyvern.forge.sdk.workflow.schedules import compute_next_run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from . import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
register_passthrough_exception(ScheduleLimitExceededError)
|
||||
|
||||
|
||||
class SchedulesRepository(BaseRepository):
|
||||
"""Database operations for workflow schedules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
sqlite_schedule_lock: asyncio.Lock | None = None,
|
||||
) -> None:
|
||||
super().__init__(session_factory, debug_enabled, is_retryable_error_fn)
|
||||
self._sqlite_schedule_lock = sqlite_schedule_lock
|
||||
|
||||
@db_operation("create_workflow_schedule")
|
||||
async def create_workflow_schedule(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
temporal_schedule_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> WorkflowSchedule:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = WorkflowScheduleModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cron_expression=cron_expression,
|
||||
timezone=timezone,
|
||||
enabled=enabled,
|
||||
parameters=parameters,
|
||||
temporal_schedule_id=temporal_schedule_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_schedule)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("create_workflow_schedule_with_limit")
|
||||
async def create_workflow_schedule_with_limit(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
max_schedules: int | None,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> tuple[WorkflowSchedule, int]:
|
||||
"""Create a schedule atomically with limit enforcement.
|
||||
|
||||
On PostgreSQL, uses an advisory lock to serialize concurrent creates for
|
||||
the same workflow, preventing TOCTOU races on the schedule count.
|
||||
|
||||
On SQLite, uses an asyncio.Lock (set on AgentDB.__init__) since SQLite
|
||||
is single-writer and has no advisory lock support.
|
||||
|
||||
Returns (created_schedule, count_before_insert).
|
||||
Raises ScheduleLimitExceededError if count >= max_schedules.
|
||||
"""
|
||||
# SQLite: serialize via Python lock (no advisory locks available).
|
||||
# The lock is held across the count-check + insert to prevent TOCTOU.
|
||||
if self._sqlite_schedule_lock is not None:
|
||||
async with self._sqlite_schedule_lock:
|
||||
return await self._create_schedule_with_limit_inner(
|
||||
organization_id,
|
||||
workflow_permanent_id,
|
||||
max_schedules,
|
||||
cron_expression,
|
||||
timezone,
|
||||
enabled,
|
||||
parameters,
|
||||
name,
|
||||
description,
|
||||
use_advisory_lock=False,
|
||||
)
|
||||
return await self._create_schedule_with_limit_inner(
|
||||
organization_id,
|
||||
workflow_permanent_id,
|
||||
max_schedules,
|
||||
cron_expression,
|
||||
timezone,
|
||||
enabled,
|
||||
parameters,
|
||||
name,
|
||||
description,
|
||||
use_advisory_lock=True,
|
||||
)
|
||||
|
||||
# Intentionally not decorated with @db_operation — errors are caught by the
|
||||
# outer create_workflow_schedule_with_limit which owns the operation name.
|
||||
async def _create_schedule_with_limit_inner(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
max_schedules: int | None,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
*,
|
||||
use_advisory_lock: bool,
|
||||
) -> tuple[WorkflowSchedule, int]:
|
||||
async with self.Session() as session:
|
||||
if use_advisory_lock:
|
||||
lock_key = f"schedule:{organization_id}:{workflow_permanent_id}"
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(:key))"),
|
||||
{"key": lock_key},
|
||||
)
|
||||
|
||||
count = (
|
||||
await session.execute(
|
||||
select(func.count()).where(
|
||||
WorkflowScheduleModel.organization_id == organization_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if max_schedules is not None and count >= max_schedules:
|
||||
raise ScheduleLimitExceededError(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
current_count=count,
|
||||
max_allowed=max_schedules,
|
||||
)
|
||||
|
||||
workflow_schedule = WorkflowScheduleModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cron_expression=cron_expression,
|
||||
timezone=timezone,
|
||||
enabled=enabled,
|
||||
parameters=parameters,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_schedule)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled), count
|
||||
|
||||
@db_operation("set_temporal_schedule_id")
|
||||
async def set_temporal_schedule_id(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
temporal_schedule_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.temporal_schedule_id = temporal_schedule_id
|
||||
workflow_schedule.modified_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("update_workflow_schedule")
|
||||
async def update_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
temporal_schedule_id: str | None | object = _UNSET,
|
||||
name: str | None | object = _UNSET,
|
||||
description: str | None | object = _UNSET,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.cron_expression = cron_expression
|
||||
workflow_schedule.timezone = timezone
|
||||
workflow_schedule.enabled = enabled
|
||||
workflow_schedule.parameters = parameters
|
||||
if temporal_schedule_id is not _UNSET:
|
||||
workflow_schedule.temporal_schedule_id = temporal_schedule_id
|
||||
if name is not _UNSET:
|
||||
workflow_schedule.name = name
|
||||
if description is not _UNSET:
|
||||
workflow_schedule.description = description
|
||||
workflow_schedule.modified_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_schedule_by_id")
|
||||
async def get_workflow_schedule_by_id(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_schedules")
|
||||
async def get_workflow_schedules(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> list[WorkflowSchedule]:
|
||||
async with self.Session() as session:
|
||||
rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
|
||||
|
||||
@db_operation("get_all_enabled_schedules")
|
||||
async def get_all_enabled_schedules(
|
||||
self,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowSchedule]:
|
||||
"""Fetch all enabled, non-deleted schedules, optionally filtered by org."""
|
||||
async with self.Session() as session:
|
||||
stmt = select(WorkflowScheduleModel).where(
|
||||
WorkflowScheduleModel.enabled.is_(True),
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
if organization_id:
|
||||
stmt = stmt.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
rows = (await session.scalars(stmt)).all()
|
||||
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
|
||||
|
||||
@db_operation("has_schedule_fired_since")
|
||||
async def has_schedule_fired_since(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
since: datetime,
|
||||
) -> bool:
|
||||
"""Check if a workflow_run exists for the given schedule since a timestamp."""
|
||||
async with self.Session() as session:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(
|
||||
exists().where(
|
||||
WorkflowRunModel.workflow_schedule_id == workflow_schedule_id,
|
||||
WorkflowRunModel.created_at >= since,
|
||||
)
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
return bool(row)
|
||||
|
||||
@db_operation("update_workflow_schedule_enabled")
|
||||
async def update_workflow_schedule_enabled(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
enabled: bool,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
workflow_schedule.enabled = enabled
|
||||
workflow_schedule.modified_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("delete_workflow_schedule")
|
||||
async def delete_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.deleted_at = datetime.now(UTC)
|
||||
workflow_schedule.modified_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("restore_workflow_schedule")
|
||||
async def restore_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel)
|
||||
.filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
.filter(WorkflowScheduleModel.deleted_at.isnot(None))
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.deleted_at = None
|
||||
workflow_schedule.modified_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("count_workflow_schedules")
|
||||
async def count_workflow_schedules(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(func.count()).where(
|
||||
WorkflowScheduleModel.organization_id == organization_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
@db_operation("list_organization_schedules")
|
||||
async def list_organization_schedules(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
enabled_filter: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[OrganizationScheduleItem], int]:
|
||||
"""
|
||||
List all schedules for an organization, joined with workflow titles.
|
||||
Returns (schedules, total_count).
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version title per workflow_permanent_id
|
||||
latest_version_sq = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(WorkflowModel.workflow_permanent_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
workflow_title_sq = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
WorkflowModel.title,
|
||||
)
|
||||
.join(
|
||||
latest_version_sq,
|
||||
(WorkflowModel.workflow_permanent_id == latest_version_sq.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == latest_version_sq.c.max_version),
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Base query: schedules joined with workflow titles
|
||||
base_filter = (
|
||||
select(WorkflowScheduleModel, workflow_title_sq.c.title.label("workflow_title"))
|
||||
.outerjoin(
|
||||
workflow_title_sq,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_title_sq.c.workflow_permanent_id,
|
||||
)
|
||||
.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if enabled_filter is not None:
|
||||
base_filter = base_filter.where(WorkflowScheduleModel.enabled == enabled_filter)
|
||||
|
||||
if search:
|
||||
base_filter = base_filter.where(
|
||||
or_(
|
||||
workflow_title_sq.c.title.icontains(search, autoescape=True),
|
||||
WorkflowScheduleModel.name.icontains(search, autoescape=True),
|
||||
)
|
||||
)
|
||||
|
||||
# Count query
|
||||
count_query = select(func.count()).select_from(base_filter.subquery())
|
||||
total_count = (await session.execute(count_query)).scalar_one()
|
||||
|
||||
# Data query with pagination
|
||||
data_query = (
|
||||
base_filter.order_by(WorkflowScheduleModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
)
|
||||
rows = (await session.execute(data_query)).all()
|
||||
|
||||
# Materialize row data while session is open
|
||||
raw_schedules = []
|
||||
for row in rows:
|
||||
schedule_model = row[0]
|
||||
raw_schedules.append(
|
||||
(
|
||||
schedule_model.workflow_schedule_id,
|
||||
schedule_model.organization_id,
|
||||
schedule_model.workflow_permanent_id,
|
||||
row[1] or "Untitled Workflow",
|
||||
schedule_model.cron_expression,
|
||||
schedule_model.timezone,
|
||||
schedule_model.enabled,
|
||||
schedule_model.parameters,
|
||||
schedule_model.name,
|
||||
schedule_model.description,
|
||||
schedule_model.created_at,
|
||||
schedule_model.modified_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Compute next_run outside session scope (pure CPU, no DB needed)
|
||||
schedules: list[OrganizationScheduleItem] = []
|
||||
for (
|
||||
ws_id,
|
||||
org_id,
|
||||
wpid,
|
||||
title,
|
||||
cron_expr,
|
||||
tz,
|
||||
enabled,
|
||||
params,
|
||||
name,
|
||||
description,
|
||||
created,
|
||||
modified,
|
||||
) in raw_schedules:
|
||||
next_run = None
|
||||
if enabled:
|
||||
try:
|
||||
next_run = compute_next_run(cron_expr, tz)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to compute next_run for schedule",
|
||||
workflow_schedule_id=ws_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schedules.append(
|
||||
OrganizationScheduleItem(
|
||||
workflow_schedule_id=ws_id,
|
||||
organization_id=org_id,
|
||||
workflow_permanent_id=wpid,
|
||||
workflow_title=title,
|
||||
cron_expression=cron_expr,
|
||||
timezone=tz,
|
||||
enabled=enabled,
|
||||
parameters=params,
|
||||
name=name,
|
||||
description=description,
|
||||
next_run=next_run,
|
||||
created_at=created,
|
||||
modified_at=modified,
|
||||
)
|
||||
)
|
||||
|
||||
return schedules, total_count
|
||||
|
||||
@db_operation("soft_delete_orphaned_schedules")
|
||||
async def soft_delete_orphaned_schedules(self, limit: int = 500) -> list[tuple[str, str]]:
|
||||
"""Soft-delete orphaned schedules and return their identities.
|
||||
|
||||
Uses a single UPDATE ... RETURNING statement so orphan detection and
|
||||
soft-deletion happen atomically in one DB round-trip.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
active_workflow_exists = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.where(WorkflowModel.workflow_permanent_id == WorkflowScheduleModel.workflow_permanent_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.correlate(WorkflowScheduleModel)
|
||||
.exists()
|
||||
)
|
||||
orphaned_schedules = (
|
||||
select(
|
||||
WorkflowScheduleModel.workflow_schedule_id.label("workflow_schedule_id"),
|
||||
WorkflowScheduleModel.workflow_permanent_id.label("workflow_permanent_id"),
|
||||
)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
.where(~active_workflow_exists)
|
||||
.limit(limit)
|
||||
.cte("orphaned_schedules")
|
||||
)
|
||||
update_query = (
|
||||
update(WorkflowScheduleModel)
|
||||
.where(
|
||||
WorkflowScheduleModel.workflow_schedule_id.in_(select(orphaned_schedules.c.workflow_schedule_id))
|
||||
)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(UTC))
|
||||
.returning(
|
||||
WorkflowScheduleModel.workflow_schedule_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id,
|
||||
)
|
||||
)
|
||||
result = await session.execute(update_query)
|
||||
await session.commit()
|
||||
return [(row[0], row[1]) for row in result.all()]
|
||||
1206
skyvern/forge/sdk/db/repositories/scripts.py
Normal file
1206
skyvern/forge/sdk/db/repositories/scripts.py
Normal file
File diff suppressed because it is too large
Load diff
727
skyvern/forge/sdk/db/repositories/tasks.py
Normal file
727
skyvern/forge/sdk/db/repositories/tasks.py
Normal file
|
|
@ -0,0 +1,727 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Sequence
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, delete, distinct, func, select, tuple_, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import read_retry
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class TasksRepository(BaseRepository):
|
||||
@db_operation("create_task")
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
title: str | None,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
status: str = "created",
|
||||
complete_criterion: str | None = None,
|
||||
terminate_criterion: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
task_type: str = TaskType.general,
|
||||
application: str | None = None,
|
||||
include_action_history_in_verification: bool | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
browser_address: str | None = None,
|
||||
download_timeout: float | None = None,
|
||||
) -> Task:
|
||||
# Sanitize text fields to remove NUL bytes and control characters
|
||||
# that PostgreSQL cannot store in text columns
|
||||
def _sanitize(v: str | None) -> str | None:
|
||||
return sanitize_postgres_text(v) if isinstance(v, str) else v
|
||||
|
||||
navigation_goal = _sanitize(navigation_goal)
|
||||
data_extraction_goal = _sanitize(data_extraction_goal)
|
||||
title = _sanitize(title)
|
||||
url = sanitize_postgres_text(url)
|
||||
complete_criterion = _sanitize(complete_criterion)
|
||||
terminate_criterion = _sanitize(terminate_criterion)
|
||||
|
||||
async with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status=status,
|
||||
task_type=task_type,
|
||||
url=url,
|
||||
title=title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
navigation_goal=navigation_goal,
|
||||
complete_criterion=complete_criterion,
|
||||
terminate_criterion=terminate_criterion,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
error_code_mapping=error_code_mapping,
|
||||
application=application,
|
||||
include_action_history_in_verification=include_action_history_in_verification,
|
||||
model=model,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_session_id=browser_session_id,
|
||||
browser_address=browser_address,
|
||||
download_timeout=download_timeout,
|
||||
)
|
||||
session.add(new_task)
|
||||
await session.commit()
|
||||
await session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
|
||||
@db_operation("create_step")
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
order: int,
|
||||
retry_index: int,
|
||||
organization_id: str | None = None,
|
||||
status: StepStatus = StepStatus.created,
|
||||
created_by: str | None = None,
|
||||
) -> Step:
|
||||
async with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
retry_index=retry_index,
|
||||
status=status,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(new_step)
|
||||
await session.commit()
|
||||
await session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_task", log_errors=False)
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
async with self.Session() as session:
|
||||
query = select(TaskModel).filter_by(task_id=task_id)
|
||||
if organization_id is not None:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
if task_obj := (await session.scalars(query)).first():
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Task not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_tasks_by_ids")
|
||||
async def get_tasks_by_ids(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> list[Task]:
|
||||
async with self.Session() as session:
|
||||
tasks = (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter(TaskModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
|
||||
@db_operation("get_step")
|
||||
async def get_step(self, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_task_steps")
|
||||
async def get_task_steps(self, task_id: str, organization_id: str) -> list[Step]:
|
||||
async with self.Session() as session:
|
||||
if steps := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_steps_by_task_ids")
|
||||
async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]:
|
||||
async with self.Session() as session:
|
||||
steps = (
|
||||
await session.scalars(
|
||||
select(StepModel).filter(StepModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
|
||||
@db_operation("get_total_unique_step_order_count_by_task_ids")
|
||||
async def get_total_unique_step_order_count_by_task_ids(
|
||||
self,
|
||||
*,
|
||||
task_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids
|
||||
Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s
|
||||
where s.task_id in task_ids
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(func.count(distinct(tuple_(StepModel.task_id, StepModel.order))))
|
||||
.where(StepModel.task_id.in_(task_ids))
|
||||
.where(StepModel.organization_id == organization_id)
|
||||
)
|
||||
return (await session.execute(query)).scalar()
|
||||
|
||||
@db_operation("get_task_step_models")
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
|
||||
async with self.Session() as session:
|
||||
return (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all()
|
||||
|
||||
@db_operation("get_task_step_count")
|
||||
async def get_task_step_count(self, task_id: str, organization_id: str | None = None) -> int:
|
||||
async with self.Session() as session:
|
||||
result = await session.scalar(
|
||||
select(func.count(StepModel.step_id))
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
return result or 0
|
||||
|
||||
@db_operation("get_task_actions")
|
||||
async def get_task_actions(self, task_id: str, organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id == task_id)
|
||||
.order_by(ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("get_task_actions_hydrated")
|
||||
async def get_task_actions_hydrated(self, task_id: str, organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id == task_id)
|
||||
.order_by(ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [hydrate_action(action) for action in actions]
|
||||
|
||||
@db_operation("get_tasks_actions")
|
||||
async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id.in_(task_ids))
|
||||
.order_by(ActionModel.created_at.desc())
|
||||
)
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [hydrate_action(action) for action in actions]
|
||||
|
||||
@db_operation("get_action_count_for_step")
|
||||
async def get_action_count_for_step(self, step_id: str, task_id: str, organization_id: str) -> int:
|
||||
"""Get count of actions for a step. Uses composite index for efficiency."""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(func.count())
|
||||
.select_from(ActionModel)
|
||||
.where(ActionModel.organization_id == organization_id)
|
||||
.where(ActionModel.task_id == task_id)
|
||||
.where(ActionModel.step_id == step_id)
|
||||
)
|
||||
result = await session.scalar(query)
|
||||
return result or 0
|
||||
|
||||
@db_operation("get_first_step")
|
||||
async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.asc())
|
||||
.order_by(StepModel.retry_index.asc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Latest step not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_step")
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(StepModel.status != StepStatus.canceled)
|
||||
.order_by(StepModel.order.desc())
|
||||
.order_by(StepModel.retry_index.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Latest step not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("update_step")
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
incremental_cost: float | None = None,
|
||||
incremental_input_tokens: int | None = None,
|
||||
incremental_output_tokens: int | None = None,
|
||||
incremental_reasoning_tokens: int | None = None,
|
||||
incremental_cached_tokens: int | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> Step:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
step.status = status
|
||||
|
||||
if status.is_terminal() and step.finished_at is None:
|
||||
step.finished_at = datetime.now(timezone.utc)
|
||||
if output is not None:
|
||||
step.output = output.model_dump(exclude_none=True)
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if incremental_cost is not None:
|
||||
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||
if incremental_input_tokens is not None:
|
||||
step.input_token_count = incremental_input_tokens + (step.input_token_count or 0)
|
||||
if incremental_output_tokens is not None:
|
||||
step.output_token_count = incremental_output_tokens + (step.output_token_count or 0)
|
||||
if incremental_reasoning_tokens is not None:
|
||||
step.reasoning_token_count = incremental_reasoning_tokens + (step.reasoning_token_count or 0)
|
||||
if incremental_cached_tokens is not None:
|
||||
step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0)
|
||||
if created_by is not None:
|
||||
step.created_by = created_by
|
||||
|
||||
await session.commit()
|
||||
updated_step = await self.get_step(step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
return updated_step
|
||||
else:
|
||||
raise NotFoundError("Step not found")
|
||||
|
||||
@db_operation("clear_task_failure_reason")
|
||||
async def clear_task_failure_reason(self, organization_id: str, task_id: str) -> Task:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.failure_reason = None
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("update_task")
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
errors: list[dict[str, Any]] | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Task:
|
||||
if (
|
||||
status is None
|
||||
and extracted_information is None
|
||||
and failure_reason is None
|
||||
and errors is None
|
||||
and max_steps_per_run is None
|
||||
and webhook_failure_reason is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
|
||||
)
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if status == TaskStatus.queued and task.queued_at is None:
|
||||
task.queued_at = datetime.now(timezone.utc)
|
||||
if status == TaskStatus.running and task.started_at is None:
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
if status.is_final() and task.finished_at is None:
|
||||
task.finished_at = datetime.now(timezone.utc)
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
if errors is not None:
|
||||
task.errors = (task.errors or []) + errors
|
||||
if max_steps_per_run is not None:
|
||||
task.max_steps_per_run = max_steps_per_run
|
||||
if webhook_failure_reason is not None:
|
||||
task.webhook_failure_reason = webhook_failure_reason
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("update_task_2fa_state")
|
||||
async def update_task_2fa_state(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
waiting_for_verification_code: bool,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> Task:
|
||||
"""Update task 2FA verification code waiting state."""
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
task.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if not waiting_for_verification_code:
|
||||
# Clear identifiers when no longer waiting
|
||||
task.verification_code_identifier = None
|
||||
task.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("bulk_update_tasks")
|
||||
async def bulk_update_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
status: TaskStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Bulk update tasks by their IDs.
|
||||
|
||||
Args:
|
||||
task_ids: List of task IDs to update
|
||||
status: Optional status to set for all tasks
|
||||
failure_reason: Optional failure reason to set for all tasks
|
||||
"""
|
||||
if not task_ids:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
update_values = {}
|
||||
if status:
|
||||
update_values["status"] = status.value
|
||||
if failure_reason:
|
||||
update_values["failure_reason"] = failure_reason
|
||||
|
||||
if update_values:
|
||||
update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values)
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_tasks")
|
||||
async def get_tasks(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
task_status: list[TaskStatus] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
only_standalone_tasks: bool = False,
|
||||
application: str | None = None,
|
||||
order_by_column: OrderBy = OrderBy.created_at,
|
||||
order: SortDirection = SortDirection.desc,
|
||||
) -> list[Task]:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param page: Starts at 1
|
||||
:param page_size:
|
||||
:param task_status:
|
||||
:param workflow_run_id:
|
||||
:param only_standalone_tasks:
|
||||
:param order_by_column:
|
||||
:param order:
|
||||
:return:
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
query = (
|
||||
select(TaskModel, WorkflowRunModel.workflow_permanent_id)
|
||||
.join(WorkflowRunModel, TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id, isouter=True)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
)
|
||||
if task_status:
|
||||
query = query.filter(TaskModel.status.in_(task_status))
|
||||
if workflow_run_id:
|
||||
query = query.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
if only_standalone_tasks:
|
||||
query = query.filter(TaskModel.workflow_run_id.is_(None))
|
||||
if application:
|
||||
query = query.filter(TaskModel.application == application)
|
||||
order_by_col = getattr(TaskModel, order_by_column)
|
||||
query = (
|
||||
query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
)
|
||||
|
||||
results = (await session.execute(query)).all()
|
||||
|
||||
return [
|
||||
convert_to_task(task, debug_enabled=self.debug_enabled, workflow_permanent_id=workflow_permanent_id)
|
||||
for task, workflow_permanent_id in results
|
||||
]
|
||||
|
||||
@db_operation("get_tasks_count")
|
||||
async def get_tasks_count(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_status: list[TaskStatus] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
only_standalone_tasks: bool = False,
|
||||
application: str | None = None,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
count_query = (
|
||||
select(func.count()).select_from(TaskModel).filter(TaskModel.organization_id == organization_id)
|
||||
)
|
||||
if task_status:
|
||||
count_query = count_query.filter(TaskModel.status.in_(task_status))
|
||||
if workflow_run_id:
|
||||
count_query = count_query.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
if only_standalone_tasks:
|
||||
count_query = count_query.filter(TaskModel.workflow_run_id.is_(None))
|
||||
if application:
|
||||
count_query = count_query.filter(TaskModel.application == application)
|
||||
return (await session.execute(count_query)).scalar_one()
|
||||
|
||||
@db_operation("get_running_tasks_info_globally")
|
||||
async def get_running_tasks_info_globally(
|
||||
self,
|
||||
stale_threshold_hours: int = 24,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get information about running tasks across all organizations.
|
||||
Used by cleanup service to determine if cleanup should be skipped.
|
||||
|
||||
Args:
|
||||
stale_threshold_hours: Tasks not updated for this many hours are considered stale.
|
||||
|
||||
Returns:
|
||||
Tuple of (active_task_count, stale_task_count).
|
||||
Active tasks are those updated within the threshold.
|
||||
Stale tasks are those not updated within the threshold but still in running status.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
running_statuses = [TaskStatus.created, TaskStatus.queued, TaskStatus.running]
|
||||
stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=stale_threshold_hours)
|
||||
|
||||
# Count active tasks (recently updated)
|
||||
active_query = (
|
||||
select(func.count())
|
||||
.select_from(TaskModel)
|
||||
.filter(TaskModel.status.in_(running_statuses))
|
||||
.filter(TaskModel.modified_at >= stale_cutoff)
|
||||
)
|
||||
active_count = (await session.execute(active_query)).scalar_one()
|
||||
|
||||
# Count stale tasks (not updated for a long time)
|
||||
stale_query = (
|
||||
select(func.count())
|
||||
.select_from(TaskModel)
|
||||
.filter(TaskModel.status.in_(running_statuses))
|
||||
.filter(TaskModel.modified_at < stale_cutoff)
|
||||
)
|
||||
stale_count = (await session.execute(stale_query)).scalar_one()
|
||||
|
||||
return (active_count, stale_count)
|
||||
|
||||
@db_operation("get_latest_task_by_workflow_id")
|
||||
async def get_latest_task_by_workflow_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_last_task_for_workflow_run")
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_tasks_by_workflow_run_id")
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
async with self.Session() as session:
|
||||
tasks = (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
|
||||
@db_operation("delete_task_steps")
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
StepModel.organization_id == organization_id,
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_previous_actions_for_task")
|
||||
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
|
||||
)
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("delete_task_actions")
|
||||
async def delete_task_actions(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete actions by filtering organization_id and task_id
|
||||
stmt = delete(ActionModel).where(
|
||||
and_(
|
||||
ActionModel.organization_id == organization_id,
|
||||
ActionModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
699
skyvern/forge/sdk/db/repositories/workflow_parameters.py
Normal file
699
skyvern/forge/sdk/db/repositories/workflow_parameters.py
Normal file
|
|
@ -0,0 +1,699 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
AISuggestionModel,
|
||||
AWSSecretParameterModel,
|
||||
AzureVaultCredentialParameterModel,
|
||||
Base,
|
||||
BitwardenCreditCardDataParameterModel,
|
||||
BitwardenLoginCredentialParameterModel,
|
||||
BitwardenSensitiveInformationParameterModel,
|
||||
CredentialParameterModel,
|
||||
OnePasswordCredentialParameterModel,
|
||||
OutputParameterModel,
|
||||
TaskGenerationModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowCopilotChatMessageModel,
|
||||
WorkflowCopilotChatModel,
|
||||
WorkflowParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_output_parameter,
|
||||
convert_to_workflow_copilot_chat_message,
|
||||
convert_to_workflow_parameter,
|
||||
hydrate_action,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||
WorkflowCopilotChat,
|
||||
WorkflowCopilotChatMessage,
|
||||
WorkflowCopilotChatSender,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
AWSSecretParameter,
|
||||
AzureVaultCredentialParameter,
|
||||
BitwardenCreditCardDataParameter,
|
||||
BitwardenLoginCredentialParameter,
|
||||
BitwardenSensitiveInformationParameter,
|
||||
ContextParameter,
|
||||
CredentialParameter,
|
||||
OnePasswordCredentialParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.schemas.runs import RunType
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
from . import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowParametersRepository(BaseRepository):
|
||||
"""Database operations for workflow parameters, copilot chat, task generation, actions, and runs."""
|
||||
|
||||
@db_operation("create_workflow_parameter")
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: Any,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
async with self.Session() as session:
|
||||
if default_value is None:
|
||||
pass
|
||||
elif workflow_parameter_type == WorkflowParameterType.JSON:
|
||||
default_value = json.dumps(default_value)
|
||||
else:
|
||||
default_value = str(default_value)
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("create_aws_secret_parameter")
|
||||
async def create_aws_secret_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
async with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
aws_key=aws_key,
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
@db_operation("create_output_parameter")
|
||||
async def create_output_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
description: str | None = None,
|
||||
) -> OutputParameter:
|
||||
async with self.Session() as session:
|
||||
output_parameter = OutputParameterModel(
|
||||
key=key,
|
||||
description=description,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(output_parameter)
|
||||
return convert_to_output_parameter(output_parameter)
|
||||
|
||||
@staticmethod
|
||||
def _convert_parameter_to_model(parameter: PARAMETER_TYPE) -> Base:
|
||||
"""Convert a parameter object to its corresponding SQLAlchemy model."""
|
||||
if isinstance(parameter, WorkflowParameter):
|
||||
if parameter.default_value is None:
|
||||
default_value = None
|
||||
elif parameter.workflow_parameter_type == WorkflowParameterType.JSON:
|
||||
default_value = json.dumps(parameter.default_value)
|
||||
else:
|
||||
default_value = str(parameter.default_value)
|
||||
return WorkflowParameterModel(
|
||||
workflow_parameter_id=parameter.workflow_parameter_id,
|
||||
workflow_parameter_type=parameter.workflow_parameter_type.value,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
workflow_id=parameter.workflow_id,
|
||||
default_value=default_value,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, OutputParameter):
|
||||
return OutputParameterModel(
|
||||
output_parameter_id=parameter.output_parameter_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
workflow_id=parameter.workflow_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, AWSSecretParameter):
|
||||
return AWSSecretParameterModel(
|
||||
aws_secret_parameter_id=parameter.aws_secret_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
aws_key=parameter.aws_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenLoginCredentialParameter):
|
||||
return BitwardenLoginCredentialParameterModel(
|
||||
bitwarden_login_credential_parameter_id=parameter.bitwarden_login_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_item_id=parameter.bitwarden_item_id,
|
||||
url_parameter_key=parameter.url_parameter_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenSensitiveInformationParameter):
|
||||
return BitwardenSensitiveInformationParameterModel(
|
||||
bitwarden_sensitive_information_parameter_id=parameter.bitwarden_sensitive_information_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_identity_key=parameter.bitwarden_identity_key,
|
||||
bitwarden_identity_fields=parameter.bitwarden_identity_fields,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenCreditCardDataParameter):
|
||||
return BitwardenCreditCardDataParameterModel(
|
||||
bitwarden_credit_card_data_parameter_id=parameter.bitwarden_credit_card_data_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_item_id=parameter.bitwarden_item_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, CredentialParameter):
|
||||
return CredentialParameterModel(
|
||||
credential_parameter_id=parameter.credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
credential_id=parameter.credential_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, OnePasswordCredentialParameter):
|
||||
return OnePasswordCredentialParameterModel(
|
||||
onepassword_credential_parameter_id=parameter.onepassword_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
vault_id=parameter.vault_id,
|
||||
item_id=parameter.item_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, AzureVaultCredentialParameter):
|
||||
return AzureVaultCredentialParameterModel(
|
||||
azure_vault_credential_parameter_id=parameter.azure_vault_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
vault_name=parameter.vault_name,
|
||||
username_key=parameter.username_key,
|
||||
password_key=parameter.password_key,
|
||||
totp_secret_key=parameter.totp_secret_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported workflow definition parameter type: {type(parameter).__name__}")
|
||||
|
||||
@db_operation("save_workflow_definition_parameters")
|
||||
async def save_workflow_definition_parameters(self, parameters: list[PARAMETER_TYPE]) -> None:
|
||||
"""Save multiple workflow definition parameters in a single transaction."""
|
||||
|
||||
# ContextParameter is not persisted
|
||||
parameters_to_save = [p for p in parameters if not isinstance(p, ContextParameter)]
|
||||
if not parameters_to_save:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
for parameter in parameters_to_save:
|
||||
model = self._convert_parameter_to_model(parameter)
|
||||
session.add(model)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_workflow_output_parameters")
|
||||
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
|
||||
async with self.Session() as session:
|
||||
output_parameters = (
|
||||
await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
|
||||
@db_operation("get_workflow_output_parameters_by_ids")
|
||||
async def get_workflow_output_parameters_by_ids(self, output_parameter_ids: list[str]) -> list[OutputParameter]:
|
||||
async with self.Session() as session:
|
||||
output_parameters = (
|
||||
await session.scalars(
|
||||
select(OutputParameterModel).filter(
|
||||
OutputParameterModel.output_parameter_id.in_(output_parameter_ids)
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
|
||||
@db_operation("get_workflow_parameters")
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
async with self.Session() as session:
|
||||
workflow_parameters = (
|
||||
await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
|
||||
@db_operation("get_workflow_parameter")
|
||||
async def get_workflow_parameter(
|
||||
self, workflow_parameter_id: str, organization_id: str | None = None
|
||||
) -> WorkflowParameter | None:
|
||||
async with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
await session.scalars(
|
||||
select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("create_task_generation")
|
||||
async def create_task_generation(
|
||||
self,
|
||||
organization_id: str,
|
||||
user_prompt: str,
|
||||
user_prompt_hash: str,
|
||||
url: str | None = None,
|
||||
navigation_goal: str | None = None,
|
||||
navigation_payload: dict[str, Any] | None = None,
|
||||
data_extraction_goal: str | None = None,
|
||||
extracted_information_schema: dict[str, Any] | None = None,
|
||||
suggested_title: str | None = None,
|
||||
llm: str | None = None,
|
||||
llm_prompt: str | None = None,
|
||||
llm_response: str | None = None,
|
||||
source_task_generation_id: str | None = None,
|
||||
) -> TaskGeneration:
|
||||
async with self.Session() as session:
|
||||
new_task_generation = TaskGenerationModel(
|
||||
organization_id=organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=url,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
llm=llm,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=llm_response,
|
||||
suggested_title=suggested_title,
|
||||
source_task_generation_id=source_task_generation_id,
|
||||
)
|
||||
session.add(new_task_generation)
|
||||
await session.commit()
|
||||
await session.refresh(new_task_generation)
|
||||
return TaskGeneration.model_validate(new_task_generation)
|
||||
|
||||
@db_operation("create_ai_suggestion")
|
||||
async def create_ai_suggestion(
|
||||
self,
|
||||
organization_id: str,
|
||||
ai_suggestion_type: str,
|
||||
) -> AISuggestion:
|
||||
async with self.Session() as session:
|
||||
new_ai_suggestion = AISuggestionModel(
|
||||
organization_id=organization_id,
|
||||
ai_suggestion_type=ai_suggestion_type,
|
||||
)
|
||||
session.add(new_ai_suggestion)
|
||||
await session.commit()
|
||||
await session.refresh(new_ai_suggestion)
|
||||
return AISuggestion.model_validate(new_ai_suggestion)
|
||||
|
||||
@db_operation("create_workflow_copilot_chat")
|
||||
async def create_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> WorkflowCopilotChat:
|
||||
async with self.Session() as session:
|
||||
new_chat = WorkflowCopilotChatModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
)
|
||||
session.add(new_chat)
|
||||
await session.commit()
|
||||
await session.refresh(new_chat)
|
||||
return WorkflowCopilotChat.model_validate(new_chat)
|
||||
|
||||
@db_operation("update_workflow_copilot_chat")
|
||||
async def update_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
proposed_workflow: dict | None | object = _UNSET,
|
||||
auto_accept: bool | None = None,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
chat = (
|
||||
await session.scalars(
|
||||
select(WorkflowCopilotChatModel)
|
||||
.where(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.where(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
)
|
||||
).first()
|
||||
if not chat:
|
||||
return None
|
||||
|
||||
if proposed_workflow is not _UNSET:
|
||||
chat.proposed_workflow = proposed_workflow
|
||||
if auto_accept is not None:
|
||||
chat.auto_accept = auto_accept
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("create_workflow_copilot_chat_message")
|
||||
async def create_workflow_copilot_chat_message(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
sender: WorkflowCopilotChatSender,
|
||||
content: str,
|
||||
global_llm_context: str | None = None,
|
||||
) -> WorkflowCopilotChatMessage:
|
||||
async with self.Session() as session:
|
||||
new_message = WorkflowCopilotChatMessageModel(
|
||||
workflow_copilot_chat_id=workflow_copilot_chat_id,
|
||||
organization_id=organization_id,
|
||||
sender=sender,
|
||||
content=content,
|
||||
global_llm_context=global_llm_context,
|
||||
)
|
||||
session.add(new_message)
|
||||
await session.commit()
|
||||
await session.refresh(new_message)
|
||||
return convert_to_workflow_copilot_chat_message(new_message, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_copilot_chat_messages")
|
||||
async def get_workflow_copilot_chat_messages(
|
||||
self,
|
||||
workflow_copilot_chat_id: str,
|
||||
) -> list[WorkflowCopilotChatMessage]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatMessageModel)
|
||||
.filter(WorkflowCopilotChatMessageModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
.order_by(WorkflowCopilotChatMessageModel.workflow_copilot_chat_message_id.asc())
|
||||
)
|
||||
messages = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_copilot_chat_message(message, self.debug_enabled) for message in messages]
|
||||
|
||||
@db_operation("get_workflow_copilot_chat_by_id")
|
||||
async def get_workflow_copilot_chat_by_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatModel)
|
||||
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.filter(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
.order_by(WorkflowCopilotChatModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
chat = (await session.scalars(query)).first()
|
||||
if not chat:
|
||||
return None
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("get_latest_workflow_copilot_chat")
|
||||
async def get_latest_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatModel)
|
||||
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.filter(WorkflowCopilotChatModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.order_by(WorkflowCopilotChatModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
chat = (await session.scalars(query)).first()
|
||||
if not chat:
|
||||
return None
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("get_task_generation_by_prompt_hash")
|
||||
async def get_task_generation_by_prompt_hash(
|
||||
self,
|
||||
user_prompt_hash: str,
|
||||
query_window_hours: int = settings.PROMPT_CACHE_WINDOW_HOURS,
|
||||
) -> TaskGeneration | None:
|
||||
before_time = datetime.now(timezone.utc) - timedelta(hours=query_window_hours)
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TaskGenerationModel)
|
||||
.filter_by(user_prompt_hash=user_prompt_hash)
|
||||
.filter(TaskGenerationModel.llm.is_not(None))
|
||||
.filter(TaskGenerationModel.created_at > before_time)
|
||||
)
|
||||
task_generation = (await session.scalars(query)).first()
|
||||
if not task_generation:
|
||||
return None
|
||||
return TaskGeneration.model_validate(task_generation)
|
||||
|
||||
@db_operation("create_action")
|
||||
async def create_action(self, action: Action) -> Action:
|
||||
async with self.Session() as session:
|
||||
new_action = ActionModel(
|
||||
action_type=action.action_type,
|
||||
source_action_id=action.source_action_id,
|
||||
organization_id=action.organization_id,
|
||||
workflow_run_id=action.workflow_run_id,
|
||||
task_id=action.task_id,
|
||||
step_id=action.step_id,
|
||||
step_order=action.step_order,
|
||||
action_order=action.action_order,
|
||||
status=action.status,
|
||||
reasoning=action.reasoning,
|
||||
intention=action.intention,
|
||||
response=action.response,
|
||||
element_id=action.element_id,
|
||||
skyvern_element_hash=action.skyvern_element_hash,
|
||||
skyvern_element_data=action.skyvern_element_data,
|
||||
screenshot_artifact_id=action.screenshot_artifact_id,
|
||||
action_json=action.model_dump(),
|
||||
confidence_float=action.confidence_float,
|
||||
created_by=action.created_by,
|
||||
)
|
||||
session.add(new_action)
|
||||
await session.commit()
|
||||
await session.refresh(new_action)
|
||||
return hydrate_action(new_action)
|
||||
|
||||
@db_operation("update_action_reasoning")
|
||||
async def update_action_reasoning(
|
||||
self,
|
||||
organization_id: str,
|
||||
action_id: str,
|
||||
reasoning: str,
|
||||
) -> Action:
|
||||
async with self.Session() as session:
|
||||
action = (
|
||||
await session.scalars(
|
||||
select(ActionModel).filter_by(action_id=action_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if action:
|
||||
action.reasoning = reasoning
|
||||
await session.commit()
|
||||
await session.refresh(action)
|
||||
return Action.model_validate(action)
|
||||
raise NotFoundError(f"Action {action_id}")
|
||||
|
||||
@db_operation("retrieve_action_plan")
|
||||
async def retrieve_action_plan(self, task: Task) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(TaskModel.task_id)
|
||||
.filter(TaskModel.url == task.url)
|
||||
.filter(TaskModel.navigation_goal == task.navigation_goal)
|
||||
.filter(TaskModel.status == TaskStatus.completed)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(1)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.task_id == subquery.c.task_id)
|
||||
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title:
|
||||
task_run.title = title
|
||||
if url:
|
||||
task_run.url = url
|
||||
if url_hash:
|
||||
task_run.url_hash = url_hash
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
903
skyvern/forge/sdk/db/repositories/workflow_runs.py
Normal file
903
skyvern/forge/sdk/db/repositories/workflow_runs.py
Normal file
|
|
@ -0,0 +1,903 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import Text, and_, cast, exists, func, literal, literal_column, or_, select, update
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import read_retry
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.enums import WorkflowRunTriggerType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunBlockModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunOutputParameterModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.protocols import WorkflowParameterReader
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_task,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_output_parameter,
|
||||
convert_to_workflow_run_parameter,
|
||||
serialize_proxy_location,
|
||||
)
|
||||
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
|
||||
from . import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowRunsRepository(BaseRepository):
|
||||
"""Database operations for workflow runs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: _SessionFactory,
|
||||
debug_enabled: bool = False,
|
||||
is_retryable_error_fn: Callable[[SQLAlchemyError], bool] | None = None,
|
||||
workflow_parameter_reader: WorkflowParameterReader | None = None,
|
||||
dialect_name: str = "postgresql",
|
||||
) -> None:
|
||||
super().__init__(session_factory, debug_enabled, is_retryable_error_fn)
|
||||
self._workflow_parameter_reader = workflow_parameter_reader
|
||||
self._dialect_name = dialect_name
|
||||
|
||||
@db_operation("get_running_workflow_runs_info_globally")
|
||||
async def get_running_workflow_runs_info_globally(
|
||||
self,
|
||||
stale_threshold_hours: int = 24,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get information about running workflow runs across all organizations.
|
||||
Used by cleanup service to determine if cleanup should be skipped.
|
||||
|
||||
Args:
|
||||
stale_threshold_hours: Workflow runs not updated for this many hours are considered stale.
|
||||
|
||||
Returns:
|
||||
Tuple of (active_workflow_count, stale_workflow_count).
|
||||
Active workflows are those updated within the threshold.
|
||||
Stale workflows are those not updated within the threshold but still in running status.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
running_statuses = [
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.paused,
|
||||
]
|
||||
stale_cutoff = datetime.now(timezone.utc) - timedelta(hours=stale_threshold_hours)
|
||||
|
||||
# Count active workflow runs (recently updated)
|
||||
active_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.status.in_(running_statuses))
|
||||
.filter(WorkflowRunModel.modified_at >= stale_cutoff)
|
||||
)
|
||||
active_count = (await session.execute(active_query)).scalar_one()
|
||||
|
||||
# Count stale workflow runs (not updated for a long time)
|
||||
stale_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.status.in_(running_statuses))
|
||||
.filter(WorkflowRunModel.modified_at < stale_cutoff)
|
||||
)
|
||||
stale_count = (await session.execute(stale_query)).scalar_one()
|
||||
|
||||
return (active_count, stale_count)
|
||||
|
||||
@db_operation("create_workflow_run")
|
||||
async def create_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
browser_session_id: str | None = None,
|
||||
browser_profile_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_address: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
run_with: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
ai_fallback: bool | None = None,
|
||||
code_gen: bool | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
trigger_type: WorkflowRunTriggerType | None = None,
|
||||
workflow_schedule_id: str | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if workflow_run_id is not None:
|
||||
kwargs["workflow_run_id"] = workflow_run_id
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
browser_profile_id=browser_profile_id,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
status="created",
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_address=browser_address,
|
||||
sequential_key=sequential_key,
|
||||
run_with=run_with,
|
||||
debug_session_id=debug_session_id,
|
||||
ai_fallback=ai_fallback,
|
||||
code_gen=code_gen,
|
||||
trigger_type=trigger_type.value if trigger_type else None,
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
|
||||
@db_operation("update_workflow_run")
|
||||
async def update_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
ai_fallback_triggered: bool | None = None,
|
||||
job_id: str | None = None,
|
||||
run_with: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
ai_fallback: bool | None = None,
|
||||
depends_on_workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
waiting_for_verification_code: bool | None = None,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
browser_profile_id: str | None | object = _UNSET,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).first()
|
||||
if workflow_run:
|
||||
if status:
|
||||
workflow_run.status = status
|
||||
if status and status == WorkflowRunStatus.queued and workflow_run.queued_at is None:
|
||||
workflow_run.queued_at = datetime.now(timezone.utc)
|
||||
if status and status == WorkflowRunStatus.running and workflow_run.started_at is None:
|
||||
workflow_run.started_at = datetime.now(timezone.utc)
|
||||
if status and status.is_final() and workflow_run.finished_at is None:
|
||||
workflow_run.finished_at = datetime.now(timezone.utc)
|
||||
if failure_reason:
|
||||
workflow_run.failure_reason = failure_reason
|
||||
if webhook_failure_reason is not None:
|
||||
workflow_run.webhook_failure_reason = webhook_failure_reason
|
||||
if ai_fallback_triggered is not None:
|
||||
workflow_run.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
|
||||
if job_id:
|
||||
workflow_run.job_id = job_id
|
||||
if run_with:
|
||||
workflow_run.run_with = run_with
|
||||
if sequential_key:
|
||||
workflow_run.sequential_key = sequential_key
|
||||
if ai_fallback is not None:
|
||||
workflow_run.ai_fallback = ai_fallback
|
||||
if depends_on_workflow_run_id:
|
||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
# 2FA verification code waiting state updates
|
||||
if waiting_for_verification_code is not None:
|
||||
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
workflow_run.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||
# Clear related fields when waiting is set to False
|
||||
workflow_run.verification_code_identifier = None
|
||||
workflow_run.verification_code_polling_started_at = None
|
||||
if browser_profile_id is not _UNSET:
|
||||
workflow_run.browser_profile_id = browser_profile_id
|
||||
await session.commit()
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
else:
|
||||
raise WorkflowRunNotFound(workflow_run_id)
|
||||
|
||||
@db_operation("bulk_update_workflow_runs")
|
||||
async def bulk_update_workflow_runs(
|
||||
self,
|
||||
workflow_run_ids: list[str],
|
||||
status: WorkflowRunStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Bulk update workflow runs by their IDs.
|
||||
|
||||
Args:
|
||||
workflow_run_ids: List of workflow run IDs to update
|
||||
status: Optional status to set for all workflow runs
|
||||
failure_reason: Optional failure reason to set for all workflow runs
|
||||
"""
|
||||
if not workflow_run_ids:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
update_values = {}
|
||||
if status:
|
||||
update_values["status"] = status.value
|
||||
if failure_reason:
|
||||
update_values["failure_reason"] = failure_reason
|
||||
|
||||
if update_values:
|
||||
update_stmt = (
|
||||
update(WorkflowRunModel)
|
||||
.where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
|
||||
.values(**update_values)
|
||||
)
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("clear_workflow_run_failure_reason")
|
||||
async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run:
|
||||
workflow_run.failure_reason = None
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
else:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
@db_operation("get_all_runs")
|
||||
async def get_all_runs(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
include_debugger_runs: bool = False,
|
||||
search_key: str | None = None,
|
||||
) -> list[WorkflowRun | Task]:
|
||||
async with self.Session() as session:
|
||||
# temporary limit to 10 pages
|
||||
if page > 10:
|
||||
return []
|
||||
|
||||
limit = page * page_size
|
||||
|
||||
workflow_run_query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_debugger_runs:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.debug_session_id.is_(None))
|
||||
|
||||
if search_key:
|
||||
key_like = f"%{search_key}%"
|
||||
# Match workflow_run_id directly
|
||||
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
|
||||
# Match parameter key or description (only for non-deleted parameter definitions)
|
||||
param_key_desc_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.join(
|
||||
WorkflowParameterModel,
|
||||
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
|
||||
)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(key_like),
|
||||
WorkflowParameterModel.description.ilike(key_like),
|
||||
)
|
||||
)
|
||||
)
|
||||
# Match run parameter value directly (searches all values regardless of parameter definition status)
|
||||
param_value_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowRunParameterModel.value.ilike(key_like))
|
||||
)
|
||||
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
|
||||
extra_headers_match = and_(
|
||||
WorkflowRunModel.extra_http_headers.isnot(None),
|
||||
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
|
||||
)
|
||||
workflow_run_query = workflow_run_query.where(
|
||||
or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match)
|
||||
)
|
||||
|
||||
if status:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
|
||||
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
|
||||
workflow_run_query_result = (await session.execute(workflow_run_query)).all()
|
||||
workflow_runs = [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_run_query_result
|
||||
]
|
||||
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
|
||||
runs = workflow_runs + tasks
|
||||
|
||||
runs.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
lower = (page - 1) * page_size
|
||||
upper = page * page_size
|
||||
|
||||
return runs[lower:upper]
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_workflow_run", log_errors=False)
|
||||
async def get_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
|
||||
if organization_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
|
||||
if job_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
|
||||
if status:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
|
||||
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
async def get_run(self, run_id: str, organization_id: str | None = None) -> WorkflowRun | None:
|
||||
"""Alias satisfying the RunReader protocol."""
|
||||
return await self.get_workflow_run(run_id, organization_id=organization_id)
|
||||
|
||||
@db_operation("get_last_queued_workflow_run")
|
||||
async def get_last_queued_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(status=WorkflowRunStatus.queued)
|
||||
if sequential_key:
|
||||
query = query.filter_by(sequential_key=sequential_key)
|
||||
query = query.order_by(WorkflowRunModel.modified_at.desc())
|
||||
workflow_run = (await session.scalars(query)).first()
|
||||
return convert_to_workflow_run(workflow_run) if workflow_run else None
|
||||
|
||||
@db_operation("get_workflow_runs_by_ids")
|
||||
async def get_workflow_runs_by_ids(
|
||||
self,
|
||||
workflow_run_ids: list[str],
|
||||
workflow_permanent_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
|
||||
if workflow_permanent_id:
|
||||
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
|
||||
|
||||
@db_operation("get_last_running_workflow_run")
|
||||
async def get_last_running_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(status=WorkflowRunStatus.running)
|
||||
if sequential_key:
|
||||
query = query.filter_by(sequential_key=sequential_key)
|
||||
query = query.filter(
|
||||
WorkflowRunModel.started_at.isnot(None)
|
||||
) # filter out workflow runs that does not have a started_at timestamp
|
||||
query = query.order_by(WorkflowRunModel.started_at.desc())
|
||||
workflow_run = (await session.scalars(query)).first()
|
||||
return convert_to_workflow_run(workflow_run) if workflow_run else None
|
||||
|
||||
async def _get_last_workflow_run_by_filter(
|
||||
self,
|
||||
organization_id: str | None = None,
|
||||
**filters: str,
|
||||
) -> WorkflowRun | None:
|
||||
"""Get the last queued or running workflow run matching the given column filters.
|
||||
|
||||
Used for browser_session_id and browser_address sequential execution.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(**filters)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
|
||||
# check if there's a queued run
|
||||
queue_query = query.filter_by(status=WorkflowRunStatus.queued)
|
||||
queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc())
|
||||
workflow_run = (await session.scalars(queue_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
|
||||
# check if there's a running run
|
||||
running_query = query.filter_by(status=WorkflowRunStatus.running)
|
||||
running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None))
|
||||
running_query = running_query.order_by(WorkflowRunModel.started_at.desc())
|
||||
workflow_run = (await session.scalars(running_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
async def get_last_workflow_run_for_browser_session(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
return await self._get_last_workflow_run_by_filter(
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
async def get_last_workflow_run_for_browser_address(
|
||||
self,
|
||||
browser_address: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
return await self._get_last_workflow_run_by_filter(
|
||||
organization_id=organization_id,
|
||||
browser_address=browser_address,
|
||||
)
|
||||
|
||||
@db_operation("get_workflows_depending_on")
|
||||
async def get_workflows_depending_on(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> list[WorkflowRun]:
|
||||
"""
|
||||
Get all workflow runs that depend on the given workflow_run_id.
|
||||
|
||||
Used to find workflows that should be signaled when a workflow completes,
|
||||
for sequential workflow dependency handling.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow_run_id to find dependents for
|
||||
|
||||
Returns:
|
||||
List of WorkflowRun objects that have depends_on_workflow_run_id set to workflow_run_id
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(depends_on_workflow_run_id=workflow_run_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
|
||||
|
||||
@staticmethod
|
||||
def _apply_search_key_filter(query, search_key: str | None): # type: ignore[no-untyped-def]
|
||||
if not search_key:
|
||||
return query
|
||||
key_like = f"%{search_key}%"
|
||||
# Match workflow_run_id directly
|
||||
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
|
||||
# Match parameter key or description (only for non-deleted parameter definitions)
|
||||
# Use EXISTS to avoid duplicate rows and to keep pagination correct
|
||||
param_key_desc_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.join(
|
||||
WorkflowParameterModel,
|
||||
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
|
||||
)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(key_like),
|
||||
WorkflowParameterModel.description.ilike(key_like),
|
||||
)
|
||||
)
|
||||
)
|
||||
# Match run parameter value directly (searches all values regardless of parameter definition status)
|
||||
param_value_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowRunParameterModel.value.ilike(key_like))
|
||||
)
|
||||
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
|
||||
extra_headers_match = and_(
|
||||
WorkflowRunModel.extra_http_headers.isnot(None),
|
||||
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
|
||||
)
|
||||
return query.where(or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match))
|
||||
|
||||
def _apply_error_code_filter(self, query, error_code: str | None): # type: ignore[no-untyped-def]
|
||||
if not error_code:
|
||||
return query
|
||||
|
||||
if self._dialect_name == "sqlite":
|
||||
# Task errors: array of objects like [{"error_code": "timeout", ...}]
|
||||
# Use json_each to iterate + json_extract to match the error_code field
|
||||
error_code_in_tasks = exists(
|
||||
select(1)
|
||||
.select_from(TaskModel)
|
||||
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(func.json_each(TaskModel.errors))
|
||||
.where(func.json_extract(literal_column("json_each.value"), "$.error_code") == error_code)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Block errors: flat array of strings like ["timeout", "network_error"]
|
||||
error_code_in_blocks = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunBlockModel)
|
||||
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(func.json_each(WorkflowRunBlockModel.error_codes))
|
||||
.where(literal_column("json_each.value") == error_code)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# PostgreSQL: native JSONB containment
|
||||
error_code_in_tasks = exists(
|
||||
select(1)
|
||||
.select_from(TaskModel)
|
||||
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(cast(TaskModel.errors, JSONB).contains(literal([{"error_code": error_code}], type_=JSONB)))
|
||||
)
|
||||
error_code_in_blocks = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunBlockModel)
|
||||
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(cast(WorkflowRunBlockModel.error_codes, JSONB).contains(literal([error_code], type_=JSONB)))
|
||||
)
|
||||
return query.where(or_(error_code_in_tasks, error_code_in_blocks))
|
||||
|
||||
@db_operation("get_workflow_runs")
|
||||
async def get_workflow_runs(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
ordering: tuple[str, str] | None = None,
|
||||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
|
||||
query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
|
||||
)
|
||||
|
||||
query = self._apply_search_key_filter(query, search_key)
|
||||
query = self._apply_error_code_filter(query, error_code)
|
||||
|
||||
if status:
|
||||
query = query.filter(WorkflowRunModel.status.in_(status))
|
||||
|
||||
allowed_ordering_fields = {
|
||||
"created_at": WorkflowRunModel.created_at,
|
||||
"status": WorkflowRunModel.status,
|
||||
}
|
||||
|
||||
field, direction = ("created_at", "desc")
|
||||
|
||||
if ordering and isinstance(ordering, tuple) and len(ordering) == 2:
|
||||
req_field, req_direction = ordering
|
||||
if req_field in allowed_ordering_fields and req_direction in ("asc", "desc"):
|
||||
field, direction = req_field, req_direction
|
||||
|
||||
order_column = allowed_ordering_fields[field]
|
||||
|
||||
if direction == "asc":
|
||||
query = query.order_by(order_column.asc())
|
||||
else:
|
||||
query = query.order_by(order_column.desc())
|
||||
|
||||
query = query.limit(page_size).offset(db_page * page_size)
|
||||
|
||||
workflow_runs = (await session.execute(query)).all()
|
||||
|
||||
return [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_runs
|
||||
]
|
||||
|
||||
@db_operation("get_workflow_runs_count")
|
||||
async def get_workflow_runs_count(
|
||||
self,
|
||||
organization_id: str,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
)
|
||||
if status:
|
||||
count_query = count_query.filter(WorkflowRunModel.status.in_(status))
|
||||
return (await session.execute(count_query)).scalar_one()
|
||||
|
||||
@db_operation("get_workflow_runs_for_workflow_permanent_id")
|
||||
async def get_workflow_runs_for_workflow_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
"""
|
||||
Get runs for a workflow, with optional `search_key` on run ID, parameter key/description/value,
|
||||
or extra HTTP headers.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
)
|
||||
query = self._apply_search_key_filter(query, search_key)
|
||||
query = self._apply_error_code_filter(query, error_code)
|
||||
if status:
|
||||
query = query.filter(WorkflowRunModel.status.in_(status))
|
||||
query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
workflow_runs_and_titles_tuples = (await session.execute(query)).all()
|
||||
workflow_runs = [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_runs_and_titles_tuples
|
||||
]
|
||||
return workflow_runs
|
||||
|
||||
@db_operation("get_workflow_runs_by_parent_workflow_run_id")
|
||||
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||
self,
|
||||
parent_workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id)
|
||||
if organization_id is not None:
|
||||
query = query.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
|
||||
@db_operation("get_workflow_run_output_parameters")
|
||||
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameters = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
for parameter in workflow_run_output_parameters
|
||||
]
|
||||
|
||||
@db_operation("get_workflow_run_output_parameter_by_id")
|
||||
async def get_workflow_run_output_parameter_by_id(
|
||||
self, workflow_run_id: str, output_parameter_id: str
|
||||
) -> WorkflowRunOutputParameter | None:
|
||||
async with self.Session() as session:
|
||||
parameter = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
)
|
||||
).first()
|
||||
|
||||
if parameter:
|
||||
return convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
|
||||
return None
|
||||
|
||||
@db_operation("create_or_update_workflow_run_output_parameter")
|
||||
async def create_or_update_workflow_run_output_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
output_parameter_id: str,
|
||||
value: dict[str, Any] | list | str | None,
|
||||
) -> WorkflowRunOutputParameter:
|
||||
async with self.Session() as session:
|
||||
# check if the workflow run output parameter already exists
|
||||
# if it does, update the value
|
||||
if workflow_run_output_parameter := (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
)
|
||||
).first():
|
||||
LOG.info(
|
||||
"Updating existing workflow run output parameter",
|
||||
workflow_run_id=workflow_run_output_parameter.workflow_run_id,
|
||||
output_parameter_id=workflow_run_output_parameter.output_parameter_id,
|
||||
)
|
||||
workflow_run_output_parameter.value = value
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
# if it does not exist, create a new one
|
||||
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=output_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("update_workflow_run_output_parameter")
|
||||
async def update_workflow_run_output_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
output_parameter_id: str,
|
||||
value: dict[str, Any] | list | str | None,
|
||||
) -> WorkflowRunOutputParameter:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameter = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
)
|
||||
).first()
|
||||
if not workflow_run_output_parameter:
|
||||
raise NotFoundError(
|
||||
f"WorkflowRunOutputParameter not found for {workflow_run_id} and {output_parameter_id}"
|
||||
)
|
||||
workflow_run_output_parameter.value = value
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("create_workflow_run_parameter")
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
workflow_parameter_id = workflow_parameter.workflow_parameter_id
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_parameter)
|
||||
return convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_run_parameters")
|
||||
async def get_workflow_run_parameters(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).all()
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
if self._workflow_parameter_reader is None:
|
||||
raise RuntimeError("workflow_parameter_reader dependency not set")
|
||||
workflow_parameter = await self._workflow_parameter_reader.get_workflow_parameter(
|
||||
workflow_run_parameter.workflow_parameter_id
|
||||
)
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id=workflow_run_parameter.workflow_parameter_id)
|
||||
results.append(
|
||||
(
|
||||
workflow_parameter,
|
||||
convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter,
|
||||
workflow_parameter,
|
||||
self.debug_enabled,
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@db_operation("get_workflow_run_block_errors")
|
||||
async def get_workflow_run_block_errors(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[tuple[list[str], str | None]]:
|
||||
"""Return (error_codes, failure_reason) tuples for blocks with non-null error_codes."""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunBlockModel.error_codes, WorkflowRunBlockModel.failure_reason).filter_by(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
if organization_id is not None:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.where(WorkflowRunBlockModel.error_codes.isnot(None))
|
||||
rows = (await session.execute(query)).all()
|
||||
return [(row.error_codes, row.failure_reason) for row in rows]
|
||||
705
skyvern/forge/sdk/db/repositories/workflows.py
Normal file
705
skyvern/forge/sdk/db/repositories/workflows.py
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import exists, func, or_, select, update
|
||||
|
||||
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db._soft_delete import exclude_deleted
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
AWSSecretParameterModel,
|
||||
AzureVaultCredentialParameterModel,
|
||||
BitwardenCreditCardDataParameterModel,
|
||||
BitwardenLoginCredentialParameterModel,
|
||||
BitwardenSensitiveInformationParameterModel,
|
||||
CredentialParameterModel,
|
||||
FolderModel,
|
||||
OnePasswordCredentialParameterModel,
|
||||
OutputParameterModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowScheduleModel,
|
||||
WorkflowTemplateModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow, serialize_proxy_location
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
from skyvern.schemas.workflows import WorkflowStatus
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowsRepository(BaseRepository):
|
||||
"""Database operations for workflow management."""
|
||||
|
||||
@db_operation("create_workflow")
|
||||
async def create_workflow(
|
||||
self,
|
||||
title: str,
|
||||
workflow_definition: dict[str, Any],
|
||||
organization_id: str | None = None,
|
||||
description: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
persist_browser_session: bool = False,
|
||||
model: dict[str, Any] | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
version: int | None = None,
|
||||
is_saved_task: bool = False,
|
||||
status: WorkflowStatus = WorkflowStatus.published,
|
||||
run_with: str | None = None,
|
||||
ai_fallback: bool = True,
|
||||
cache_key: str | None = None,
|
||||
adaptive_caching: bool = False,
|
||||
code_version: int | None = None,
|
||||
generate_script_on_terminal: bool = False,
|
||||
run_sequentially: bool = False,
|
||||
sequential_key: str | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> Workflow:
|
||||
async with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
persist_browser_session=persist_browser_session,
|
||||
model=model,
|
||||
is_saved_task=is_saved_task,
|
||||
status=status,
|
||||
run_with=run_with,
|
||||
ai_fallback=ai_fallback,
|
||||
cache_key=cache_key or DEFAULT_SCRIPT_RUN_ID,
|
||||
adaptive_caching=adaptive_caching,
|
||||
code_version=code_version,
|
||||
generate_script_on_terminal=generate_script_on_terminal,
|
||||
run_sequentially=run_sequentially,
|
||||
sequential_key=sequential_key,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
if workflow_permanent_id:
|
||||
workflow.workflow_permanent_id = workflow_permanent_id
|
||||
if version:
|
||||
workflow.version = version
|
||||
session.add(workflow)
|
||||
|
||||
# Update folder's modified_at if folder_id is provided
|
||||
if folder_id:
|
||||
# Validate folder exists and belongs to the same organization
|
||||
folder_stmt = (
|
||||
select(FolderModel)
|
||||
.where(FolderModel.folder_id == folder_id)
|
||||
.where(FolderModel.organization_id == organization_id)
|
||||
.where(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
folder_model = await session.scalar(folder_stmt)
|
||||
if not folder_model:
|
||||
raise ValueError(
|
||||
f"Folder {folder_id} not found or does not belong to organization {organization_id}"
|
||||
)
|
||||
folder_model.modified_at = datetime.now(timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
@db_operation("soft_delete_workflow_by_id")
|
||||
async def soft_delete_workflow_by_id(self, workflow_id: str, organization_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# soft delete the workflow by setting the deleted_at field to the current time
|
||||
update_deleted_at_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_id == workflow_id)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.execute(update_deleted_at_query)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_workflow")
|
||||
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
|
||||
async with self.Session() as session:
|
||||
get_workflow_query = exclude_deleted(
|
||||
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
|
||||
)
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_by_permanent_id")
|
||||
async def get_workflow_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
version: int | None = None,
|
||||
ignore_version: int | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None:
|
||||
get_workflow_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if filter_deleted:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None))
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if version:
|
||||
get_workflow_query = get_workflow_query.filter_by(version=version)
|
||||
if ignore_version:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowModel.version != ignore_version)
|
||||
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
|
||||
async with self.Session() as session:
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_for_workflow_run")
|
||||
async def get_workflow_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None:
|
||||
get_workflow_query = select(WorkflowModel)
|
||||
|
||||
if filter_deleted:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None))
|
||||
|
||||
get_workflow_query = get_workflow_query.join(
|
||||
WorkflowRunModel,
|
||||
WorkflowRunModel.workflow_id == WorkflowModel.workflow_id,
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.workflow_run_id == workflow_run_id)
|
||||
async with self.Session() as session:
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_versions_by_permanent_id")
|
||||
async def get_workflow_versions_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all versions of a workflow by its permanent ID, ordered by version descending (newest first).
|
||||
"""
|
||||
get_workflows_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if filter_deleted:
|
||||
get_workflows_query = get_workflows_query.filter(WorkflowModel.deleted_at.is_(None))
|
||||
if organization_id:
|
||||
get_workflows_query = get_workflows_query.filter_by(organization_id=organization_id)
|
||||
get_workflows_query = get_workflows_query.order_by(WorkflowModel.version.desc())
|
||||
|
||||
async with self.Session() as session:
|
||||
workflows = (await session.scalars(get_workflows_query)).all()
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("get_workflows_by_permanent_ids")
|
||||
async def get_workflows_by_permanent_ids(
|
||||
self,
|
||||
workflow_permanent_ids: list[str],
|
||||
organization_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
title: str = "",
|
||||
statuses: list[WorkflowStatus] | None = None,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all workflows with the latest version for the organization.
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
main_query = select(WorkflowModel).join(
|
||||
subquery,
|
||||
(WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
if organization_id:
|
||||
main_query = main_query.where(WorkflowModel.organization_id == organization_id)
|
||||
if title:
|
||||
main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%"))
|
||||
if statuses:
|
||||
main_query = main_query.where(WorkflowModel.status.in_(statuses))
|
||||
main_query = (
|
||||
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
)
|
||||
workflows = (await session.scalars(main_query)).all()
|
||||
|
||||
# Map template status by permanent_id so API responses surface is_template
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("get_workflows_by_organization_id")
|
||||
async def get_workflows_by_organization_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
only_saved_tasks: bool = False,
|
||||
only_workflows: bool = False,
|
||||
only_templates: bool = False,
|
||||
search_key: str | None = None,
|
||||
folder_id: str | None = None,
|
||||
statuses: list[WorkflowStatus] | None = None,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all workflows with the latest version for the organization.
|
||||
|
||||
Search semantics:
|
||||
- If `search_key` is provided, its value is used as a unified search term for
|
||||
`workflows.title`, `folders.title`, and workflow parameter metadata (key, description, and default_value).
|
||||
- If `search_key` is not provided, no search filtering is applied.
|
||||
- Parameter metadata search excludes soft-deleted parameter rows across parameter tables.
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
main_query = (
|
||||
select(WorkflowModel)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.outerjoin(
|
||||
FolderModel,
|
||||
(WorkflowModel.folder_id == FolderModel.folder_id)
|
||||
& (FolderModel.organization_id == WorkflowModel.organization_id),
|
||||
)
|
||||
)
|
||||
if only_saved_tasks:
|
||||
main_query = main_query.where(WorkflowModel.is_saved_task.is_(True))
|
||||
elif only_workflows:
|
||||
main_query = main_query.where(WorkflowModel.is_saved_task.is_(False))
|
||||
if only_templates:
|
||||
# Filter by workflow_templates table (templates at permanent_id level)
|
||||
template_subquery = select(WorkflowTemplateModel.workflow_permanent_id).where(
|
||||
WorkflowTemplateModel.organization_id == organization_id,
|
||||
WorkflowTemplateModel.deleted_at.is_(None),
|
||||
)
|
||||
main_query = main_query.where(WorkflowModel.workflow_permanent_id.in_(template_subquery))
|
||||
if statuses:
|
||||
main_query = main_query.where(WorkflowModel.status.in_(statuses))
|
||||
if folder_id:
|
||||
main_query = main_query.where(WorkflowModel.folder_id == folder_id)
|
||||
if search_key:
|
||||
search_like = f"%{search_key}%"
|
||||
title_like = WorkflowModel.title.ilike(search_like)
|
||||
folder_title_like = FolderModel.title.ilike(search_like)
|
||||
|
||||
parameter_filters = [
|
||||
# WorkflowParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(WorkflowParameterModel)
|
||||
.where(WorkflowParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(search_like),
|
||||
WorkflowParameterModel.description.ilike(search_like),
|
||||
WorkflowParameterModel.default_value.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# OutputParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(OutputParameterModel)
|
||||
.where(OutputParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(OutputParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
OutputParameterModel.key.ilike(search_like),
|
||||
OutputParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# AWSSecretParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(AWSSecretParameterModel)
|
||||
.where(AWSSecretParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(AWSSecretParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
AWSSecretParameterModel.key.ilike(search_like),
|
||||
AWSSecretParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenLoginCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenLoginCredentialParameterModel)
|
||||
.where(BitwardenLoginCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenLoginCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenLoginCredentialParameterModel.key.ilike(search_like),
|
||||
BitwardenLoginCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenSensitiveInformationParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenSensitiveInformationParameterModel)
|
||||
.where(BitwardenSensitiveInformationParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenSensitiveInformationParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenSensitiveInformationParameterModel.key.ilike(search_like),
|
||||
BitwardenSensitiveInformationParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenCreditCardDataParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenCreditCardDataParameterModel)
|
||||
.where(BitwardenCreditCardDataParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenCreditCardDataParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenCreditCardDataParameterModel.key.ilike(search_like),
|
||||
BitwardenCreditCardDataParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# OnePasswordCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(OnePasswordCredentialParameterModel)
|
||||
.where(OnePasswordCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(OnePasswordCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
OnePasswordCredentialParameterModel.key.ilike(search_like),
|
||||
OnePasswordCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# AzureVaultCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(AzureVaultCredentialParameterModel)
|
||||
.where(AzureVaultCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(AzureVaultCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
AzureVaultCredentialParameterModel.key.ilike(search_like),
|
||||
AzureVaultCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# CredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(CredentialParameterModel)
|
||||
.where(CredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(CredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
CredentialParameterModel.key.ilike(search_like),
|
||||
CredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
]
|
||||
main_query = main_query.where(or_(title_like, folder_title_like, or_(*parameter_filters)))
|
||||
main_query = (
|
||||
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
)
|
||||
workflows = (await session.scalars(main_query)).all()
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("update_workflow")
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
organization_id: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
version: int | None = None,
|
||||
run_with: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
status: str | None = None,
|
||||
import_error: str | None = None,
|
||||
) -> Workflow:
|
||||
async with self.Session() as session:
|
||||
get_workflow_query = exclude_deleted(
|
||||
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
|
||||
)
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
if title is not None:
|
||||
workflow.title = title
|
||||
if description is not None:
|
||||
workflow.description = description
|
||||
if workflow_definition is not None:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
if version is not None:
|
||||
workflow.version = version
|
||||
if run_with is not None:
|
||||
workflow.run_with = run_with
|
||||
if cache_key is not None:
|
||||
workflow.cache_key = cache_key
|
||||
if status is not None:
|
||||
workflow.status = status
|
||||
if import_error is not None:
|
||||
workflow.import_error = import_error
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
else:
|
||||
raise NotFoundError("Workflow not found")
|
||||
|
||||
@db_operation("soft_delete_workflow_and_schedules_by_permanent_id")
|
||||
async def soft_delete_workflow_and_schedules_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Soft-delete a workflow and its active schedules in a single DB transaction."""
|
||||
async with self.Session() as session:
|
||||
select_query = (
|
||||
select(WorkflowScheduleModel.workflow_schedule_id)
|
||||
.where(WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
)
|
||||
if organization_id is not None:
|
||||
select_query = select_query.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
result = await session.execute(select_query)
|
||||
schedule_ids = list(result.scalars().all())
|
||||
|
||||
deleted_at = datetime.now(timezone.utc)
|
||||
if schedule_ids:
|
||||
update_schedules_query = (
|
||||
update(WorkflowScheduleModel)
|
||||
.where(WorkflowScheduleModel.workflow_schedule_id.in_(schedule_ids))
|
||||
.values(deleted_at=deleted_at)
|
||||
)
|
||||
await session.execute(update_schedules_query)
|
||||
|
||||
update_workflow_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
)
|
||||
if organization_id is not None:
|
||||
update_workflow_query = update_workflow_query.filter_by(organization_id=organization_id)
|
||||
await session.execute(update_workflow_query.values(deleted_at=deleted_at))
|
||||
await session.commit()
|
||||
return schedule_ids
|
||||
|
||||
@db_operation("add_workflow_template")
|
||||
async def add_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""Add a workflow to the templates table."""
|
||||
async with self.Session() as session:
|
||||
existing = (
|
||||
await session.scalars(
|
||||
select(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
existing.deleted_at = None
|
||||
await session.commit()
|
||||
return
|
||||
template = WorkflowTemplateModel(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(template)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("remove_workflow_template")
|
||||
async def remove_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""Soft delete a workflow from the templates table."""
|
||||
async with self.Session() as session:
|
||||
update_deleted_at_query = (
|
||||
update(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.execute(update_deleted_at_query)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_org_template_permanent_ids")
|
||||
async def get_org_template_permanent_ids(
|
||||
self,
|
||||
organization_id: str,
|
||||
) -> set[str]:
|
||||
"""Get all workflow_permanent_ids that are templates for an organization."""
|
||||
async with self.Session() as session:
|
||||
result = await session.scalars(
|
||||
select(WorkflowTemplateModel.workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
)
|
||||
return set(result.all())
|
||||
|
||||
@db_operation("is_workflow_template")
|
||||
async def is_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> bool:
|
||||
"""Check if a workflow is marked as a template."""
|
||||
async with self.Session() as session:
|
||||
result = (
|
||||
await session.scalars(
|
||||
select(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
return result is not None
|
||||
0
tests/unit/forge/__init__.py
Normal file
0
tests/unit/forge/__init__.py
Normal file
0
tests/unit/forge/sdk/__init__.py
Normal file
0
tests/unit/forge/sdk/__init__.py
Normal file
0
tests/unit/forge/sdk/db/__init__.py
Normal file
0
tests/unit/forge/sdk/db/__init__.py
Normal file
56
tests/unit/forge/sdk/db/test_base_repository.py
Normal file
56
tests/unit/forge/sdk/db/test_base_repository.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
from skyvern.forge.sdk.db.protocols import RunReader, TaskReader, WorkflowParameterReader, WorkflowReader
|
||||
from skyvern.forge.sdk.db.utils import serialize_proxy_location
|
||||
from skyvern.schemas.runs import GeoTarget, ProxyLocation
|
||||
|
||||
|
||||
class TestSerializeProxyLocation:
|
||||
def test_none(self):
|
||||
assert serialize_proxy_location(None) is None
|
||||
|
||||
def test_geo_target(self):
|
||||
geo = GeoTarget(country="US", state="CA")
|
||||
result = serialize_proxy_location(geo)
|
||||
assert result is not None
|
||||
assert "US" in result
|
||||
|
||||
def test_enum(self):
|
||||
result = serialize_proxy_location(ProxyLocation.RESIDENTIAL)
|
||||
assert result == "RESIDENTIAL"
|
||||
|
||||
|
||||
class TestBaseRepository:
|
||||
def test_init_with_defaults(self):
|
||||
mock_session = MagicMock()
|
||||
repo = BaseRepository(session_factory=mock_session, debug_enabled=True)
|
||||
assert repo.Session is mock_session
|
||||
assert repo.debug_enabled is True
|
||||
|
||||
def test_is_retryable_error_default(self):
|
||||
mock_session = MagicMock()
|
||||
repo = BaseRepository(session_factory=mock_session)
|
||||
error = OperationalError("statement", {}, Exception("server closed the connection unexpectedly"))
|
||||
assert repo.is_retryable_error(error) is False # default returns False
|
||||
|
||||
def test_is_retryable_error_custom(self):
|
||||
mock_session = MagicMock()
|
||||
|
||||
def custom_fn(e):
|
||||
return "closed" in str(e).lower()
|
||||
|
||||
repo = BaseRepository(session_factory=mock_session, is_retryable_error_fn=custom_fn)
|
||||
error = OperationalError("statement", {}, Exception("server closed the connection"))
|
||||
assert repo.is_retryable_error(error) is True
|
||||
|
||||
|
||||
class TestProtocols:
|
||||
def test_protocols_importable(self):
|
||||
"""Protocols should be importable and runtime-checkable."""
|
||||
assert TaskReader is not None
|
||||
assert WorkflowReader is not None
|
||||
assert WorkflowParameterReader is not None
|
||||
assert RunReader is not None
|
||||
238
tests/unit/forge/sdk/db/test_repositories.py
Normal file
238
tests/unit/forge/sdk/db/test_repositories.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""Tests for all 14 OSS repository instantiations + dependency injection."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_credential_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = CredentialRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "create_credential")
|
||||
assert hasattr(repo, "get_credential")
|
||||
assert hasattr(repo, "get_credentials")
|
||||
assert hasattr(repo, "update_credential")
|
||||
assert hasattr(repo, "delete_credential")
|
||||
assert hasattr(repo, "create_organization_bitwarden_collection")
|
||||
assert hasattr(repo, "get_organization_bitwarden_collection")
|
||||
|
||||
|
||||
def test_otp_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.otp import OTPRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = OTPRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "get_otp_codes")
|
||||
assert hasattr(repo, "create_otp_code")
|
||||
|
||||
|
||||
def test_debug_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.debug import DebugRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = DebugRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "get_debug_session")
|
||||
assert hasattr(repo, "create_debug_session")
|
||||
assert hasattr(repo, "create_block_run")
|
||||
|
||||
|
||||
def test_organizations_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.organizations import OrganizationsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = OrganizationsRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "get_organization")
|
||||
assert hasattr(repo, "create_organization")
|
||||
assert hasattr(repo, "create_org_auth_token")
|
||||
assert hasattr(repo, "validate_org_auth_token")
|
||||
|
||||
|
||||
def test_schedules_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = SchedulesRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "create_workflow_schedule")
|
||||
assert hasattr(repo, "get_workflow_schedules")
|
||||
|
||||
|
||||
def test_scripts_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.scripts import ScriptsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = ScriptsRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "create_script")
|
||||
assert hasattr(repo, "get_scripts")
|
||||
|
||||
|
||||
def test_workflow_parameters_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.workflow_parameters import WorkflowParametersRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = WorkflowParametersRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "get_workflow_parameter")
|
||||
assert hasattr(repo, "create_workflow_parameter")
|
||||
|
||||
|
||||
def test_tasks_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.tasks import TasksRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = TasksRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "create_task")
|
||||
assert hasattr(repo, "get_task")
|
||||
assert hasattr(repo, "create_step")
|
||||
|
||||
|
||||
def test_workflows_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.workflows import WorkflowsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = WorkflowsRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "get_workflow")
|
||||
assert hasattr(repo, "create_workflow")
|
||||
assert hasattr(repo, "get_workflow_by_permanent_id")
|
||||
|
||||
|
||||
def test_browser_sessions_repository_instantiation():
|
||||
from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
repo = BrowserSessionsRepository(session_factory=mock_session, debug_enabled=False)
|
||||
assert repo.Session is mock_session
|
||||
assert hasattr(repo, "create_browser_profile")
|
||||
assert hasattr(repo, "get_browser_profile")
|
||||
|
||||
|
||||
# ── Cross-dependency repositories ──
|
||||
|
||||
|
||||
def test_workflow_runs_repository_with_dependency():
|
||||
from skyvern.forge.sdk.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_param_reader = MagicMock()
|
||||
repo = WorkflowRunsRepository(
|
||||
session_factory=mock_session,
|
||||
debug_enabled=False,
|
||||
workflow_parameter_reader=mock_param_reader,
|
||||
)
|
||||
assert repo.Session is mock_session
|
||||
assert repo._workflow_parameter_reader is mock_param_reader
|
||||
assert hasattr(repo, "get_workflow_run_parameters")
|
||||
assert hasattr(repo, "create_workflow_run")
|
||||
assert hasattr(repo, "get_workflow_run")
|
||||
|
||||
|
||||
def test_artifacts_repository_with_dependency():
|
||||
from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_run_reader = MagicMock()
|
||||
repo = ArtifactsRepository(
|
||||
session_factory=mock_session,
|
||||
debug_enabled=False,
|
||||
run_reader=mock_run_reader,
|
||||
)
|
||||
assert repo.Session is mock_session
|
||||
assert repo._run_reader is mock_run_reader
|
||||
assert hasattr(repo, "create_artifact")
|
||||
assert hasattr(repo, "get_artifact")
|
||||
|
||||
|
||||
def test_folders_repository_with_dependency():
|
||||
from skyvern.forge.sdk.db.repositories.folders import FoldersRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_workflow_reader = MagicMock()
|
||||
repo = FoldersRepository(
|
||||
session_factory=mock_session,
|
||||
debug_enabled=False,
|
||||
workflow_reader=mock_workflow_reader,
|
||||
)
|
||||
assert repo.Session is mock_session
|
||||
assert repo._workflow_reader is mock_workflow_reader
|
||||
assert hasattr(repo, "create_folder")
|
||||
assert hasattr(repo, "update_workflow_folder")
|
||||
|
||||
|
||||
def test_observer_repository_with_dependency():
|
||||
from skyvern.forge.sdk.db.repositories.observer import ObserverRepository
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_task_reader = MagicMock()
|
||||
repo = ObserverRepository(
|
||||
session_factory=mock_session,
|
||||
debug_enabled=False,
|
||||
task_reader=mock_task_reader,
|
||||
)
|
||||
assert repo.Session is mock_session
|
||||
assert repo._task_reader is mock_task_reader
|
||||
assert hasattr(repo, "create_workflow_run_block")
|
||||
assert hasattr(repo, "get_workflow_run_blocks")
|
||||
|
||||
|
||||
# ── AgentDB composition test ──
|
||||
|
||||
|
||||
def test_agent_db_has_typed_repo_attributes():
|
||||
"""After refactoring, AgentDB should expose typed repository attributes."""
|
||||
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
|
||||
from skyvern.forge.sdk.db.repositories.tasks import TasksRepository
|
||||
|
||||
with patch("skyvern.forge.sdk.db.agent_db.create_async_engine"):
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
|
||||
db = AgentDB("postgresql+asyncpg://test", debug_enabled=True)
|
||||
assert isinstance(db.tasks, TasksRepository)
|
||||
assert isinstance(db.credentials, CredentialRepository)
|
||||
assert hasattr(db, "get_task") # backward compat delegate
|
||||
assert hasattr(db, "create_workflow")
|
||||
|
||||
|
||||
def test_agent_db_delegates_route_to_repositories():
|
||||
"""Verify delegate methods actually forward to the correct repository."""
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
with mock_patch("skyvern.forge.sdk.db.agent_db.create_async_engine"):
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
|
||||
db = AgentDB("postgresql+asyncpg://test", debug_enabled=False)
|
||||
|
||||
# Patch a method on each major repository and verify the delegate calls it
|
||||
delegates_to_check = [
|
||||
("get_task", "tasks"),
|
||||
("create_workflow", "workflows"),
|
||||
("create_artifact", "artifacts"),
|
||||
("get_organization", "organizations"),
|
||||
("get_credential", "credentials"),
|
||||
("create_workflow_run", "workflow_runs"),
|
||||
]
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
for delegate_name, repo_attr in delegates_to_check:
|
||||
repo = getattr(db, repo_attr)
|
||||
mock_method = AsyncMock(return_value="sentinel")
|
||||
original = getattr(repo, delegate_name)
|
||||
setattr(repo, delegate_name, mock_method)
|
||||
try:
|
||||
result = loop.run_until_complete(getattr(db, delegate_name)("arg1", key="val"))
|
||||
mock_method.assert_called_once_with("arg1", key="val")
|
||||
assert result == "sentinel", f"Delegate {delegate_name} did not return repository result"
|
||||
finally:
|
||||
setattr(repo, delegate_name, original)
|
||||
finally:
|
||||
loop.close()
|
||||
|
|
@ -5,13 +5,15 @@ Guards against:
|
|||
- Re-export contracts breaking (ScheduleLimitExceededError used by cloud/routes/)
|
||||
"""
|
||||
|
||||
from skyvern.forge.sdk.db.base_repository import BaseRepository
|
||||
|
||||
|
||||
def test_agent_db_exports_schedule_limit_exceeded_error() -> None:
|
||||
"""ScheduleLimitExceededError must be importable from agent_db (re-export contract)."""
|
||||
from skyvern.forge.sdk.db.agent_db import ScheduleLimitExceededError
|
||||
|
||||
# Verify it's the canonical class, not a shadow
|
||||
from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError as Original
|
||||
# Verify it's the canonical class from exceptions.py, not a shadow
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError as Original
|
||||
|
||||
assert ScheduleLimitExceededError is Original
|
||||
|
||||
|
|
@ -23,30 +25,35 @@ def test_agent_db_exports_agent_db_class() -> None:
|
|||
assert AgentDB is not None
|
||||
|
||||
|
||||
def test_all_mixins_in_agent_db_mro() -> None:
|
||||
"""All 14 domain mixins must appear in AgentDB's MRO.
|
||||
def test_all_repositories_on_agent_db() -> None:
|
||||
"""All 14 domain repositories must be present as typed attributes on AgentDB.
|
||||
|
||||
If someone re-introduces a late import for any mixin, this test
|
||||
catches it because the mixin won't be in the class hierarchy.
|
||||
After the mixin-to-repository refactor, AgentDB uses composition instead
|
||||
of inheritance. This test verifies every domain repository is wired up
|
||||
by checking that __init__ assigns BaseRepository instances to each expected name.
|
||||
"""
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
|
||||
mro_names = {cls.__name__ for cls in AgentDB.__mro__}
|
||||
expected_mixins = [
|
||||
"TasksMixin",
|
||||
"WorkflowsMixin",
|
||||
"WorkflowRunsMixin",
|
||||
"WorkflowParametersMixin",
|
||||
"SchedulesMixin",
|
||||
"ArtifactsMixin",
|
||||
"BrowserSessionsMixin",
|
||||
"ScriptsMixin",
|
||||
"OTPMixin",
|
||||
"CredentialsMixin",
|
||||
"FoldersMixin",
|
||||
"OrganizationsMixin",
|
||||
"ObserverMixin",
|
||||
"DebugMixin",
|
||||
expected_repos = [
|
||||
"tasks",
|
||||
"workflows",
|
||||
"workflow_runs",
|
||||
"workflow_params",
|
||||
"schedules",
|
||||
"artifacts",
|
||||
"browser_sessions",
|
||||
"scripts",
|
||||
"otp",
|
||||
"credentials",
|
||||
"folders",
|
||||
"organizations",
|
||||
"observer",
|
||||
"debug",
|
||||
]
|
||||
for name in expected_mixins:
|
||||
assert name in mro_names, f"{name} missing from AgentDB MRO"
|
||||
# Instantiate with a dummy database string (sqlite in-memory)
|
||||
db = AgentDB("sqlite+aiosqlite:///", debug_enabled=False)
|
||||
for repo in expected_repos:
|
||||
assert hasattr(db, repo), f"Repository '{repo}' missing from AgentDB instance"
|
||||
assert isinstance(getattr(db, repo), BaseRepository), (
|
||||
f"AgentDB.{repo} should be a BaseRepository, got {type(getattr(db, repo))}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ async def test_db_operation_log_errors_false_suppresses_logging() -> None:
|
|||
@pytest.mark.asyncio
|
||||
async def test_db_operation_schedule_limit_exceeded_is_passthrough() -> None:
|
||||
"""ScheduleLimitExceededError should be treated as business logic, not unexpected."""
|
||||
from skyvern.forge.sdk.db.mixins.schedules import ScheduleLimitExceededError
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError
|
||||
|
||||
class ScheduleDB:
|
||||
@db_operation("create_schedule")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue