AgentDB Phase 7: Migrate remaining 8 domains to typed repos (#5366)

This commit is contained in:
Aaron Perez 2026-04-02 19:36:50 -05:00 committed by GitHub
parent 58fed69496
commit 26b8f4d73e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
84 changed files with 811 additions and 8804 deletions

View file

@ -0,0 +1,60 @@
#!/usr/bin/env bash
# Detect direct calls to AgentDB backward-compatible delegate methods.
# New code must use repository attributes (e.g. db.tasks.create_task)
# instead of the legacy delegates (e.g. db.create_task).
#
# Called by tests/unit/test_no_direct_db_delegates.py
set -euo pipefail
AGENT_DB="skyvern/forge/sdk/db/agent_db.py"
# Extract delegate method names from agent_db.py.
# These are the "async def <name>" lines inside the delegate section (after line 170).
delegate_methods=$(
awk 'NR > 170 && /^ async def / { gsub(/.*async def /,""); gsub(/\(.*/,""); print }' "$AGENT_DB" \
| sort -u
)
if [ -z "$delegate_methods" ]; then
echo "ERROR: Could not extract delegate methods from $AGENT_DB" >&2
exit 1
fi
# Build a grep alternation pattern for delegate method names.
methods_pattern=$(echo "$delegate_methods" | paste -sd'|' -)
# Search for direct delegate calls on known AgentDB access patterns:
# app.DATABASE.<method>( — should be app.DATABASE.<repo>.<method>(
# REPLICA_DATABASE.<method>( — should be REPLICA_DATABASE.<repo>.<method>(
# Exclude the delegate file itself and tests.
db_pattern="(DATABASE|REPLICA_DATABASE)\.(${methods_pattern})\("
# Legacy files that still use direct delegates (grandfathered in).
ALLOWLIST=(
"$AGENT_DB"
"tests/"
"run_streaming.py"
)
exclude_args=()
for allowed in "${ALLOWLIST[@]}"; do
exclude_args+=(":!${allowed}")
done
violations=$(
git grep -n -E "$db_pattern" -- '*.py' "${exclude_args[@]}" \
2>/dev/null \
|| true
)
if [ -n "$violations" ]; then
echo "Direct AgentDB delegate calls found. Use repository attributes instead."
echo " e.g. app.DATABASE.tasks.create_task(...) not app.DATABASE.create_task(...)"
echo ""
echo "$violations"
exit 1
fi
echo "No direct delegate calls found."
exit 0

View file

@ -3296,13 +3296,13 @@ async def create_or_update_script_block(
block_code_bytes = block_code if isinstance(block_code, bytes) else block_code.encode("utf-8")
try:
# Step 3: Create script block in database
script_block = await app.DATABASE.get_script_block_by_label(
script_block = await app.DATABASE.scripts.get_script_block_by_label(
organization_id=organization_id,
script_revision_id=script_revision_id,
script_block_label=block_label,
)
if not script_block:
script_block = await app.DATABASE.create_script_block(
script_block = await app.DATABASE.scripts.create_script_block(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
@ -3318,7 +3318,7 @@ async def create_or_update_script_block(
for value in [run_signature, workflow_run_id, workflow_run_block_id, input_fields, requires_agent]
):
# Update metadata when new values are provided
script_block = await app.DATABASE.update_script_block(
script_block = await app.DATABASE.scripts.update_script_block(
script_block_id=script_block.script_block_id,
organization_id=organization_id,
run_signature=run_signature,
@ -3336,13 +3336,13 @@ async def create_or_update_script_block(
# Create artifact and upload to S3
artifact_id = None
if update and script_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id,
script_block.script_file_id,
organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
asyncio.create_task(app.STORAGE.store_artifact(artifact, block_code_bytes))
else:
@ -3357,7 +3357,7 @@ async def create_or_update_script_block(
)
# Create script file record
script_file = await app.DATABASE.create_script_file(
script_file = await app.DATABASE.scripts.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
@ -3371,7 +3371,7 @@ async def create_or_update_script_block(
)
# update script block with script file id
await app.DATABASE.update_script_block(
await app.DATABASE.scripts.update_script_block(
script_block_id=script_block.script_block_id,
organization_id=organization_id,
script_file_id=script_file.file_id,

View file

@ -174,7 +174,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
organization_id = context.organization_id if context else None
step_id = context.step_id if context else None
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
single_click_prompt = prompt_engine.load_prompt(
template="single-click-action",
navigation_goal=intention,
@ -199,7 +199,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
"The element may not exist on the current page."
)
task_id = context.task_id if context else None
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
if organization_id and task and step:
actions = parse_actions(
task, step.step_id, step.order, self.scraped_page, json_response.get("actions", [])
@ -244,8 +244,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
task_id = context.task_id
step_id = context.step_id
workflow_run_id = context.workflow_run_id
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
if intention:
try:
@ -373,8 +373,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
task_id = context.task_id
step_id = context.step_id
workflow_run_id = context.workflow_run_id
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
if intention:
try:
@ -478,8 +478,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
option_value = value or ""
context = skyvern_context.current()
if context and context.task_id and context.step_id and context.organization_id:
task = await app.DATABASE.get_task(context.task_id, organization_id=context.organization_id)
step = await app.DATABASE.get_step(context.step_id, organization_id=context.organization_id)
task = await app.DATABASE.tasks.get_task(context.task_id, organization_id=context.organization_id)
step = await app.DATABASE.tasks.get_step(context.step_id, organization_id=context.organization_id)
if intention and task and step:
try:
prompt = context.prompt if context else None
@ -610,7 +610,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
step = None
organization_id = context.organization_id if context else None
if context and context.organization_id and context.step_id:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
step_id=context.step_id,
organization_id=context.organization_id,
)
@ -662,7 +662,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
if not context or not context.organization_id or not context.workflow_permanent_id:
return
block_label = self.current_label or "unknown"
await app.DATABASE.record_branch_hit(
await app.DATABASE.scripts.record_branch_hit(
organization_id=context.organization_id,
workflow_permanent_id=context.workflow_permanent_id,
block_label=block_label,
@ -734,7 +734,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
# Record an element fallback episode for the feedback loop
if context.workflow_run_id and context.workflow_permanent_id:
try:
await app.DATABASE.create_fallback_episode(
await app.DATABASE.scripts.create_fallback_episode(
organization_id=context.organization_id,
workflow_permanent_id=context.workflow_permanent_id,
workflow_run_id=context.workflow_run_id,
@ -792,7 +792,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
)
step = None
if context and context.organization_id and context.task_id and context.step_id:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
step_id=context.step_id,
organization_id=context.organization_id,
)
@ -868,7 +868,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
step = None
if context.organization_id and context.task_id and context.step_id:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
step_id=context.step_id,
organization_id=context.organization_id,
)
@ -944,8 +944,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
task_id = context.task_id
step_id = context.step_id
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
if not task or not step:
LOG.warning("ai_act: missing task or step", task_id=task_id, step_id=step_id)

View file

@ -82,7 +82,7 @@ class ScriptSkyvernPage(SkyvernPage):
async def _get_or_create_browser_state(cls, browser_session_id: str | None = None) -> BrowserState:
context = skyvern_context.current()
if context and context.workflow_run_id and context.organization_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=context.workflow_run_id, organization_id=context.organization_id
)
if workflow_run:
@ -101,7 +101,7 @@ class ScriptSkyvernPage(SkyvernPage):
async def _get_browser_state(cls) -> BrowserState | None:
context = skyvern_context.current()
if context and context.workflow_run_id and context.organization_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=context.workflow_run_id, organization_id=context.organization_id
)
if workflow_run:
@ -153,7 +153,7 @@ class ScriptSkyvernPage(SkyvernPage):
organization_id = context.organization_id
download_timeout = BROWSER_DOWNLOAD_TIMEOUT
if context.task_id:
task = await app.DATABASE.get_task(context.task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(context.task_id, organization_id=organization_id)
if task and task.download_timeout:
download_timeout = task.download_timeout
await check_downloading_files_and_wait_for_download_to_complete(
@ -392,7 +392,7 @@ class ScriptSkyvernPage(SkyvernPage):
except Exception:
LOG.warning("Failed to generate action reasoning, using fallback", action_type=action_type)
await app.DATABASE.update_action_reasoning(
await app.DATABASE.workflow_params.update_action_reasoning(
organization_id=organization_id,
action_id=action_id,
reasoning=reasoning,
@ -496,7 +496,7 @@ class ScriptSkyvernPage(SkyvernPage):
created_by="script",
)
created_action = await app.DATABASE.create_action(action)
created_action = await app.DATABASE.workflow_params.create_action(action)
# Skip LLM reasoning in script mode — use static string instead.
# Build a descriptive label from the selector for the timeline.
if context and context.script_mode:
@ -521,7 +521,7 @@ class ScriptSkyvernPage(SkyvernPage):
else:
label = selector[:60]
reasoning = f"Script execution: {label}" if label else "Script execution"
await app.DATABASE.update_action_reasoning(
await app.DATABASE.workflow_params.update_action_reasoning(
organization_id=str(context.organization_id),
action_id=str(created_action.action_id),
reasoning=reasoning,
@ -586,7 +586,7 @@ class ScriptSkyvernPage(SkyvernPage):
screenshot = await browser_state.take_post_action_screenshot(scrolling_number=0)
if screenshot:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
context.step_id,
organization_id=context.organization_id,
)
@ -642,7 +642,7 @@ class ScriptSkyvernPage(SkyvernPage):
html = await skyvern_frame.get_content()
if html:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
context.step_id,
organization_id=context.organization_id,
)
@ -690,7 +690,7 @@ class ScriptSkyvernPage(SkyvernPage):
screenshot = await browser_state.take_fullpage_screenshot()
if screenshot:
step = await app.DATABASE.get_step(
step = await app.DATABASE.tasks.get_step(
context.step_id,
organization_id=context.organization_id,
)
@ -981,8 +981,8 @@ class ScriptSkyvernPage(SkyvernPage):
await app.AGENT_FUNCTION.auto_solve_captchas(self.page)
return None
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
if task and step:
solve_captcha_handler = ActionHandler._handled_action_types[ActionType.SOLVE_CAPTCHA]
action = SolveCaptchaAction(
@ -1014,15 +1014,15 @@ class ScriptSkyvernPage(SkyvernPage):
if context.script_mode:
print(" ⏭ Skipping complete() verification (--no-verify)")
return
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
if task and step:
# CRITICAL: Update step.output with actions_and_results BEFORE validation
# This ensures complete_verify() can access action history (including download info)
# when checking if the goal was achieved
await self._update_step_output_before_complete(context)
# Refresh step to get updated output for validation
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
if not step:
return
@ -1062,8 +1062,8 @@ class ScriptSkyvernPage(SkyvernPage):
msg += ": " + "; ".join(errors)
raise ScriptTerminationException(msg)
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
if task and step:
action = TerminateAction(
organization_id=context.organization_id,
@ -1116,7 +1116,7 @@ class ScriptSkyvernPage(SkyvernPage):
errors=errors,
)
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
step_id=context.step_id,
task_id=context.task_id,
organization_id=context.organization_id,

View file

@ -71,7 +71,7 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
workflow_definition_blocks = workflow.workflow_definition.blocks
# get workflow run blocks for task execution data
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
workflow_run_blocks.sort(key=lambda x: x.created_at)
@ -93,7 +93,7 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
if all_task_ids:
task_ids_list = list(all_task_ids)
# Single query for all tasks
tasks = await app.DATABASE.get_tasks_by_ids(task_ids=task_ids_list, organization_id=organization_id)
tasks = await app.DATABASE.tasks.get_tasks_by_ids(task_ids=task_ids_list, organization_id=organization_id)
tasks_by_id = {task.task_id: task for task in tasks}
LOG.debug(
"Batch fetched tasks for code gen",
@ -102,7 +102,9 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
)
# Single query for all actions (returns desc order for timeline; reverse for chronological)
all_actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids_list, organization_id=organization_id)
all_actions = await app.DATABASE.tasks.get_tasks_actions(
task_ids=task_ids_list, organization_id=organization_id
)
all_actions.reverse()
for action in all_actions:
if action.task_id:

View file

@ -204,7 +204,7 @@ class ForgeAgent:
LOG.info("No browser state found for workflow run, setting task url to empty string")
task_url = ""
task = await app.DATABASE.create_task(
task = await app.DATABASE.tasks.create_task(
url=task_url,
task_type=task_block.task_type,
complete_criterion=task_block.complete_criterion,
@ -244,13 +244,13 @@ class ForgeAgent:
task_retry=task_retry,
)
# Update task status to running
task = await app.DATABASE.update_task(
task = await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
status=TaskStatus.running,
)
step = await app.DATABASE.create_step(
step = await app.DATABASE.tasks.create_step(
task.task_id,
order=0,
retry_index=0,
@ -270,14 +270,14 @@ class ForgeAgent:
totp_verification_url = str(task_request.totp_verification_url) if task_request.totp_verification_url else None
# validate browser session id
if task_request.browser_session_id:
browser_session = await app.DATABASE.get_persistent_browser_session(
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
session_id=task_request.browser_session_id,
organization_id=organization_id,
)
if not browser_session:
raise BrowserSessionNotFound(browser_session_id=task_request.browser_session_id)
task = await app.DATABASE.create_task(
task = await app.DATABASE.tasks.create_task(
url=str(task_request.url),
title=task_request.title,
webhook_callback_url=webhook_callback_url,
@ -345,7 +345,7 @@ class ForgeAgent:
workflow_run: WorkflowRun | None = None
if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=task.workflow_run_id,
organization_id=organization.organization_id,
)
@ -381,7 +381,9 @@ class ForgeAgent:
)
return step, None, None
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id)
refreshed_task = await app.DATABASE.tasks.get_task(
task_id=task.task_id, organization_id=organization.organization_id
)
if refreshed_task:
task = refreshed_task
@ -413,7 +415,7 @@ class ForgeAgent:
or settings.MAX_STEPS_PER_RUN
)
if max_steps_per_run and task.max_steps_per_run != max_steps_per_run:
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=organization.organization_id,
max_steps_per_run=max_steps_per_run,
@ -938,7 +940,7 @@ class ForgeAgent:
# Only pass new errors — update_task() appends to existing errors
if detected_errors:
new_errors = [error.model_dump() for error in detected_errors]
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
errors=new_errors,
@ -1321,7 +1323,7 @@ class ForgeAgent:
action_order=action_idx,
)
detailed_agent_step_output.actions_and_results[action_idx] = (action, [action_result])
action.action_id = (await app.DATABASE.create_action(action=action)).action_id
action.action_id = (await app.DATABASE.workflow_params.create_action(action=action)).action_id
await self.record_artifacts_after_action(task, step, browser_state, engine, action)
break
@ -1570,7 +1572,7 @@ class ForgeAgent:
):
working_page = await browser_state.must_get_working_page()
# refresh task in case the extracted information is updated previously
refreshed_task = await app.DATABASE.get_task(task.task_id, task.organization_id)
refreshed_task = await app.DATABASE.tasks.get_task(task.task_id, task.organization_id)
assert refreshed_task is not None
task = refreshed_task
extract_action = await self.create_extract_action(task, step, scraped_page)
@ -1664,7 +1666,7 @@ class ForgeAgent:
cached_tokens = first_response.usage.input_tokens_details.cached_tokens or 0
reasoning_tokens = first_response.usage.output_tokens_details.reasoning_tokens or 0
llm_cost = (3.0 / 1000000) * input_tokens + (12.0 / 1000000) * output_tokens
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=task.task_id,
step_id=step.step_id,
organization_id=task.organization_id,
@ -1777,7 +1779,7 @@ class ForgeAgent:
cached_tokens = current_response.usage.input_tokens_details.cached_tokens or 0
reasoning_tokens = current_response.usage.output_tokens_details.reasoning_tokens or 0
llm_cost = (3.0 / 1000000) * input_tokens + (12.0 / 1000000) * output_tokens
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=task.task_id,
step_id=step.step_id,
organization_id=task.organization_id,
@ -2107,7 +2109,7 @@ class ForgeAgent:
or incremental_reasoning_tokens is not None
or incremental_cached_tokens is not None
):
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
@ -2452,7 +2454,7 @@ class ForgeAgent:
)
else:
try:
await app.DATABASE.update_action_screenshot_artifact_id(
await app.DATABASE.artifacts.update_action_screenshot_artifact_id(
organization_id=action.organization_id,
action_id=action.action_id,
screenshot_artifact_id=screenshot_artifact_id,
@ -3406,7 +3408,7 @@ class ForgeAgent:
Find the last successful ScrapeAction for the task and return the extracted information.
"""
# TODO: make sure we can get extracted information with the ExtractAction change
steps = await app.DATABASE.get_task_steps(
steps = await app.DATABASE.tasks.get_task_steps(
task_id=task.task_id,
organization_id=task.organization_id,
)
@ -3439,7 +3441,7 @@ class ForgeAgent:
Find the TerminateAction for the task and return the reasoning.
# TODO (kerem): Also return meaningful exceptions when we add them [WYV-311]
"""
steps = await app.DATABASE.get_task_steps(
steps = await app.DATABASE.tasks.get_task_steps(
task_id=task.task_id,
organization_id=task.organization_id,
)
@ -3475,7 +3477,9 @@ class ForgeAgent:
"""
# refresh the task from the db to get the latest status
try:
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
refreshed_task = await app.DATABASE.tasks.get_task(
task_id=task.task_id, organization_id=task.organization_id
)
if not refreshed_task:
LOG.error("Failed to get task from db when clean up task", task_id=task.task_id)
raise TaskNotFound(task_id=task.task_id)
@ -3579,7 +3583,7 @@ class ForgeAgent:
task_id=task.task_id,
)
return
last_step = await app.DATABASE.get_latest_step(task.task_id, organization_id=task.organization_id)
last_step = await app.DATABASE.tasks.get_latest_step(task.task_id, organization_id=task.organization_id)
task_response = await self.build_task_response(task=task, last_step=last_step)
# try to build the new TaskRunResponse for backward compatibility
@ -3622,7 +3626,7 @@ class ForgeAgent:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
webhook_failure_reason="",
@ -3635,7 +3639,7 @@ class ForgeAgent:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",
@ -3665,7 +3669,7 @@ class ForgeAgent:
downloaded_files: list[FileInfo] | None = None
# get the artifact of the screenshot and get the screenshot_url
screenshot_artifact = await app.DATABASE.get_artifact(
screenshot_artifact = await app.DATABASE.artifacts.get_artifact(
task_id=task.task_id,
step_id=last_step.step_id,
artifact_type=ArtifactType.SCREENSHOT_FINAL,
@ -3689,9 +3693,11 @@ class ForgeAgent:
LOG.warning("Timeout getting recordings", browser_session_id=task.browser_session_id)
if recording_url is None:
first_step = await app.DATABASE.get_first_step(task_id=task.task_id, organization_id=task.organization_id)
first_step = await app.DATABASE.tasks.get_first_step(
task_id=task.task_id, organization_id=task.organization_id
)
if first_step:
recording_artifact = await app.DATABASE.get_artifact(
recording_artifact = await app.DATABASE.artifacts.get_artifact(
task_id=task.task_id,
step_id=first_step.step_id,
artifact_type=ArtifactType.RECORDING,
@ -3701,7 +3707,7 @@ class ForgeAgent:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
# get the artifact of the last TASK_RESPONSE_ACTION_SCREENSHOT_COUNT screenshots and get the screenshot_url
latest_action_screenshot_artifacts = await app.DATABASE.get_latest_n_artifacts(
latest_action_screenshot_artifacts = await app.DATABASE.artifacts.get_latest_n_artifacts(
task_id=task.task_id,
organization_id=task.organization_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION],
@ -3737,7 +3743,7 @@ class ForgeAgent:
)
if need_browser_log:
browser_console_log = await app.DATABASE.get_latest_artifact(
browser_console_log = await app.DATABASE.artifacts.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.BROWSER_CONSOLE_LOG],
organization_id=task.organization_id,
@ -3746,7 +3752,7 @@ class ForgeAgent:
browser_console_log_url = await app.ARTIFACT_MANAGER.get_share_link(browser_console_log)
# get the latest task from the db to get the latest status, extracted_information, and failure_reason
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
task_from_db = await app.DATABASE.tasks.get_task(task_id=task.task_id, organization_id=task.organization_id)
if not task_from_db:
LOG.error("Failed to get task from db when sending task response")
raise TaskNotFound(task_id=task.task_id)
@ -3886,7 +3892,7 @@ class ForgeAgent:
await save_step_logs(step.step_id)
return await app.DATABASE.update_step(
return await app.DATABASE.tasks.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
@ -3904,7 +3910,7 @@ class ForgeAgent:
failure_category: list[dict[str, Any]] | None = None,
) -> Task:
# refresh task from db to get the latest status
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
task_from_db = await app.DATABASE.tasks.get_task(task_id=task.task_id, organization_id=task.organization_id)
if task_from_db:
task = task_from_db
@ -3944,7 +3950,7 @@ class ForgeAgent:
await save_task_logs(task.task_id)
LOG.info("Updating task in db", task_id=task.task_id, diff=update_comparison)
return await app.DATABASE.update_task(
return await app.DATABASE.tasks.update_task(
task.task_id,
organization_id=task.organization_id,
**updates,
@ -3990,7 +3996,7 @@ class ForgeAgent:
name=f"verify_goal_{step.step_id}",
)
next_step = await app.DATABASE.create_step(
next_step = await app.DATABASE.tasks.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
@ -4297,7 +4303,7 @@ class ForgeAgent:
step_order=step.order,
step_retry=step.retry_index,
)
next_step = await app.DATABASE.create_step(
next_step = await app.DATABASE.tasks.create_step(
task_id=task.task_id,
organization_id=task.organization_id,
order=step.order,
@ -4316,7 +4322,7 @@ class ForgeAgent:
llm_errors: list[str] = []
try:
steps = await app.DATABASE.get_task_steps(
steps = await app.DATABASE.tasks.get_task_steps(
task_id=task.task_id, organization_id=organization.organization_id
)
for step_cnt, step in enumerate(steps):
@ -4434,7 +4440,7 @@ class ForgeAgent:
steps_without_actions = 0
try:
steps = await app.DATABASE.get_task_steps(
steps = await app.DATABASE.tasks.get_task_steps(
task_id=task.task_id, organization_id=organization.organization_id
)
@ -4722,7 +4728,7 @@ class ForgeAgent:
step_order=step.order,
step_retry=step.retry_index,
)
next_step = await app.DATABASE.create_step(
next_step = await app.DATABASE.tasks.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
@ -4843,7 +4849,7 @@ class ForgeAgent:
if not otp_value and (task.totp_verification_url or task.totp_identifier) and task.organization_id:
workflow_id = workflow_permanent_id = None
if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(task.workflow_run_id)
if workflow_run:
workflow_id = workflow_run.workflow_id
workflow_permanent_id = workflow_run.workflow_permanent_id
@ -4891,7 +4897,7 @@ class ForgeAgent:
@staticmethod
async def get_task_errors(task: Task) -> list[UserDefinedError]:
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
errors = []
for step in steps:
if step.output and step.output.errors:
@ -4907,7 +4913,7 @@ class ForgeAgent:
step_errors = detailed_step_output.extract_errors() or []
task_errors.extend([error.model_dump() for error in step_errors])
return await app.DATABASE.update_task(
return await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
errors=task_errors,
@ -4962,7 +4968,7 @@ class ForgeAgent:
"""
Run the extraction flow when a task with a data extraction goal completes during parallel verification.
"""
refreshed_task = await app.DATABASE.get_task(task.task_id, task.organization_id)
refreshed_task = await app.DATABASE.tasks.get_task(task.task_id, task.organization_id)
if refreshed_task:
task = refreshed_task

View file

@ -452,7 +452,7 @@ class AgentFunction:
if not has_valid_step_status:
reasons.append(f"invalid_step_status:{step.status}")
# can't execute if the task has another step that is running
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
if not has_no_running_steps:
reasons.append(f"another_step_is_running_for_task:{task.task_id}")

View file

@ -836,7 +836,7 @@ class LLMAPIHandlerFactory:
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step and not is_speculative_step:
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
@ -847,7 +847,7 @@ class LLMAPIHandlerFactory:
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
)
if thought:
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
@ -1302,7 +1302,7 @@ class LLMAPIHandlerFactory:
_log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens)
if step and not is_speculative_step:
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
@ -1313,7 +1313,7 @@ class LLMAPIHandlerFactory:
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
)
if thought:
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
@ -1704,7 +1704,7 @@ class LLMCaller:
call_stats = await self.get_call_stats(response)
if step and not is_speculative_step:
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
@ -1715,7 +1715,7 @@ class LLMCaller:
incremental_cached_tokens=call_stats.cached_tokens,
)
if thought:
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=thought.organization_id,
input_token_count=call_stats.input_tokens,

View file

@ -189,7 +189,7 @@ class ArtifactManager:
if not workflow_run_block_id and context:
workflow_run_block_id = context.parent_workflow_run_block_id
artifact = await app.DATABASE.create_artifact(
artifact = await app.DATABASE.artifacts.create_artifact(
artifact_id,
artifact_type,
uri,
@ -521,7 +521,7 @@ class ArtifactManager:
artifact_models = [artifact_data.artifact_model for artifact_data in request.artifacts]
# Bulk insert artifacts
artifacts = await app.DATABASE.bulk_create_artifacts(artifact_models)
artifacts = await app.DATABASE.artifacts.bulk_create_artifacts(artifact_models)
# Fire and forget upload tasks
for artifact, artifact_data in zip(artifacts, request.artifacts):
@ -823,7 +823,7 @@ class ArtifactManager:
) -> None:
if not artifact_id or not organization_id:
return None
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(artifact_id, organization_id)
if not artifact:
return
# Fire and forget
@ -1178,12 +1178,12 @@ class ArtifactManager:
)
for artifact_type, filename, artifact_id in accumulator.member_types
]
await app.DATABASE.bulk_create_artifacts([parent_model, *member_models])
await app.DATABASE.artifacts.bulk_create_artifacts([parent_model, *member_models])
# Apply deferred action.screenshot_artifact_id updates now that artifact rows exist.
for organization_id, action_id, artifact_id in accumulator.pending_action_screenshot_updates:
try:
await app.DATABASE.update_action_screenshot_artifact_id(
await app.DATABASE.artifacts.update_action_screenshot_artifact_id(
organization_id=organization_id,
action_id=action_id,
screenshot_artifact_id=artifact_id,
@ -1267,7 +1267,7 @@ class ArtifactManager:
)
for filename, (artifact_type, _) in entries.items()
]
await app.DATABASE.bulk_create_artifacts([parent_model, *member_models])
await app.DATABASE.artifacts.bulk_create_artifacts([parent_model, *member_models])
async def wait_for_upload_aiotasks(self, primary_keys: list[str]) -> None:
try:

View file

@ -14,6 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from skyvern.config import settings
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError # noqa: F401
from skyvern.forge.sdk.db.models import PersistentBrowserSessionModel
from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository
from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
@ -169,9 +170,6 @@ class AgentDB(BaseAlchemyDB):
# ======================================================================
# Backward-compatible delegate methods
# TODO(SKY-62): These delegates erase type information (*args: Any -> Any).
# Migrate callers to use typed repository attributes directly
# (e.g., db.tasks.get_task(...) instead of db.get_task(...)), then remove.
# ======================================================================
# -- Task delegates --
@ -529,17 +527,17 @@ class AgentDB(BaseAlchemyDB):
async def close_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
return await self.browser_sessions.close_persistent_browser_session(*args, **kwargs)
async def get_all_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
return await self.browser_sessions.get_all_active_persistent_browser_sessions(*args, **kwargs)
async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
return await self.browser_sessions.get_all_active_persistent_browser_sessions()
async def archive_browser_session_address(self, *args: Any, **kwargs: Any) -> Any:
return await self.browser_sessions.archive_browser_session_address(*args, **kwargs)
async def get_uncompleted_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
return await self.browser_sessions.get_uncompleted_persistent_browser_sessions(*args, **kwargs)
async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
return await self.browser_sessions.get_uncompleted_persistent_browser_sessions()
async def get_debug_session_by_browser_session_id(self, *args: Any, **kwargs: Any) -> Any:
return await self.browser_sessions.get_debug_session_by_browser_session_id(*args, **kwargs)
return await self.debug.get_debug_session_by_browser_session_id(*args, **kwargs)
# -- Schedule delegates --

View file

@ -1,39 +0,0 @@
# DEPRECATED: These mixin classes have been superseded by the repository pattern
# in skyvern/forge/sdk/db/repositories/. AgentDB no longer inherits from these
# mixins — it composes repository instances instead. These files are retained
# temporarily to avoid breaking any downstream code that may import from here.
# TODO(SKY-62): Remove after 2026-Q2 once all imports are verified migrated.
from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin
from skyvern.forge.sdk.db.mixins.base import BaseAlchemyDB, read_retry
from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin
from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin
from skyvern.forge.sdk.db.mixins.debug import DebugMixin
from skyvern.forge.sdk.db.mixins.folders import FoldersMixin
from skyvern.forge.sdk.db.mixins.observer import ObserverMixin
from skyvern.forge.sdk.db.mixins.organizations import OrganizationsMixin
from skyvern.forge.sdk.db.mixins.otp import OTPMixin
from skyvern.forge.sdk.db.mixins.schedules import SchedulesMixin
from skyvern.forge.sdk.db.mixins.scripts import ScriptsMixin
from skyvern.forge.sdk.db.mixins.tasks import TasksMixin
from skyvern.forge.sdk.db.mixins.workflow_parameters import WorkflowParametersMixin
from skyvern.forge.sdk.db.mixins.workflow_runs import WorkflowRunsMixin
from skyvern.forge.sdk.db.mixins.workflows import WorkflowsMixin
__all__ = [
"ArtifactsMixin",
"BaseAlchemyDB",
"BrowserSessionsMixin",
"CredentialsMixin",
"DebugMixin",
"FoldersMixin",
"ObserverMixin",
"OrganizationsMixin",
"OTPMixin",
"SchedulesMixin",
"ScriptsMixin",
"TasksMixin",
"WorkflowParametersMixin",
"WorkflowRunsMixin",
"WorkflowsMixin",
"read_retry",
]

View file

@ -1,452 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import and_, delete, or_, select, update
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.models import ActionModel, ArtifactModel
from skyvern.forge.sdk.db.utils import convert_to_artifact
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
from skyvern.forge.sdk.schemas.runs import Run
import structlog
LOG = structlog.get_logger()
class ArtifactsMixin:
"""Database operations for artifact management."""
Session: _SessionFactory
debug_enabled: bool
async def get_run(self, run_id: str, organization_id: str | None = None) -> Run | None:
raise NotImplementedError
@db_operation("create_artifact")
async def create_artifact(
self,
artifact_id: str,
artifact_type: str,
uri: str,
organization_id: str,
step_id: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
task_v2_id: str | None = None,
run_id: str | None = None,
thought_id: str | None = None,
ai_suggestion_id: str | None = None,
) -> Artifact:
async with self.Session() as session:
new_artifact = ArtifactModel(
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
task_id=task_id,
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=task_v2_id,
observer_thought_id=thought_id,
run_id=run_id,
ai_suggestion_id=ai_suggestion_id,
organization_id=organization_id,
)
session.add(new_artifact)
await session.commit()
await session.refresh(new_artifact)
return convert_to_artifact(new_artifact, self.debug_enabled)
@db_operation("bulk_create_artifacts")
async def bulk_create_artifacts(
self,
artifact_models: list[ArtifactModel],
) -> list[Artifact]:
"""
Bulk create multiple artifacts in a single database transaction.
Args:
artifact_models: List of ArtifactModel instances to insert
Returns:
List of created Artifact objects
"""
if not artifact_models:
return []
async with self.Session() as session:
session.add_all(artifact_models)
await session.commit()
# Refresh all artifacts to get their created_at and modified_at values
for artifact in artifact_models:
await session.refresh(artifact)
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifact_models]
@db_operation("get_artifacts_for_task_v2")
async def get_artifacts_for_task_v2(
self,
task_v2_id: str,
organization_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
) -> list[Artifact]:
async with self.Session() as session:
query = (
select(ArtifactModel)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
if artifact_types:
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
query = query.order_by(ArtifactModel.created_at)
if artifacts := (await session.scalars(query)).all():
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
@db_operation("get_artifacts_for_task_step")
async def get_artifacts_for_task_step(
self,
task_id: str,
step_id: str,
organization_id: str | None = None,
) -> list[Artifact]:
async with self.Session() as session:
if artifacts := (
await session.scalars(
select(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.order_by(ArtifactModel.created_at)
)
).all():
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
else:
return []
@db_operation("get_artifacts_for_run")
async def get_artifacts_for_run(
self,
run_id: str,
organization_id: str,
artifact_types: list[ArtifactType] | None = None,
group_by_type: bool = False,
sort_by: str = "created_at",
) -> dict[ArtifactType, list[Artifact]] | list[Artifact]:
"""Return artifacts associated with a run.
Args:
run_id: The ID of the run to get artifacts for
organization_id: The ID of the organization that owns the run
artifact_types: Optional list of artifact types to filter by
group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts.
If False, returns a flat list of artifacts. Defaults to False.
sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'.
Defaults to 'created_at'.
Returns:
If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts.
If group_by_type is False, returns a list of artifacts sorted by the specified field.
Raises:
ValueError: If sort_by is not one of the allowed values
"""
allowed_sort_fields = {"created_at", "step_id", "task_id"}
if sort_by not in allowed_sort_fields:
raise ValueError(f"sort_by must be one of {allowed_sort_fields}")
run = await self.get_run(run_id, organization_id=organization_id)
async with self.Session() as session:
query = select(ArtifactModel).filter_by(organization_id=organization_id)
if run:
# Workflow run — filter by workflow_run_id
query = query.filter_by(workflow_run_id=run.workflow_run_id)
elif run_id.startswith("tsk_"):
# Task run — get_run only handles workflow runs,
# so fall back to filtering by task_id for task-based artifacts
query = query.filter_by(task_id=run_id)
else:
return []
if artifact_types:
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
# Apply sorting
if sort_by == "created_at":
query = query.order_by(ArtifactModel.created_at)
elif sort_by == "step_id":
query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at)
elif sort_by == "task_id":
query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at)
# Execute query and convert to Artifact objects
artifacts = [
convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all()
]
# Group artifacts by type if requested
if group_by_type:
result: dict[ArtifactType, list[Artifact]] = {}
for artifact in artifacts:
if artifact.artifact_type not in result:
result[artifact.artifact_type] = []
result[artifact.artifact_type].append(artifact)
return result
return artifacts
@db_operation("get_artifact_by_id")
async def get_artifact_by_id(
self,
artifact_id: str,
organization_id: str,
) -> Artifact | None:
async with self.Session() as session:
if artifact := (
await session.scalars(
select(ArtifactModel).filter_by(artifact_id=artifact_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
@db_operation("get_artifact_by_id_no_org")
async def get_artifact_by_id_no_org(
self,
artifact_id: str,
) -> Artifact | None:
"""Fetch an artifact by ID without an organization filter.
Only use this when the caller has already verified authorization through
an out-of-band mechanism (e.g. a valid HMAC-signed URL).
"""
async with self.Session() as session:
if artifact := (await session.scalars(select(ArtifactModel).filter_by(artifact_id=artifact_id))).first():
return convert_to_artifact(artifact, self.debug_enabled)
else:
return None
@db_operation("get_artifacts_by_ids")
async def get_artifacts_by_ids(
self,
artifact_ids: list[str],
organization_id: str,
) -> list[Artifact]:
if not artifact_ids:
return []
async with self.Session() as session:
artifacts = (
await session.scalars(
select(ArtifactModel)
.filter(ArtifactModel.artifact_id.in_(artifact_ids))
.filter_by(organization_id=organization_id)
)
).all()
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
@db_operation("get_artifacts_by_entity_id")
async def get_artifacts_by_entity_id(
self,
*,
organization_id: str | None,
artifact_type: ArtifactType | None = None,
task_id: str | None = None,
step_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
thought_id: str | None = None,
task_v2_id: str | None = None,
limit: int | None = None,
) -> list[Artifact]:
async with self.Session() as session:
# Build base query
query = select(ArtifactModel)
if artifact_type is not None:
query = query.filter_by(artifact_type=artifact_type)
if task_id is not None:
query = query.filter_by(task_id=task_id)
if step_id is not None:
query = query.filter_by(step_id=step_id)
if workflow_run_id is not None:
query = query.filter_by(workflow_run_id=workflow_run_id)
if workflow_run_block_id is not None:
query = query.filter_by(workflow_run_block_id=workflow_run_block_id)
if thought_id is not None:
query = query.filter_by(observer_thought_id=thought_id)
if task_v2_id is not None:
query = query.filter_by(observer_cruise_id=task_v2_id)
# Handle backward compatibility where old artifact rows were stored with organization_id NULL
if organization_id is not None:
query = query.filter(
or_(ArtifactModel.organization_id == organization_id, ArtifactModel.organization_id.is_(None))
)
query = query.order_by(ArtifactModel.created_at.desc())
if limit is not None:
query = query.limit(limit)
artifacts = (await session.scalars(query)).all()
LOG.debug("Artifacts fetched", count=len(artifacts))
return [convert_to_artifact(a, self.debug_enabled) for a in artifacts]
@db_operation("get_artifact_by_entity_id")
async def get_artifact_by_entity_id(
self,
*,
artifact_type: ArtifactType,
organization_id: str,
task_id: str | None = None,
step_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
thought_id: str | None = None,
task_v2_id: str | None = None,
) -> Artifact | None:
artifacts = await self.get_artifacts_by_entity_id(
organization_id=organization_id,
artifact_type=artifact_type,
task_id=task_id,
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
thought_id=thought_id,
task_v2_id=task_v2_id,
limit=1,
)
return artifacts[0] if artifacts else None
@db_operation("get_artifact")
async def get_artifact(
self,
task_id: str,
step_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
async with self.Session() as session:
artifact = (
await session.scalars(
select(ArtifactModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
.filter_by(artifact_type=artifact_type)
.order_by(ArtifactModel.created_at.desc())
)
).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
@db_operation("get_artifact_for_run")
async def get_artifact_for_run(
self,
run_id: str,
artifact_type: ArtifactType,
organization_id: str | None = None,
) -> Artifact | None:
async with self.Session() as session:
artifact = (
await session.scalars(
select(ArtifactModel)
.filter(ArtifactModel.run_id == run_id)
.filter(ArtifactModel.artifact_type == artifact_type)
.filter(ArtifactModel.organization_id == organization_id)
.order_by(ArtifactModel.created_at.desc())
)
).first()
if artifact:
return convert_to_artifact(artifact, self.debug_enabled)
return None
@db_operation("get_latest_artifact")
async def get_latest_artifact(
self,
task_id: str,
step_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
organization_id: str | None = None,
) -> Artifact | None:
artifacts = await self.get_latest_n_artifacts(
task_id=task_id,
step_id=step_id,
artifact_types=artifact_types,
organization_id=organization_id,
n=1,
)
if artifacts:
return artifacts[0]
return None
@db_operation("get_latest_n_artifacts")
async def get_latest_n_artifacts(
self,
task_id: str,
step_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
organization_id: str | None = None,
n: int = 1,
) -> list[Artifact] | None:
async with self.Session() as session:
artifact_query = select(ArtifactModel).filter_by(task_id=task_id)
if organization_id:
artifact_query = artifact_query.filter_by(organization_id=organization_id)
if step_id:
artifact_query = artifact_query.filter_by(step_id=step_id)
if artifact_types:
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
artifacts = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).fetchmany(n)
if artifacts:
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
return None
@db_operation("delete_task_artifacts")
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
async with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.organization_id == organization_id,
ArtifactModel.task_id == task_id,
)
)
await session.execute(stmt)
await session.commit()
@db_operation("delete_task_v2_artifacts")
async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.observer_cruise_id == task_v2_id,
ArtifactModel.organization_id == organization_id,
)
)
await session.execute(stmt)
await session.commit()
@db_operation("update_action_screenshot_artifact_id")
async def update_action_screenshot_artifact_id(
self, *, organization_id: str, action_id: str, screenshot_artifact_id: str
) -> None:
async with self.Session() as session:
await session.execute(
update(ActionModel)
.where(ActionModel.action_id == action_id, ActionModel.organization_id == organization_id)
.values(screenshot_artifact_id=screenshot_artifact_id)
)
await session.commit()

View file

@ -1,12 +0,0 @@
"""Shared dependencies for mixin modules.
This module is the single import point for base classes and utilities that
sibling mixin modules need (e.g. ``BaseAlchemyDB``, ``read_retry``).
Centralising the import here means every mixin can do
``from .base import BaseAlchemyDB, read_retry`` instead of reaching up to
the parent package, keeping coupling explicit and consistent.
"""
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry
__all__ = ["BaseAlchemyDB", "read_retry"]

View file

@ -1,518 +0,0 @@
from __future__ import annotations
import uuid
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from sqlalchemy import asc, case, select
from skyvern.exceptions import BrowserProfileNotFound
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.mixins.base import read_retry
from skyvern.forge.sdk.db.models import (
BrowserProfileModel,
DebugSessionModel,
PersistentBrowserSessionModel,
WorkflowRunModel,
)
from skyvern.forge.sdk.db.utils import convert_to_workflow_run, serialize_proxy_location
from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
Extensions,
PersistentBrowserSession,
PersistentBrowserType,
)
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
import structlog
LOG = structlog.get_logger()
class BrowserSessionsMixin:
"""Database operations for browser profiles and persistent browser sessions.
.. deprecated::
This mixin is part of the legacy database layer. New code should use the
repository classes in ``skyvern.forge.sdk.db.repositories`` instead.
Cross-mixin migrations already completed:
- ``get_last_workflow_run_for_browser_session`` ``WorkflowRunsRepository``
(queries workflow runs as the primary entity, browser session is just a filter).
"""
Session: _SessionFactory
@db_operation("get_last_workflow_run_for_browser_session")
async def get_last_workflow_run_for_browser_session(
self,
browser_session_id: str,
organization_id: str | None = None,
) -> WorkflowRun | None:
# Deprecated: moved to WorkflowRunsRepository.get_last_workflow_run_for_browser_session
# (skyvern/forge/sdk/db/repositories/workflow_runs.py). The primary entity is the
# workflow run, not the browser session. This copy remains for legacy mixin compatibility.
async with self.Session() as session:
# check if there's a queued run
query = select(WorkflowRunModel).filter_by(browser_session_id=browser_session_id)
if organization_id:
query = query.filter_by(organization_id=organization_id)
queue_query = query.filter_by(status=WorkflowRunStatus.queued)
queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc())
workflow_run = (await session.scalars(queue_query)).first()
if workflow_run:
return convert_to_workflow_run(workflow_run)
# check if there's a running run
running_query = query.filter_by(status=WorkflowRunStatus.running)
running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None))
running_query = running_query.order_by(WorkflowRunModel.started_at.desc())
workflow_run = (await session.scalars(running_query)).first()
if workflow_run:
return convert_to_workflow_run(workflow_run)
return None
@db_operation("create_browser_profile")
async def create_browser_profile(
self,
organization_id: str,
name: str,
description: str | None = None,
) -> BrowserProfile:
async with self.Session() as session:
browser_profile = BrowserProfileModel(
organization_id=organization_id,
name=name,
description=description,
)
session.add(browser_profile)
await session.commit()
await session.refresh(browser_profile)
return BrowserProfile.model_validate(browser_profile)
@db_operation("get_browser_profile")
async def get_browser_profile(
self,
profile_id: str,
organization_id: str,
include_deleted: bool = False,
) -> BrowserProfile | None:
async with self.Session() as session:
query = (
select(BrowserProfileModel)
.filter_by(browser_profile_id=profile_id)
.filter_by(organization_id=organization_id)
)
if not include_deleted:
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
browser_profile = (await session.scalars(query)).first()
if not browser_profile:
return None
return BrowserProfile.model_validate(browser_profile)
@db_operation("list_browser_profiles")
async def list_browser_profiles(
self,
organization_id: str,
include_deleted: bool = False,
) -> list[BrowserProfile]:
async with self.Session() as session:
query = select(BrowserProfileModel).filter_by(organization_id=organization_id)
if not include_deleted:
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
browser_profiles = await session.scalars(query.order_by(asc(BrowserProfileModel.created_at)))
return [BrowserProfile.model_validate(profile) for profile in browser_profiles.all()]
@db_operation("delete_browser_profile")
async def delete_browser_profile(
self,
profile_id: str,
organization_id: str,
) -> None:
async with self.Session() as session:
query = (
select(BrowserProfileModel)
.filter_by(browser_profile_id=profile_id)
.filter_by(organization_id=organization_id)
.filter(BrowserProfileModel.deleted_at.is_(None))
)
browser_profile = (await session.scalars(query)).first()
if not browser_profile:
raise BrowserProfileNotFound(profile_id=profile_id, organization_id=organization_id)
browser_profile.deleted_at = datetime.utcnow()
await session.commit()
@db_operation("get_active_persistent_browser_sessions")
async def get_active_persistent_browser_sessions(
self,
organization_id: str,
active_hours: int = 24,
) -> list[PersistentBrowserSession]:
"""Get all active persistent browser sessions for an organization."""
async with self.Session() as session:
result = await session.execute(
select(PersistentBrowserSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter_by(completed_at=None)
.filter(PersistentBrowserSessionModel.created_at > datetime.utcnow() - timedelta(hours=active_hours))
)
sessions = result.scalars().all()
return [PersistentBrowserSession.model_validate(session) for session in sessions]
@db_operation("get_persistent_browser_sessions_history")
async def get_persistent_browser_sessions_history(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
lookback_hours: int = 24 * 7,
) -> list[PersistentBrowserSession]:
"""Get persistent browser sessions history for an organization."""
async with self.Session() as session:
open_first = case(
(
PersistentBrowserSessionModel.status == "running",
0, # open
),
else_=1, # not open
)
result = await session.execute(
select(PersistentBrowserSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter(
PersistentBrowserSessionModel.created_at > (datetime.utcnow() - timedelta(hours=lookback_hours))
)
.order_by(
open_first.asc(), # open sessions first
PersistentBrowserSessionModel.created_at.desc(), # then newest within each group
)
.offset((page - 1) * page_size)
.limit(page_size)
)
sessions = result.scalars().all()
return [PersistentBrowserSession.model_validate(session) for session in sessions]
@read_retry()
@db_operation("get_persistent_browser_session_by_runnable_id", log_errors=False)
async def get_persistent_browser_session_by_runnable_id(
self, runnable_id: str, organization_id: str | None = None
) -> PersistentBrowserSession | None:
"""Get a specific persistent browser session."""
async with self.Session() as session:
query = (
select(PersistentBrowserSessionModel)
.filter_by(runnable_id=runnable_id)
.filter_by(deleted_at=None)
.filter_by(completed_at=None)
)
if organization_id:
query = query.filter_by(organization_id=organization_id)
persistent_browser_session = (await session.scalars(query)).first()
if persistent_browser_session:
return PersistentBrowserSession.model_validate(persistent_browser_session)
return None
@db_operation("get_persistent_browser_session")
async def get_persistent_browser_session(
self,
session_id: str,
organization_id: str | None = None,
) -> PersistentBrowserSession | None:
"""Get a specific persistent browser session."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
return PersistentBrowserSession.model_validate(persistent_browser_session)
return None
@db_operation("create_persistent_browser_session")
async def create_persistent_browser_session(
self,
organization_id: str,
runnable_type: str | None = None,
runnable_id: str | None = None,
timeout_minutes: int | None = None,
proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL,
extensions: list[Extensions] | None = None,
browser_type: PersistentBrowserType | None = None,
browser_profile_id: str | None = None,
) -> PersistentBrowserSession:
"""Create a new persistent browser session."""
extensions_str: list[str] | None = (
[extension.value for extension in extensions] if extensions is not None else None
)
async with self.Session() as session:
browser_session = PersistentBrowserSessionModel(
organization_id=organization_id,
runnable_type=runnable_type,
runnable_id=runnable_id,
timeout_minutes=timeout_minutes,
proxy_location=serialize_proxy_location(proxy_location),
extensions=extensions_str,
browser_type=browser_type.value if browser_type else None,
browser_profile_id=browser_profile_id,
)
session.add(browser_session)
await session.commit()
await session.refresh(browser_session)
return PersistentBrowserSession.model_validate(browser_session)
@db_operation("update_persistent_browser_session")
async def update_persistent_browser_session(
self,
browser_session_id: str,
*,
status: str | None = None,
timeout_minutes: int | None = None,
organization_id: str | None = None,
completed_at: datetime | None = None,
started_at: datetime | None = None,
) -> PersistentBrowserSession:
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if not persistent_browser_session:
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
if status:
persistent_browser_session.status = status
if timeout_minutes:
persistent_browser_session.timeout_minutes = timeout_minutes
if completed_at:
persistent_browser_session.completed_at = completed_at
if started_at:
persistent_browser_session.started_at = started_at
await session.commit()
await session.refresh(persistent_browser_session)
return PersistentBrowserSession.model_validate(persistent_browser_session)
@db_operation("set_persistent_browser_session_browser_address")
async def set_persistent_browser_session_browser_address(
self,
browser_session_id: str,
browser_address: str | None,
ip_address: str | None,
ecs_task_arn: str | None,
organization_id: str | None = None,
) -> None:
"""Set the browser address for a persistent browser session."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
if browser_address:
persistent_browser_session.browser_address = browser_address
# once the address is set, the session is started
persistent_browser_session.started_at = datetime.utcnow()
if ip_address:
persistent_browser_session.ip_address = ip_address
if ecs_task_arn:
persistent_browser_session.ecs_task_arn = ecs_task_arn
await session.commit()
await session.refresh(persistent_browser_session)
else:
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
@db_operation("update_persistent_browser_session_compute_cost")
async def update_persistent_browser_session_compute_cost(
self,
session_id: str,
organization_id: str,
instance_type: str,
vcpu_millicores: int,
memory_mb: int,
duration_ms: int,
compute_cost: float,
) -> None:
"""Update the compute cost fields for a persistent browser session"""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
persistent_browser_session.instance_type = instance_type
persistent_browser_session.vcpu_millicores = vcpu_millicores
persistent_browser_session.memory_mb = memory_mb
persistent_browser_session.duration_ms = duration_ms
persistent_browser_session.compute_cost = compute_cost
await session.commit()
await session.refresh(persistent_browser_session)
else:
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
@db_operation("mark_persistent_browser_session_deleted")
async def mark_persistent_browser_session_deleted(self, session_id: str, organization_id: str) -> None:
"""Mark a persistent browser session as deleted."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
)
).first()
if persistent_browser_session:
persistent_browser_session.deleted_at = datetime.utcnow()
await session.commit()
await session.refresh(persistent_browser_session)
else:
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
@db_operation("occupy_persistent_browser_session")
async def occupy_persistent_browser_session(
self, session_id: str, runnable_type: str, runnable_id: str, organization_id: str
) -> None:
"""Occupy a specific persistent browser session."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
persistent_browser_session.runnable_type = runnable_type
persistent_browser_session.runnable_id = runnable_id
await session.commit()
await session.refresh(persistent_browser_session)
else:
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
@db_operation("release_persistent_browser_session")
async def release_persistent_browser_session(
self,
session_id: str,
organization_id: str,
) -> PersistentBrowserSession:
"""Release a specific persistent browser session."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
persistent_browser_session.runnable_type = None
persistent_browser_session.runnable_id = None
await session.commit()
await session.refresh(persistent_browser_session)
return PersistentBrowserSession.model_validate(persistent_browser_session)
else:
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
@db_operation("close_persistent_browser_session")
async def close_persistent_browser_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession:
"""Close a specific persistent browser session."""
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if persistent_browser_session:
if persistent_browser_session.completed_at:
return PersistentBrowserSession.model_validate(persistent_browser_session)
persistent_browser_session.completed_at = datetime.utcnow()
persistent_browser_session.status = "completed"
await session.commit()
await session.refresh(persistent_browser_session)
return PersistentBrowserSession.model_validate(persistent_browser_session)
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
@db_operation("get_all_active_persistent_browser_sessions")
async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
"""Get all active persistent browser sessions across all organizations."""
async with self.Session() as session:
result = await session.execute(select(PersistentBrowserSessionModel).filter_by(deleted_at=None))
return result.scalars().all()
@db_operation("archive_browser_session_address")
async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None:
"""Suffix browser_address with a unique tag so the unique constraint
no longer blocks new sessions that reuse the same local address."""
async with self.Session() as session:
row = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if not row or not row.browser_address:
return
if "::closed::" in row.browser_address:
return
row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}"
await session.commit()
@db_operation("get_uncompleted_persistent_browser_sessions")
async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
"""Get all browser sessions that have not been completed or deleted."""
async with self.Session() as session:
result = await session.execute(
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
)
return result.scalars().all()
@db_operation("get_debug_session_by_browser_session_id")
async def get_debug_session_by_browser_session_id(
self,
browser_session_id: str,
organization_id: str,
) -> DebugSession | None:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
model = (await session.scalars(query)).first()
return DebugSession.model_validate(model) if model else None

View file

@ -1,218 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import select
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import CredentialModel, OrganizationBitwardenCollectionModel
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class CredentialsMixin:
"""Database operations for credential and Bitwarden collection management."""
Session: _SessionFactory
@db_operation("create_credential")
async def create_credential(
self,
organization_id: str,
name: str,
vault_type: CredentialVaultType,
item_id: str,
credential_type: CredentialType,
username: str | None,
totp_type: str,
card_last4: str | None,
card_brand: str | None,
totp_identifier: str | None = None,
secret_label: str | None = None,
) -> Credential:
async with self.Session() as session:
credential = CredentialModel(
organization_id=organization_id,
name=name,
vault_type=vault_type,
item_id=item_id,
credential_type=credential_type,
username=username,
totp_type=totp_type,
totp_identifier=totp_identifier,
card_last4=card_last4,
card_brand=card_brand,
secret_label=secret_label,
)
session.add(credential)
await session.commit()
await session.refresh(credential)
return Credential.model_validate(credential)
@db_operation("get_credential")
async def get_credential(self, credential_id: str, organization_id: str) -> Credential | None:
async with self.Session() as session:
credential = (
await session.scalars(
select(CredentialModel)
.filter_by(credential_id=credential_id)
.filter_by(organization_id=organization_id)
.filter(CredentialModel.deleted_at.is_(None))
)
).first()
if credential:
return Credential.model_validate(credential)
return None
@db_operation("get_credentials")
async def get_credentials(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
vault_type: str | None = None,
) -> list[Credential]:
async with self.Session() as session:
query = (
select(CredentialModel)
.filter_by(organization_id=organization_id)
.filter(CredentialModel.deleted_at.is_(None))
)
if vault_type is not None:
query = query.filter(CredentialModel.vault_type == vault_type)
credentials = (
await session.scalars(
query.order_by(CredentialModel.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
)
).all()
return [Credential.model_validate(credential) for credential in credentials]
@db_operation("update_credential")
async def update_credential(
self,
credential_id: str,
organization_id: str,
name: str | None = None,
browser_profile_id: str | None = None,
tested_url: str | None = None,
user_context: str | None = None,
save_browser_session_intent: bool | None = None,
) -> Credential:
async with self.Session() as session:
credential = (
await session.scalars(
select(CredentialModel)
.filter_by(credential_id=credential_id)
.filter_by(organization_id=organization_id)
.filter(CredentialModel.deleted_at.is_(None))
)
).first()
if not credential:
raise NotFoundError(f"Credential {credential_id} not found")
if name is not None:
credential.name = name
if browser_profile_id is not None:
credential.browser_profile_id = browser_profile_id
if tested_url is not None:
credential.tested_url = tested_url
if user_context is not None:
credential.user_context = user_context
if save_browser_session_intent is not None:
credential.save_browser_session_intent = save_browser_session_intent
await session.commit()
await session.refresh(credential)
return Credential.model_validate(credential)
@db_operation("update_credential_vault_data")
async def update_credential_vault_data(
self,
credential_id: str,
organization_id: str,
item_id: str,
name: str,
credential_type: CredentialType,
username: str | None = None,
totp_type: str = "none",
totp_identifier: str | None = None,
card_last4: str | None = None,
card_brand: str | None = None,
secret_label: str | None = None,
) -> Credential:
async with self.Session() as session:
credential = (
await session.scalars(
select(CredentialModel)
.filter_by(credential_id=credential_id)
.filter_by(organization_id=organization_id)
.filter(CredentialModel.deleted_at.is_(None))
.with_for_update()
)
).first()
if not credential:
raise NotFoundError(f"Credential {credential_id} not found")
credential.item_id = item_id
credential.name = name
credential.credential_type = credential_type
credential.username = username
credential.totp_type = totp_type
credential.totp_identifier = totp_identifier
credential.card_last4 = card_last4
credential.card_brand = card_brand
credential.secret_label = secret_label
await session.commit()
await session.refresh(credential)
return Credential.model_validate(credential)
@db_operation("delete_credential")
async def delete_credential(self, credential_id: str, organization_id: str) -> None:
async with self.Session() as session:
credential = (
await session.scalars(
select(CredentialModel)
.filter_by(credential_id=credential_id)
.filter_by(organization_id=organization_id)
)
).first()
if not credential:
raise NotFoundError(f"Credential {credential_id} not found")
credential.deleted_at = datetime.utcnow()
await session.commit()
await session.refresh(credential)
return None
@db_operation("create_organization_bitwarden_collection")
async def create_organization_bitwarden_collection(
self,
organization_id: str,
collection_id: str,
) -> OrganizationBitwardenCollection:
async with self.Session() as session:
organization_bitwarden_collection = OrganizationBitwardenCollectionModel(
organization_id=organization_id, collection_id=collection_id
)
session.add(organization_bitwarden_collection)
await session.commit()
await session.refresh(organization_bitwarden_collection)
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
@db_operation("get_organization_bitwarden_collection")
async def get_organization_bitwarden_collection(
self,
organization_id: str,
) -> OrganizationBitwardenCollection | None:
async with self.Session() as session:
organization_bitwarden_collection = (
await session.scalars(
select(OrganizationBitwardenCollectionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if organization_bitwarden_collection:
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
return None

View file

@ -1,259 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import select
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.models import (
BlockRunModel,
DebugSessionModel,
WorkflowRunModel,
)
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession, DebugSessionRun
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class DebugMixin:
Session: _SessionFactory
"""Database operations for debug sessions and block runs."""
@db_operation("get_debug_session")
async def get_debug_session(
self,
*,
organization_id: str,
user_id: str,
workflow_permanent_id: str,
) -> DebugSession | None:
async with self.Session() as session:
debug_session = (
await session.scalars(
select(DebugSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(user_id=user_id)
.filter_by(deleted_at=None)
.filter_by(status="created")
.order_by(DebugSessionModel.created_at.desc())
)
).first()
if not debug_session:
return None
return DebugSession.model_validate(debug_session)
@db_operation("get_latest_block_run")
async def get_latest_block_run(
self,
*,
organization_id: str,
user_id: str,
block_label: str,
) -> BlockRun | None:
async with self.Session() as session:
query = (
select(BlockRunModel)
.filter_by(organization_id=organization_id)
.filter_by(user_id=user_id)
.filter_by(block_label=block_label)
.order_by(BlockRunModel.created_at.desc())
)
model = (await session.scalars(query)).first()
return BlockRun.model_validate(model) if model else None
@db_operation("get_latest_completed_block_run")
async def get_latest_completed_block_run(
self,
*,
organization_id: str,
user_id: str,
block_label: str,
workflow_permanent_id: str,
) -> BlockRun | None:
async with self.Session() as session:
query = (
select(BlockRunModel)
.join(WorkflowRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.filter(BlockRunModel.organization_id == organization_id)
.filter(BlockRunModel.user_id == user_id)
.filter(BlockRunModel.block_label == block_label)
.filter(WorkflowRunModel.status == WorkflowRunStatus.completed)
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
.order_by(BlockRunModel.created_at.desc())
)
model = (await session.scalars(query)).first()
return BlockRun.model_validate(model) if model else None
@db_operation("create_block_run")
async def create_block_run(
self,
*,
organization_id: str,
user_id: str,
block_label: str,
output_parameter_id: str,
workflow_run_id: str,
) -> None:
async with self.Session() as session:
block_run = BlockRunModel(
organization_id=organization_id,
user_id=user_id,
block_label=block_label,
output_parameter_id=output_parameter_id,
workflow_run_id=workflow_run_id,
)
session.add(block_run)
await session.commit()
@db_operation("get_latest_debug_session_for_user")
async def get_latest_debug_session_for_user(
self,
*,
organization_id: str,
user_id: str,
workflow_permanent_id: str,
) -> DebugSession | None:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter_by(status="created")
.filter_by(user_id=user_id)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.order_by(DebugSessionModel.created_at.desc())
)
model = (await session.scalars(query)).first()
return DebugSession.model_validate(model) if model else None
@db_operation("get_debug_session_by_id")
async def get_debug_session_by_id(
self,
debug_session_id: str,
organization_id: str,
) -> DebugSession | None:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter_by(debug_session_id=debug_session_id)
)
model = (await session.scalars(query)).first()
return DebugSession.model_validate(model) if model else None
@db_operation("get_workflow_runs_by_debug_session_id")
async def get_workflow_runs_by_debug_session_id(
self,
debug_session_id: str,
organization_id: str,
) -> list[DebugSessionRun]:
async with self.Session() as session:
query = (
select(WorkflowRunModel, BlockRunModel)
.join(BlockRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.filter(WorkflowRunModel.organization_id == organization_id)
.filter(WorkflowRunModel.debug_session_id == debug_session_id)
.order_by(WorkflowRunModel.created_at.desc())
)
results = (await session.execute(query)).all()
debug_session_runs = []
for workflow_run, block_run in results:
debug_session_runs.append(
DebugSessionRun(
ai_fallback=workflow_run.ai_fallback,
block_label=block_run.block_label,
browser_session_id=workflow_run.browser_session_id,
code_gen=workflow_run.code_gen,
debug_session_id=workflow_run.debug_session_id,
failure_reason=workflow_run.failure_reason,
output_parameter_id=block_run.output_parameter_id,
run_with=workflow_run.run_with,
script_run_id=workflow_run.script_run.get("script_run_id") if workflow_run.script_run else None,
status=workflow_run.status,
workflow_id=workflow_run.workflow_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
workflow_run_id=workflow_run.workflow_run_id,
created_at=workflow_run.created_at,
queued_at=workflow_run.queued_at,
started_at=workflow_run.started_at,
finished_at=workflow_run.finished_at,
)
)
return debug_session_runs
@db_operation("complete_debug_sessions")
async def complete_debug_sessions(
self,
*,
organization_id: str,
user_id: str | None = None,
workflow_permanent_id: str | None = None,
) -> list[DebugSession]:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter_by(status="created")
)
if user_id:
query = query.filter_by(user_id=user_id)
if workflow_permanent_id:
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
models = (await session.scalars(query)).all()
for model in models:
model.status = "completed"
debug_sessions = [DebugSession.model_validate(model) for model in models]
await session.commit()
return debug_sessions
@db_operation("create_debug_session")
async def create_debug_session(
self,
*,
browser_session_id: str,
organization_id: str,
user_id: str,
workflow_permanent_id: str,
vnc_streaming_supported: bool,
) -> DebugSession:
async with self.Session() as session:
debug_session = DebugSessionModel(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
user_id=user_id,
browser_session_id=browser_session_id,
vnc_streaming_supported=vnc_streaming_supported,
status="created",
)
session.add(debug_session)
await session.commit()
await session.refresh(debug_session)
return DebugSession.model_validate(debug_session)

View file

@ -1,368 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import func, or_, select, update
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.models import FolderModel, WorkflowModel
from skyvern.forge.sdk.db.utils import convert_to_workflow
from skyvern.forge.sdk.workflow.models.workflow import Workflow
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class FoldersMixin:
"""Database operations for folder management."""
Session: _SessionFactory
debug_enabled: bool
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
ignore_version: int | None = None,
filter_deleted: bool = True,
) -> Workflow | None:
raise NotImplementedError
@db_operation("create_folder")
async def create_folder(
self,
organization_id: str,
title: str,
description: str | None = None,
) -> FolderModel:
"""Create a new folder."""
async with self.Session() as session:
folder = FolderModel(
organization_id=organization_id,
title=title,
description=description,
)
session.add(folder)
await session.commit()
await session.refresh(folder)
return folder
@db_operation("get_folders")
async def get_folders(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
search_query: str | None = None,
) -> list[FolderModel]:
"""Get all folders for an organization with pagination and optional search."""
async with self.Session() as session:
stmt = (
select(FolderModel).filter_by(organization_id=organization_id).filter(FolderModel.deleted_at.is_(None))
)
if search_query:
search_pattern = f"%{search_query}%"
stmt = stmt.filter(
or_(
FolderModel.title.ilike(search_pattern),
FolderModel.description.ilike(search_pattern),
)
)
stmt = stmt.order_by(FolderModel.modified_at.desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(stmt)
return list(result.scalars().all())
@db_operation("get_folder")
async def get_folder(
self,
folder_id: str,
organization_id: str,
) -> FolderModel | None:
"""Get a folder by ID."""
async with self.Session() as session:
stmt = (
select(FolderModel)
.filter_by(folder_id=folder_id, organization_id=organization_id)
.filter(FolderModel.deleted_at.is_(None))
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
@db_operation("update_folder")
async def update_folder(
self,
folder_id: str,
organization_id: str,
title: str | None = None,
description: str | None = None,
) -> FolderModel | None:
"""Update a folder's title or description."""
async with self.Session() as session:
stmt = (
select(FolderModel)
.filter_by(folder_id=folder_id, organization_id=organization_id)
.filter(FolderModel.deleted_at.is_(None))
)
result = await session.execute(stmt)
folder = result.scalar_one_or_none()
if not folder:
return None
if title is not None:
folder.title = title
if description is not None:
folder.description = description
folder.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(folder)
return folder
@db_operation("get_workflow_permanent_ids_in_folder")
async def get_workflow_permanent_ids_in_folder(
self,
folder_id: str,
organization_id: str,
) -> list[str]:
"""Get workflow permanent IDs (latest versions only) in a folder."""
async with self.Session() as session:
# Subquery to get the latest version for each workflow
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
# Get workflow_permanent_ids where the latest version is in this folder
stmt = (
select(WorkflowModel.workflow_permanent_id)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.where(WorkflowModel.folder_id == folder_id)
)
result = await session.execute(stmt)
return list(result.scalars().all())
@db_operation("soft_delete_folder")
async def soft_delete_folder(
self,
folder_id: str,
organization_id: str,
delete_workflows: bool = False,
) -> bool:
"""Soft delete a folder. Optionally delete all workflows in the folder."""
async with self.Session() as session:
# Check if folder exists
folder_stmt = (
select(FolderModel)
.filter_by(folder_id=folder_id, organization_id=organization_id)
.filter(FolderModel.deleted_at.is_(None))
)
folder_result = await session.execute(folder_stmt)
folder = folder_result.scalar_one_or_none()
if not folder:
return False
# If delete_workflows is True, delete all workflows in the folder
if delete_workflows:
# Get workflow permanent IDs in the folder (inline logic)
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
workflow_permanent_ids_stmt = (
select(WorkflowModel.workflow_permanent_id)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.where(WorkflowModel.folder_id == folder_id)
)
result = await session.execute(workflow_permanent_ids_stmt)
workflow_permanent_ids = list(result.scalars().all())
# Soft delete all workflows with these permanent IDs in a single bulk update
if workflow_permanent_ids:
update_workflows_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.values(deleted_at=datetime.utcnow())
)
await session.execute(update_workflows_query)
else:
# Just remove folder_id from all workflows in this folder
update_workflows_query = (
update(WorkflowModel)
.where(WorkflowModel.folder_id == folder_id)
.where(WorkflowModel.organization_id == organization_id)
.values(folder_id=None, modified_at=datetime.utcnow())
)
await session.execute(update_workflows_query)
# Soft delete the folder
folder.deleted_at = datetime.utcnow()
await session.commit()
return True
@db_operation("get_folder_workflow_count")
async def get_folder_workflow_count(
self,
folder_id: str,
organization_id: str,
) -> int:
"""Get the count of workflows (latest versions only) in a folder."""
async with self.Session() as session:
# Subquery to get the latest version for each workflow (same pattern as get_workflows_by_organization_id)
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
# Count workflows where the latest version is in this folder
stmt = (
select(func.count(WorkflowModel.workflow_permanent_id))
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.where(WorkflowModel.folder_id == folder_id)
)
result = await session.execute(stmt)
return result.scalar_one()
@db_operation("get_folder_workflow_counts_batch")
async def get_folder_workflow_counts_batch(
self,
folder_ids: list[str],
organization_id: str,
) -> dict[str, int]:
"""Get workflow counts for multiple folders in a single query."""
async with self.Session() as session:
# Subquery to get the latest version for each workflow
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
# Count workflows grouped by folder_id
stmt = (
select(
WorkflowModel.folder_id,
func.count(WorkflowModel.workflow_permanent_id).label("count"),
)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.where(WorkflowModel.folder_id.in_(folder_ids))
.group_by(WorkflowModel.folder_id)
)
result = await session.execute(stmt)
rows = result.all()
# Convert to dict; folders with no workflows will be absent from the result
return {row.folder_id: row.count for row in rows}
@db_operation("update_workflow_folder")
async def update_workflow_folder(
self,
workflow_permanent_id: str,
organization_id: str,
folder_id: str | None,
) -> Workflow | None:
"""Update folder assignment for the latest version of a workflow."""
# Get the latest version of the workflow
latest_workflow = await self.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
)
if not latest_workflow:
return None
async with self.Session() as session:
# Validate folder exists in-org if folder_id is provided
if folder_id:
stmt = (
select(FolderModel.folder_id)
.where(FolderModel.folder_id == folder_id)
.where(FolderModel.organization_id == organization_id)
.where(FolderModel.deleted_at.is_(None))
)
if (await session.scalar(stmt)) is None:
raise ValueError(f"Folder {folder_id} not found")
workflow_model = await session.get(WorkflowModel, latest_workflow.workflow_id)
if workflow_model:
workflow_model.folder_id = folder_id
workflow_model.modified_at = datetime.utcnow()
# Update folder's modified_at in the same transaction
if folder_id:
folder_model = await session.get(FolderModel, folder_id)
if folder_model:
folder_model.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_model)
return convert_to_workflow(workflow_model, self.debug_enabled)
return None

View file

@ -1,574 +0,0 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any
from sqlalchemy import and_, delete, select
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.mixins.base import read_retry
from skyvern.forge.sdk.db.models import (
TaskV2Model,
ThoughtModel,
WorkflowRunBlockModel,
)
from skyvern.forge.sdk.db.utils import (
convert_to_task_v2,
convert_to_workflow_run_block,
serialize_proxy_location,
)
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
from skyvern.schemas.runs import ProxyLocationInput, RunEngine
from skyvern.schemas.workflows import BlockStatus, BlockType
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class ObserverMixin:
Session: _SessionFactory
debug_enabled: bool
"""Database operations for observer tasks (TaskV2), thoughts, and workflow run blocks."""
# Cross-mixin method stubs (provided by TasksMixin at runtime)
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
raise NotImplementedError
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
raise NotImplementedError
@read_retry()
@db_operation("get_task_v2", log_errors=False)
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
async with self.Session() as session:
if task_v2 := (
await session.scalars(
select(TaskV2Model)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first():
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
return None
@db_operation("delete_thoughts")
async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ThoughtModel).where(
and_(
ThoughtModel.observer_cruise_id == task_v2_id,
ThoughtModel.organization_id == organization_id,
)
)
await session.execute(stmt)
await session.commit()
@db_operation("get_task_v2_by_workflow_run_id")
async def get_task_v2_by_workflow_run_id(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> TaskV2 | None:
async with self.Session() as session:
if task_v2 := (
await session.scalars(
select(TaskV2Model)
.filter_by(organization_id=organization_id)
.filter_by(workflow_run_id=workflow_run_id)
)
).first():
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
return None
@db_operation("get_thought")
async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None:
async with self.Session() as session:
if thought := (
await session.scalars(
select(ThoughtModel)
.filter_by(observer_thought_id=thought_id)
.filter_by(organization_id=organization_id)
)
).first():
return Thought.model_validate(thought)
return None
@db_operation("get_thoughts")
async def get_thoughts(
self,
*,
task_v2_id: str,
thought_types: list[ThoughtType],
organization_id: str,
) -> list[Thought]:
async with self.Session() as session:
query = (
select(ThoughtModel)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
.order_by(ThoughtModel.created_at)
)
if thought_types:
query = query.filter(ThoughtModel.observer_thought_type.in_(thought_types))
thoughts = (await session.scalars(query)).all()
return [Thought.model_validate(thought) for thought in thoughts]
@db_operation("create_task_v2")
async def create_task_v2(
self,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
prompt: str | None = None,
url: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocationInput = None,
totp_identifier: str | None = None,
totp_verification_url: str | None = None,
webhook_callback_url: str | None = None,
extracted_information_schema: dict | list | str | None = None,
error_code_mapping: dict | None = None,
model: dict[str, Any] | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
browser_address: str | None = None,
run_with: str | None = None,
) -> TaskV2:
async with self.Session() as session:
new_task_v2 = TaskV2Model(
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
prompt=prompt,
url=url,
proxy_location=serialize_proxy_location(proxy_location),
totp_identifier=totp_identifier,
totp_verification_url=totp_verification_url,
webhook_callback_url=webhook_callback_url,
extracted_information_schema=extracted_information_schema,
error_code_mapping=error_code_mapping,
organization_id=organization_id,
model=model,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
browser_address=browser_address,
run_with=run_with,
)
session.add(new_task_v2)
await session.commit()
await session.refresh(new_task_v2)
return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled)
@db_operation("create_thought")
async def create_thought(
self,
task_v2_id: str,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
workflow_run_block_id: str | None = None,
user_input: str | None = None,
observation: str | None = None,
thought: str | None = None,
answer: str | None = None,
thought_scenario: str | None = None,
thought_type: str = ThoughtType.plan,
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
reasoning_token_count: int | None = None,
cached_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> Thought:
async with self.Session() as session:
new_thought = ThoughtModel(
observer_cruise_id=task_v2_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_block_id=workflow_run_block_id,
user_input=user_input,
observation=observation,
thought=thought,
answer=answer,
observer_thought_scenario=thought_scenario,
observer_thought_type=thought_type,
output=output,
input_token_count=input_token_count,
output_token_count=output_token_count,
reasoning_token_count=reasoning_token_count,
cached_token_count=cached_token_count,
thought_cost=thought_cost,
organization_id=organization_id,
)
session.add(new_thought)
await session.commit()
await session.refresh(new_thought)
return Thought.model_validate(new_thought)
@db_operation("update_thought")
async def update_thought(
self,
thought_id: str,
workflow_run_block_id: str | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
observation: str | None = None,
thought: str | None = None,
answer: str | None = None,
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
reasoning_token_count: int | None = None,
cached_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> Thought:
async with self.Session() as session:
thought_obj = (
await session.scalars(
select(ThoughtModel)
.filter_by(observer_thought_id=thought_id)
.filter_by(organization_id=organization_id)
)
).first()
if thought_obj:
if workflow_run_block_id:
thought_obj.workflow_run_block_id = workflow_run_block_id
if workflow_run_id:
thought_obj.workflow_run_id = workflow_run_id
if workflow_id:
thought_obj.workflow_id = workflow_id
if workflow_permanent_id:
thought_obj.workflow_permanent_id = workflow_permanent_id
if observation:
thought_obj.observation = observation
if thought:
thought_obj.thought = thought
if answer:
thought_obj.answer = answer
if output:
thought_obj.output = output
if input_token_count:
thought_obj.input_token_count = input_token_count
if output_token_count:
thought_obj.output_token_count = output_token_count
if reasoning_token_count:
thought_obj.reasoning_token_count = reasoning_token_count
if cached_token_count:
thought_obj.cached_token_count = cached_token_count
if thought_cost:
thought_obj.thought_cost = thought_cost
await session.commit()
await session.refresh(thought_obj)
return Thought.model_validate(thought_obj)
raise NotFoundError(f"Thought {thought_id}")
@db_operation("update_task_v2")
async def update_task_v2(
self,
task_v2_id: str,
status: TaskV2Status | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
url: str | None = None,
prompt: str | None = None,
summary: str | None = None,
output: dict[str, Any] | None = None,
organization_id: str | None = None,
webhook_failure_reason: str | None = None,
failure_category: list[dict[str, Any]] | None = None,
) -> TaskV2:
async with self.Session() as session:
task_v2 = (
await session.scalars(
select(TaskV2Model)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first()
if task_v2:
if status:
task_v2.status = status
if status == TaskV2Status.queued and task_v2.queued_at is None:
task_v2.queued_at = datetime.utcnow()
if status == TaskV2Status.running and task_v2.started_at is None:
task_v2.started_at = datetime.utcnow()
if status.is_final() and task_v2.finished_at is None:
task_v2.finished_at = datetime.utcnow()
if workflow_run_id:
task_v2.workflow_run_id = workflow_run_id
if workflow_id:
task_v2.workflow_id = workflow_id
if workflow_permanent_id:
task_v2.workflow_permanent_id = workflow_permanent_id
if url:
task_v2.url = url
if prompt:
task_v2.prompt = prompt
if summary:
task_v2.summary = summary
if output:
task_v2.output = output
if webhook_failure_reason is not None:
task_v2.webhook_failure_reason = webhook_failure_reason
if failure_category is not None:
task_v2.failure_category = failure_category
await session.commit()
await session.refresh(task_v2)
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
@db_operation("create_workflow_run_block")
async def create_workflow_run_block(
self,
workflow_run_id: str,
parent_workflow_run_block_id: str | None = None,
organization_id: str | None = None,
task_id: str | None = None,
label: str | None = None,
block_type: BlockType | None = None,
status: BlockStatus = BlockStatus.running,
output: dict | list | str | None = None,
continue_on_failure: bool = False,
engine: RunEngine | None = None,
current_value: str | None = None,
current_index: int | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
new_workflow_run_block = WorkflowRunBlockModel(
workflow_run_id=workflow_run_id,
parent_workflow_run_block_id=parent_workflow_run_block_id,
organization_id=organization_id,
task_id=task_id,
label=label,
block_type=block_type,
status=status,
output=output,
continue_on_failure=continue_on_failure,
engine=engine,
current_value=current_value,
current_index=current_index,
)
session.add(new_workflow_run_block)
await session.commit()
await session.refresh(new_workflow_run_block)
task = None
if task_id:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(new_workflow_run_block, task=task)
@db_operation("delete_workflow_run_blocks")
async def delete_workflow_run_blocks(self, workflow_run_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(WorkflowRunBlockModel).where(
and_(
WorkflowRunBlockModel.workflow_run_id == workflow_run_id,
WorkflowRunBlockModel.organization_id == organization_id,
)
)
await session.execute(stmt)
await session.commit()
@db_operation("update_workflow_run_block")
async def update_workflow_run_block(
self,
workflow_run_block_id: str,
organization_id: str | None = None,
status: BlockStatus | None = None,
output: dict | list | str | None = None,
failure_reason: str | None = None,
task_id: str | None = None,
loop_values: list | None = None,
current_value: str | None = None,
current_index: int | None = None,
recipients: list[str] | None = None,
attachments: list[str] | None = None,
subject: str | None = None,
body: str | None = None,
prompt: str | None = None,
wait_sec: int | None = None,
description: str | None = None,
block_workflow_run_id: str | None = None,
engine: str | None = None,
# HTTP request block parameters
http_request_method: str | None = None,
http_request_url: str | None = None,
http_request_headers: dict[str, str] | None = None,
http_request_body: dict[str, Any] | None = None,
http_request_parameters: dict[str, Any] | None = None,
http_request_timeout: int | None = None,
http_request_follow_redirects: bool | None = None,
ai_fallback_triggered: bool | None = None,
# block-level error codes (e.g. ["FILE_PARSER_ERROR"])
error_codes: list[str] | None = None,
# human interaction block
instructions: str | None = None,
positive_descriptor: str | None = None,
negative_descriptor: str | None = None,
# conditional block
executed_branch_id: str | None = None,
executed_branch_expression: str | None = None,
executed_branch_result: bool | None = None,
executed_branch_next_block: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_block_id=workflow_run_block_id)
.filter_by(organization_id=organization_id)
)
).first()
if workflow_run_block:
if status:
workflow_run_block.status = status
if output:
workflow_run_block.output = output
if task_id:
workflow_run_block.task_id = task_id
if failure_reason:
workflow_run_block.failure_reason = failure_reason
# Use `is not None` instead of truthiness checks so that falsy
# values like current_index=0, empty loop_values=[], or
# current_value="" are correctly persisted. Without this,
# the first loop iteration (index 0) loses its metadata.
if loop_values is not None:
workflow_run_block.loop_values = loop_values
if current_value is not None:
workflow_run_block.current_value = current_value
if current_index is not None:
workflow_run_block.current_index = current_index
if recipients:
workflow_run_block.recipients = recipients
if attachments:
workflow_run_block.attachments = attachments
if subject:
workflow_run_block.subject = subject
if body:
workflow_run_block.body = body
if prompt:
workflow_run_block.prompt = prompt
if wait_sec:
workflow_run_block.wait_sec = wait_sec
if description:
workflow_run_block.description = description
if block_workflow_run_id:
workflow_run_block.block_workflow_run_id = block_workflow_run_id
if engine:
workflow_run_block.engine = engine
# HTTP request block fields
if http_request_method:
workflow_run_block.http_request_method = http_request_method
if http_request_url:
workflow_run_block.http_request_url = http_request_url
if http_request_headers:
workflow_run_block.http_request_headers = http_request_headers
if http_request_body:
workflow_run_block.http_request_body = http_request_body
if http_request_parameters:
workflow_run_block.http_request_parameters = http_request_parameters
if http_request_timeout:
workflow_run_block.http_request_timeout = http_request_timeout
if http_request_follow_redirects is not None:
workflow_run_block.http_request_follow_redirects = http_request_follow_redirects
if ai_fallback_triggered is not None:
workflow_run_block.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
if error_codes is not None:
workflow_run_block.error_codes = error_codes
# human interaction block fields
if instructions:
workflow_run_block.instructions = instructions
if positive_descriptor:
workflow_run_block.positive_descriptor = positive_descriptor
if negative_descriptor:
workflow_run_block.negative_descriptor = negative_descriptor
# conditional block fields
if executed_branch_id:
workflow_run_block.executed_branch_id = executed_branch_id
if executed_branch_expression is not None:
workflow_run_block.executed_branch_expression = executed_branch_expression
if executed_branch_result is not None:
workflow_run_block.executed_branch_result = executed_branch_result
if executed_branch_next_block is not None:
workflow_run_block.executed_branch_next_block = executed_branch_next_block
await session.commit()
await session.refresh(workflow_run_block)
else:
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
task = None
task_id = workflow_run_block.task_id
if task_id:
task = await self.get_task(task_id, organization_id=workflow_run_block.organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
@db_operation("get_workflow_run_block")
async def get_workflow_run_block(
self,
workflow_run_block_id: str,
organization_id: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_block_id=workflow_run_block_id)
.filter_by(organization_id=organization_id)
)
).first()
if workflow_run_block:
task = None
task_id = workflow_run_block.task_id
if task_id:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
@db_operation("get_workflow_run_block_by_task_id")
async def get_workflow_run_block_by_task_id(
self,
task_id: str,
organization_id: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
await session.scalars(
select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first()
if workflow_run_block:
task = None
task_id = workflow_run_block.task_id
if task_id:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock not found by {task_id}")
@db_operation("get_workflow_run_blocks")
async def get_workflow_run_blocks(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRunBlock]:
async with self.Session() as session:
workflow_run_blocks = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(organization_id=organization_id)
.order_by(WorkflowRunBlockModel.created_at.desc())
)
).all()
tasks = await self.get_tasks_by_workflow_run_id(workflow_run_id)
tasks_dict = {task.task_id: task for task in tasks}
return [
convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id))
for workflow_run_block in workflow_run_blocks
]

View file

@ -1,380 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Literal, overload
from sqlalchemy import select, update
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.mixins.base import read_retry
from skyvern.forge.sdk.db.models import (
OrganizationAuthTokenModel,
OrganizationModel,
TaskModel,
WorkflowRunModel,
)
from skyvern.forge.sdk.db.utils import (
convert_to_organization,
convert_to_organization_auth_token,
)
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.schemas.organizations import (
AzureClientSecretCredential,
AzureOrganizationAuthToken,
BitwardenCredential,
BitwardenOrganizationAuthToken,
Organization,
OrganizationAuthToken,
)
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class OrganizationsMixin:
Session: _SessionFactory
"""Database operations for organization and auth-token management."""
@read_retry()
@db_operation("get_active_verification_requests", log_errors=False)
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
"""Return active 2FA verification requests for an organization.
Queries both tasks and workflow runs where waiting_for_verification_code=True.
Used to provide initial state when a WebSocket notification client connects.
"""
results: list[dict] = []
async with self.Session() as session:
# Tasks waiting for verification (exclude finalized tasks)
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
task_rows = (
await session.scalars(
select(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter_by(workflow_run_id=None)
.filter(TaskModel.status.not_in(finalized_task_statuses))
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for t in task_rows:
results.append(
{
"task_id": t.task_id,
"workflow_run_id": None,
"verification_code_identifier": t.verification_code_identifier,
"verification_code_polling_started_at": (
t.verification_code_polling_started_at.isoformat()
if t.verification_code_polling_started_at
else None
),
}
)
# Workflow runs waiting for verification (exclude finalized runs)
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
wr_rows = (
await session.scalars(
select(WorkflowRunModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for wr in wr_rows:
results.append(
{
"task_id": None,
"workflow_run_id": wr.workflow_run_id,
"verification_code_identifier": wr.verification_code_identifier,
"verification_code_polling_started_at": (
wr.verification_code_polling_started_at.isoformat()
if wr.verification_code_polling_started_at
else None
),
}
)
return results
@db_operation("get_all_organizations")
async def get_all_organizations(self) -> list[Organization]:
async with self.Session() as session:
organizations = (await session.scalars(select(OrganizationModel))).all()
return [convert_to_organization(organization) for organization in organizations]
@db_operation("get_organization")
async def get_organization(self, organization_id: str) -> Organization | None:
async with self.Session() as session:
if organization := (
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
).first():
return convert_to_organization(organization)
else:
return None
@db_operation("get_organization_by_domain")
async def get_organization_by_domain(self, domain: str) -> Organization | None:
async with self.Session() as session:
if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first():
return convert_to_organization(organization)
return None
@db_operation("create_organization")
async def create_organization(
self,
organization_name: str,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
max_retries_per_step: int | None = None,
domain: str | None = None,
organization_id: str | None = None,
) -> Organization:
async with self.Session() as session:
org = OrganizationModel(
organization_id=organization_id,
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
max_retries_per_step=max_retries_per_step,
domain=domain,
)
session.add(org)
await session.commit()
await session.refresh(org)
return convert_to_organization(org)
@db_operation("update_organization")
async def update_organization(
self,
organization_id: str,
organization_name: str | None = None,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
max_retries_per_step: int | None = None,
) -> Organization:
async with self.Session() as session:
organization = (
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
).first()
if not organization:
raise NotFoundError
if organization_name:
organization.organization_name = organization_name
if webhook_callback_url:
organization.webhook_callback_url = webhook_callback_url
if max_steps_per_run:
organization.max_steps_per_run = max_steps_per_run
if max_retries_per_step:
organization.max_retries_per_step = max_retries_per_step
await session.commit()
await session.refresh(organization)
return Organization.model_validate(organization)
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal["api", "onepassword_service_account", "custom_credential_service"],
) -> OrganizationAuthToken | None: ...
@overload
async def get_valid_org_auth_token( # type: ignore
self,
organization_id: str,
token_type: Literal["azure_client_secret_credential"],
) -> AzureOrganizationAuthToken | None: ...
@overload
async def get_valid_org_auth_token( # type: ignore
self,
organization_id: str,
token_type: Literal["bitwarden_credential"],
) -> BitwardenOrganizationAuthToken | None: ...
@db_operation("get_valid_org_auth_token")
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[
"api",
"onepassword_service_account",
"azure_client_secret_credential",
"bitwarden_credential",
"custom_credential_service",
],
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken | None:
async with self.Session() as session:
if token := (
await session.scalars(
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).first():
return await convert_to_organization_auth_token(token, token_type)
else:
return None
@db_operation("get_valid_org_auth_tokens")
async def get_valid_org_auth_tokens(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> list[OrganizationAuthToken]:
async with self.Session() as session:
tokens = (
await session.scalars(
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).all()
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
@db_operation("validate_org_auth_token")
async def validate_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
valid: bool | None = True,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken | None:
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
async with self.Session() as session:
query = (
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
)
if encrypted_token:
query = query.filter_by(encrypted_token=encrypted_token)
else:
query = query.filter_by(token=token)
if valid is not None:
query = query.filter_by(valid=valid)
if token_obj := (await session.scalars(query)).first():
return await convert_to_organization_auth_token(token_obj, token_type)
else:
return None
@db_operation("create_org_auth_token")
async def create_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str | AzureClientSecretCredential | BitwardenCredential,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
if not isinstance(token, AzureClientSecretCredential):
raise TypeError("Expected AzureClientSecretCredential for this token_type")
plaintext_token = token.model_dump_json()
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
if not isinstance(token, BitwardenCredential):
raise TypeError("Expected BitwardenCredential for this token_type")
plaintext_token = token.model_dump_json()
else:
if not isinstance(token, str):
raise TypeError("Expected str token for this token_type")
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
plaintext_token = ""
async with self.Session() as session:
auth_token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=plaintext_token,
encrypted_token=encrypted_token,
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
)
session.add(auth_token)
await session.commit()
await session.refresh(auth_token)
return await convert_to_organization_auth_token(auth_token, token_type)
@db_operation("invalidate_org_auth_tokens")
async def invalidate_org_auth_tokens(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> None:
"""Invalidate all existing tokens of a specific type for an organization."""
async with self.Session() as session:
await session.execute(
update(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.values(valid=False)
)
await session.commit()
@db_operation("replace_org_auth_token")
async def replace_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str | AzureClientSecretCredential | BitwardenCredential,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
"""Atomically invalidate existing tokens and create a new one in a single transaction."""
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
if not isinstance(token, AzureClientSecretCredential):
raise TypeError("Expected AzureClientSecretCredential for this token_type")
plaintext_token = token.model_dump_json()
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
if not isinstance(token, BitwardenCredential):
raise TypeError("Expected BitwardenCredential for this token_type")
plaintext_token = token.model_dump_json()
else:
if not isinstance(token, str):
raise TypeError("Expected str token for this token_type")
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
plaintext_token = ""
async with self.Session() as session:
# Invalidate existing tokens
await session.execute(
update(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(valid=True)
.values(valid=False)
)
# Create new token
auth_token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=plaintext_token,
encrypted_token=encrypted_token,
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
)
session.add(auth_token)
await session.commit()
await session.refresh(auth_token)
return await convert_to_organization_auth_token(auth_token, token_type)

View file

@ -1,155 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from sqlalchemy import and_, asc, select
from skyvern.config import settings
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.models import TOTPCodeModel
from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
class OTPMixin:
"""Database operations for OTP/TOTP management."""
Session: _SessionFactory
@db_operation("get_otp_codes")
async def get_otp_codes(
self,
organization_id: str,
totp_identifier: str,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
otp_type: OTPType | None = None,
workflow_run_id: str | None = None,
limit: int | None = None,
) -> list[TOTPCode]:
"""
1. filter by:
- organization_id
- totp_identifier
- workflow_run_id (optional)
2. make sure created_at is within the valid lifespan
3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc
4. apply an optional limit at the DB layer
"""
all_null = and_(
TOTPCodeModel.task_id.is_(None),
TOTPCodeModel.workflow_id.is_(None),
TOTPCodeModel.workflow_run_id.is_(None),
)
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter_by(totp_identifier=totp_identifier)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
)
if otp_type:
query = query.filter(TOTPCodeModel.otp_type == otp_type)
if workflow_run_id is not None:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
query = query.order_by(asc(all_null), TOTPCodeModel.created_at.desc())
if limit is not None:
query = query.limit(limit)
totp_codes = (await session.scalars(query)).all()
return [TOTPCode.model_validate(code) for code in totp_codes]
@db_operation("get_otp_codes_by_run")
async def get_otp_codes_by_run(
self,
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
limit: int = 1,
) -> list[TOTPCode]:
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
The user submits codes manually via the UI, and this method finds them by run context.
"""
if not workflow_run_id and not task_id:
return []
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
)
if workflow_run_id:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
elif task_id:
query = query.filter(TOTPCodeModel.task_id == task_id)
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
results = (await session.scalars(query)).all()
return [TOTPCode.model_validate(r) for r in results]
@db_operation("get_recent_otp_codes")
async def get_recent_otp_codes(
self,
organization_id: str,
limit: int = 50,
valid_lifespan_minutes: int | None = None,
otp_type: OTPType | None = None,
workflow_run_id: str | None = None,
totp_identifier: str | None = None,
) -> list[TOTPCode]:
"""
Return recent otp codes for an organization ordered by newest first with optional
workflow_run_id filtering.
"""
async with self.Session() as session:
query = select(TOTPCodeModel).filter_by(organization_id=organization_id)
if valid_lifespan_minutes is not None:
query = query.filter(
TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)
)
if otp_type:
query = query.filter(TOTPCodeModel.otp_type == otp_type)
if workflow_run_id is not None:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
if totp_identifier:
query = query.filter(TOTPCodeModel.totp_identifier == totp_identifier)
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
totp_codes = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_codes]
@db_operation("create_otp_code")
async def create_otp_code(
self,
organization_id: str,
totp_identifier: str,
content: str,
code: str,
otp_type: OTPType,
task_id: str | None = None,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
source: str | None = None,
expired_at: datetime | None = None,
) -> TOTPCode:
async with self.Session() as session:
new_totp_code = TOTPCodeModel(
organization_id=organization_id,
totp_identifier=totp_identifier,
content=content,
code=code,
task_id=task_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
source=source,
expired_at=expired_at,
otp_type=otp_type,
)
session.add(new_totp_code)
await session.commit()
await session.refresh(new_totp_code)
return TOTPCode.model_validate(new_totp_code)

View file

@ -1,654 +0,0 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import TYPE_CHECKING, Any
import structlog
from sqlalchemy import func, or_, select, text, update
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError
from skyvern.forge.sdk.db.models import (
WorkflowModel,
WorkflowRunModel,
WorkflowScheduleModel,
)
from skyvern.forge.sdk.db.utils import convert_to_workflow_schedule
from skyvern.forge.sdk.schemas.workflow_schedules import OrganizationScheduleItem, WorkflowSchedule
from skyvern.forge.sdk.workflow.schedules import compute_next_run
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
from skyvern.forge.sdk.db._sentinels import _UNSET
LOG = structlog.get_logger()
class SchedulesMixin:
"""Database operations for workflow schedules.
.. deprecated::
This mixin is part of the legacy database layer. New code should use the
repository classes in ``skyvern.forge.sdk.db.repositories`` instead.
Cross-mixin migrations already completed:
- ``soft_delete_workflow_and_schedules_by_permanent_id`` ``WorkflowsRepository``
(operates on workflows as the primary entity, schedules are a side-effect).
"""
Session: _SessionFactory
engine: AsyncEngine
debug_enabled: bool
_sqlite_schedule_lock: asyncio.Lock | None
@db_operation("create_workflow_schedule")
async def create_workflow_schedule(
self,
organization_id: str,
workflow_permanent_id: str,
cron_expression: str,
timezone: str,
enabled: bool,
parameters: dict[str, Any] | None = None,
temporal_schedule_id: str | None = None,
name: str | None = None,
description: str | None = None,
) -> WorkflowSchedule:
async with self.Session() as session:
workflow_schedule = WorkflowScheduleModel(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
cron_expression=cron_expression,
timezone=timezone,
enabled=enabled,
parameters=parameters,
temporal_schedule_id=temporal_schedule_id,
name=name,
description=description,
)
session.add(workflow_schedule)
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("create_workflow_schedule_with_limit")
async def create_workflow_schedule_with_limit(
self,
organization_id: str,
workflow_permanent_id: str,
max_schedules: int | None,
cron_expression: str,
timezone: str,
enabled: bool,
parameters: dict[str, Any] | None = None,
name: str | None = None,
description: str | None = None,
) -> tuple[WorkflowSchedule, int]:
"""Create a schedule atomically with limit enforcement.
On PostgreSQL, uses an advisory lock to serialize concurrent creates for
the same workflow, preventing TOCTOU races on the schedule count.
On SQLite, uses an asyncio.Lock (set on AgentDB.__init__) since SQLite
is single-writer and has no advisory lock support.
Returns (created_schedule, count_before_insert).
Raises ScheduleLimitExceededError if count >= max_schedules.
"""
# SQLite: serialize via Python lock (no advisory locks available).
# The lock is held across the count-check + insert to prevent TOCTOU.
sqlite_lock = getattr(self, "_sqlite_schedule_lock", None)
if sqlite_lock is not None:
async with sqlite_lock:
return await self._create_schedule_with_limit_inner(
organization_id,
workflow_permanent_id,
max_schedules,
cron_expression,
timezone,
enabled,
parameters,
name,
description,
use_advisory_lock=False,
)
return await self._create_schedule_with_limit_inner(
organization_id,
workflow_permanent_id,
max_schedules,
cron_expression,
timezone,
enabled,
parameters,
name,
description,
use_advisory_lock=True,
)
# Intentionally not decorated with @db_operation — errors are caught by the
# outer create_workflow_schedule_with_limit which owns the operation name.
async def _create_schedule_with_limit_inner(
self,
organization_id: str,
workflow_permanent_id: str,
max_schedules: int | None,
cron_expression: str,
timezone: str,
enabled: bool,
parameters: dict[str, Any] | None,
name: str | None,
description: str | None,
*,
use_advisory_lock: bool,
) -> tuple[WorkflowSchedule, int]:
async with self.Session() as session:
if use_advisory_lock:
lock_key = f"schedule:{organization_id}:{workflow_permanent_id}"
await session.execute(
text("SELECT pg_advisory_xact_lock(hashtext(:key))"),
{"key": lock_key},
)
count = (
await session.execute(
select(func.count()).where(
WorkflowScheduleModel.organization_id == organization_id,
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
WorkflowScheduleModel.deleted_at.is_(None),
)
)
).scalar_one()
if max_schedules is not None and count >= max_schedules:
raise ScheduleLimitExceededError(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
current_count=count,
max_allowed=max_schedules,
)
workflow_schedule = WorkflowScheduleModel(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
cron_expression=cron_expression,
timezone=timezone,
enabled=enabled,
parameters=parameters,
name=name,
description=description,
)
session.add(workflow_schedule)
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled), count
@db_operation("set_temporal_schedule_id")
async def set_temporal_schedule_id(
self,
workflow_schedule_id: str,
organization_id: str,
temporal_schedule_id: str,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
deleted_at=None,
)
)
).first()
if not workflow_schedule:
return None
workflow_schedule.temporal_schedule_id = temporal_schedule_id
workflow_schedule.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("update_workflow_schedule")
async def update_workflow_schedule(
self,
workflow_schedule_id: str,
organization_id: str,
cron_expression: str,
timezone: str,
enabled: bool,
parameters: dict[str, Any] | None = None,
temporal_schedule_id: str | None | object = _UNSET,
name: str | None | object = _UNSET,
description: str | None | object = _UNSET,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
deleted_at=None,
)
)
).first()
if not workflow_schedule:
return None
workflow_schedule.cron_expression = cron_expression
workflow_schedule.timezone = timezone
workflow_schedule.enabled = enabled
workflow_schedule.parameters = parameters
if temporal_schedule_id is not _UNSET:
workflow_schedule.temporal_schedule_id = temporal_schedule_id
if name is not _UNSET:
workflow_schedule.name = name
if description is not _UNSET:
workflow_schedule.description = description
workflow_schedule.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("get_workflow_schedule_by_id")
async def get_workflow_schedule_by_id(
self,
workflow_schedule_id: str,
organization_id: str,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
deleted_at=None,
)
)
).first()
if not workflow_schedule:
return None
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("get_workflow_schedules")
async def get_workflow_schedules(
self,
workflow_permanent_id: str,
organization_id: str,
) -> list[WorkflowSchedule]:
async with self.Session() as session:
rows = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
deleted_at=None,
)
)
).all()
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
@db_operation("get_all_enabled_schedules")
async def get_all_enabled_schedules(
self,
organization_id: str | None = None,
) -> list[WorkflowSchedule]:
"""Fetch all enabled, non-deleted schedules, optionally filtered by org."""
async with self.Session() as session:
stmt = select(WorkflowScheduleModel).where(
WorkflowScheduleModel.enabled.is_(True),
WorkflowScheduleModel.deleted_at.is_(None),
)
if organization_id:
stmt = stmt.where(WorkflowScheduleModel.organization_id == organization_id)
rows = (await session.scalars(stmt)).all()
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
@db_operation("has_schedule_fired_since")
async def has_schedule_fired_since(
self,
workflow_schedule_id: str,
since: datetime,
) -> bool:
"""Check if a workflow_run exists for the given schedule since a timestamp."""
from sqlalchemy import exists as sa_exists
async with self.Session() as session:
row = (
await session.execute(
select(
sa_exists().where(
WorkflowRunModel.workflow_schedule_id == workflow_schedule_id,
WorkflowRunModel.created_at >= since,
)
)
)
).scalar()
return bool(row)
@db_operation("update_workflow_schedule_enabled")
async def update_workflow_schedule_enabled(
self,
workflow_schedule_id: str,
organization_id: str,
enabled: bool,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
deleted_at=None,
)
)
).first()
if not workflow_schedule:
return None
workflow_schedule.enabled = enabled
workflow_schedule.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("delete_workflow_schedule")
async def delete_workflow_schedule(
self,
workflow_schedule_id: str,
organization_id: str,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel).filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
deleted_at=None,
)
)
).first()
if not workflow_schedule:
return None
workflow_schedule.deleted_at = datetime.utcnow()
workflow_schedule.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("restore_workflow_schedule")
async def restore_workflow_schedule(
self,
workflow_schedule_id: str,
organization_id: str,
) -> WorkflowSchedule | None:
async with self.Session() as session:
workflow_schedule = (
await session.scalars(
select(WorkflowScheduleModel)
.filter_by(
workflow_schedule_id=workflow_schedule_id,
organization_id=organization_id,
)
.filter(WorkflowScheduleModel.deleted_at.isnot(None))
)
).first()
if not workflow_schedule:
return None
workflow_schedule.deleted_at = None
workflow_schedule.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow_schedule)
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
@db_operation("count_workflow_schedules")
async def count_workflow_schedules(
self,
organization_id: str,
workflow_permanent_id: str,
) -> int:
async with self.Session() as session:
result = await session.execute(
select(func.count()).where(
WorkflowScheduleModel.organization_id == organization_id,
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
WorkflowScheduleModel.deleted_at.is_(None),
)
)
return result.scalar_one()
@db_operation("list_organization_schedules")
async def list_organization_schedules(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
enabled_filter: bool | None = None,
search: str | None = None,
) -> tuple[list[OrganizationScheduleItem], int]:
"""
List all schedules for an organization, joined with workflow titles.
Returns (schedules, total_count).
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
async with self.Session() as session:
# Subquery to get the latest version title per workflow_permanent_id
latest_version_sq = (
select(
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(WorkflowModel.workflow_permanent_id)
.subquery()
)
workflow_title_sq = (
select(
WorkflowModel.workflow_permanent_id,
WorkflowModel.title,
)
.join(
latest_version_sq,
(WorkflowModel.workflow_permanent_id == latest_version_sq.c.workflow_permanent_id)
& (WorkflowModel.version == latest_version_sq.c.max_version),
)
.subquery()
)
# Base query: schedules joined with workflow titles
base_filter = (
select(WorkflowScheduleModel, workflow_title_sq.c.title.label("workflow_title"))
.outerjoin(
workflow_title_sq,
WorkflowScheduleModel.workflow_permanent_id == workflow_title_sq.c.workflow_permanent_id,
)
.where(WorkflowScheduleModel.organization_id == organization_id)
.where(WorkflowScheduleModel.deleted_at.is_(None))
)
if enabled_filter is not None:
base_filter = base_filter.where(WorkflowScheduleModel.enabled == enabled_filter)
if search:
base_filter = base_filter.where(
or_(
workflow_title_sq.c.title.icontains(search, autoescape=True),
WorkflowScheduleModel.name.icontains(search, autoescape=True),
)
)
# Count query
count_query = select(func.count()).select_from(base_filter.subquery())
total_count = (await session.execute(count_query)).scalar_one()
# Data query with pagination
data_query = (
base_filter.order_by(WorkflowScheduleModel.created_at.desc())
.limit(page_size)
.offset(db_page * page_size)
)
rows = (await session.execute(data_query)).all()
# Materialize row data while session is open
raw_schedules = []
for row in rows:
schedule_model = row[0]
raw_schedules.append(
(
schedule_model.workflow_schedule_id,
schedule_model.organization_id,
schedule_model.workflow_permanent_id,
row[1] or "Untitled Workflow",
schedule_model.cron_expression,
schedule_model.timezone,
schedule_model.enabled,
schedule_model.parameters,
schedule_model.name,
schedule_model.description,
schedule_model.created_at,
schedule_model.modified_at,
)
)
# Compute next_run outside session scope (pure CPU, no DB needed)
schedules: list[OrganizationScheduleItem] = []
for (
ws_id,
org_id,
wpid,
title,
cron_expr,
tz,
enabled,
params,
name,
description,
created,
modified,
) in raw_schedules:
next_run = None
if enabled:
try:
next_run = compute_next_run(cron_expr, tz)
except Exception:
LOG.warning(
"Failed to compute next_run for schedule",
workflow_schedule_id=ws_id,
exc_info=True,
)
schedules.append(
OrganizationScheduleItem(
workflow_schedule_id=ws_id,
organization_id=org_id,
workflow_permanent_id=wpid,
workflow_title=title,
cron_expression=cron_expr,
timezone=tz,
enabled=enabled,
parameters=params,
name=name,
description=description,
next_run=next_run,
created_at=created,
modified_at=modified,
)
)
return schedules, total_count
@db_operation("soft_delete_workflow_and_schedules_by_permanent_id")
async def soft_delete_workflow_and_schedules_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> list[str]:
"""Soft-delete a workflow and its active schedules in a single DB transaction.
.. deprecated::
Moved to ``WorkflowsRepository.soft_delete_workflow_and_schedules_by_permanent_id``
(skyvern/forge/sdk/db/repositories/workflows.py). The primary entity is the
workflow, not the schedule, so it belongs in the workflows repository.
This copy remains for backward compatibility with the legacy mixin layer.
"""
async with self.Session() as session:
select_query = (
select(WorkflowScheduleModel.workflow_schedule_id)
.where(WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowScheduleModel.deleted_at.is_(None))
)
if organization_id is not None:
select_query = select_query.where(WorkflowScheduleModel.organization_id == organization_id)
result = await session.execute(select_query)
schedule_ids = list(result.scalars().all())
deleted_at = datetime.utcnow()
if schedule_ids:
update_schedules_query = (
update(WorkflowScheduleModel)
.where(WorkflowScheduleModel.workflow_schedule_id.in_(schedule_ids))
.values(deleted_at=deleted_at)
)
await session.execute(update_schedules_query)
update_workflow_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
)
if organization_id is not None:
update_workflow_query = update_workflow_query.filter_by(organization_id=organization_id)
await session.execute(update_workflow_query.values(deleted_at=deleted_at))
await session.commit()
return schedule_ids
@db_operation("soft_delete_orphaned_schedules")
async def soft_delete_orphaned_schedules(self, limit: int = 500) -> list[tuple[str, str]]:
"""Soft-delete orphaned schedules and return their identities.
Uses a single UPDATE ... RETURNING statement so orphan detection and
soft-deletion happen atomically in one DB round-trip.
"""
async with self.Session() as session:
active_workflow_exists = (
select(WorkflowModel.workflow_permanent_id)
.where(WorkflowModel.workflow_permanent_id == WorkflowScheduleModel.workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
.correlate(WorkflowScheduleModel)
.exists()
)
orphaned_schedules = (
select(
WorkflowScheduleModel.workflow_schedule_id.label("workflow_schedule_id"),
WorkflowScheduleModel.workflow_permanent_id.label("workflow_permanent_id"),
)
.where(WorkflowScheduleModel.deleted_at.is_(None))
.where(~active_workflow_exists)
.limit(limit)
.cte("orphaned_schedules")
)
update_query = (
update(WorkflowScheduleModel)
.where(
WorkflowScheduleModel.workflow_schedule_id.in_(select(orphaned_schedules.c.workflow_schedule_id))
)
.where(WorkflowScheduleModel.deleted_at.is_(None))
.values(deleted_at=datetime.utcnow())
.returning(
WorkflowScheduleModel.workflow_schedule_id,
WorkflowScheduleModel.workflow_permanent_id,
)
)
result = await session.execute(update_query)
await session.commit()
return [(row[0], row[1]) for row in result.all()]

File diff suppressed because it is too large Load diff

View file

@ -1,935 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Sequence
import structlog
from sqlalchemy import and_, delete, func, select, update
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.mixins.base import read_retry
from skyvern.forge.sdk.db.models import (
ActionModel,
StepModel,
TaskModel,
TaskRunModel,
WorkflowRunModel,
)
from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
from skyvern.schemas.runs import ProxyLocationInput, RunStatus, RunType
from skyvern.schemas.steps import AgentStepOutput
from skyvern.webeye.actions.actions import Action
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
LOG = structlog.get_logger()
class TasksMixin:
Session: _SessionFactory
debug_enabled: bool
@db_operation("create_task")
async def create_task(
self,
url: str,
title: str | None,
navigation_goal: str | None,
data_extraction_goal: str | None,
navigation_payload: dict[str, Any] | list | str | None,
status: str = "created",
complete_criterion: str | None = None,
terminate_criterion: str | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocationInput = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
max_steps_per_run: int | None = None,
error_code_mapping: dict[str, str] | None = None,
task_type: str = TaskType.general,
application: str | None = None,
include_action_history_in_verification: bool | None = None,
model: dict[str, Any] | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
browser_session_id: str | None = None,
browser_address: str | None = None,
download_timeout: float | None = None,
) -> Task:
# Sanitize text fields to remove NUL bytes and control characters
# that PostgreSQL cannot store in text columns
def _sanitize(v: str | None) -> str | None:
return sanitize_postgres_text(v) if isinstance(v, str) else v
navigation_goal = _sanitize(navigation_goal)
data_extraction_goal = _sanitize(data_extraction_goal)
title = _sanitize(title)
url = sanitize_postgres_text(url)
complete_criterion = _sanitize(complete_criterion)
terminate_criterion = _sanitize(terminate_criterion)
async with self.Session() as session:
new_task = TaskModel(
status=status,
task_type=task_type,
url=url,
title=title,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
navigation_goal=navigation_goal,
complete_criterion=complete_criterion,
terminate_criterion=terminate_criterion,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=organization_id,
proxy_location=serialize_proxy_location(proxy_location),
extracted_information_schema=extracted_information_schema,
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
max_steps_per_run=max_steps_per_run,
error_code_mapping=error_code_mapping,
application=application,
include_action_history_in_verification=include_action_history_in_verification,
model=model,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
browser_session_id=browser_session_id,
browser_address=browser_address,
download_timeout=download_timeout,
)
session.add(new_task)
await session.commit()
await session.refresh(new_task)
return convert_to_task(new_task, self.debug_enabled)
@db_operation("create_step")
async def create_step(
self,
task_id: str,
order: int,
retry_index: int,
organization_id: str | None = None,
status: StepStatus = StepStatus.created,
created_by: str | None = None,
) -> Step:
async with self.Session() as session:
new_step = StepModel(
task_id=task_id,
order=order,
retry_index=retry_index,
status=status,
organization_id=organization_id,
created_by=created_by,
)
session.add(new_step)
await session.commit()
await session.refresh(new_step)
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
@read_retry()
@db_operation("get_task", log_errors=False)
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
async with self.Session() as session:
query = select(TaskModel).filter_by(task_id=task_id)
if organization_id is not None:
query = query.filter_by(organization_id=organization_id)
if task_obj := (await session.scalars(query)).first():
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info(
"Task not found",
task_id=task_id,
organization_id=organization_id,
)
return None
@db_operation("get_tasks_by_ids")
async def get_tasks_by_ids(
self,
task_ids: list[str],
organization_id: str,
) -> list[Task]:
async with self.Session() as session:
tasks = (
await session.scalars(
select(TaskModel).filter(TaskModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
@db_operation("get_step")
async def get_step(self, step_id: str, organization_id: str | None = None) -> Step | None:
async with self.Session() as session:
if step := (
await session.scalars(
select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
return None
@db_operation("get_task_steps")
async def get_task_steps(self, task_id: str, organization_id: str) -> list[Step]:
async with self.Session() as session:
if steps := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
)
).all():
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
else:
return []
@db_operation("get_steps_by_task_ids")
async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]:
async with self.Session() as session:
steps = (
await session.scalars(
select(StepModel).filter(StepModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
)
).all()
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
@db_operation("get_total_unique_step_order_count_by_task_ids")
async def get_total_unique_step_order_count_by_task_ids(
self,
*,
task_ids: list[str],
organization_id: str,
) -> int:
"""
Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids
Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s
where s.task_id in task_ids
"""
async with self.Session() as session:
subq = (
select(StepModel.task_id, StepModel.order)
.where(StepModel.task_id.in_(task_ids))
.where(StepModel.organization_id == organization_id)
.distinct()
.subquery()
)
query = select(func.count()).select_from(subq)
return (await session.execute(query)).scalar()
@db_operation("get_task_step_models")
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
async with self.Session() as session:
return (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order)
.order_by(StepModel.retry_index)
)
).all()
@db_operation("get_task_step_count")
async def get_task_step_count(self, task_id: str, organization_id: str | None = None) -> int:
async with self.Session() as session:
result = await session.scalar(
select(func.count(StepModel.step_id))
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
)
return result or 0
@db_operation("get_task_actions")
async def get_task_actions(self, task_id: str, organization_id: str | None = None) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id == task_id)
.order_by(ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
@db_operation("get_task_actions_hydrated")
async def get_task_actions_hydrated(self, task_id: str, organization_id: str | None = None) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id == task_id)
.order_by(ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [hydrate_action(action) for action in actions]
@db_operation("get_tasks_actions")
async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter(ActionModel.organization_id == organization_id)
.filter(ActionModel.task_id.in_(task_ids))
.order_by(ActionModel.created_at.desc())
)
actions = (await session.scalars(query)).all()
return [hydrate_action(action) for action in actions]
@db_operation("get_action_count_for_step")
async def get_action_count_for_step(self, step_id: str, task_id: str, organization_id: str) -> int:
"""Get count of actions for a step. Uses composite index for efficiency."""
async with self.Session() as session:
query = (
select(func.count())
.select_from(ActionModel)
.where(ActionModel.organization_id == organization_id)
.where(ActionModel.task_id == task_id)
.where(ActionModel.step_id == step_id)
)
result = await session.scalar(query)
return result or 0
@db_operation("get_first_step")
async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
async with self.Session() as session:
if step := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.order_by(StepModel.order.asc())
.order_by(StepModel.retry_index.asc())
)
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info(
"Latest step not found",
task_id=task_id,
organization_id=organization_id,
)
return None
@db_operation("get_latest_step")
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
async with self.Session() as session:
if step := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(organization_id=organization_id)
.filter(StepModel.status != StepStatus.canceled)
.order_by(StepModel.order.desc())
.order_by(StepModel.retry_index.desc())
)
).first():
return convert_to_step(step, debug_enabled=self.debug_enabled)
else:
LOG.info(
"Latest step not found",
task_id=task_id,
organization_id=organization_id,
)
return None
@db_operation("update_step")
async def update_step(
self,
task_id: str,
step_id: str,
status: StepStatus | None = None,
output: AgentStepOutput | None = None,
is_last: bool | None = None,
retry_index: int | None = None,
organization_id: str | None = None,
incremental_cost: float | None = None,
incremental_input_tokens: int | None = None,
incremental_output_tokens: int | None = None,
incremental_reasoning_tokens: int | None = None,
incremental_cached_tokens: int | None = None,
created_by: str | None = None,
) -> Step:
async with self.Session() as session:
if step := (
await session.scalars(
select(StepModel)
.filter_by(task_id=task_id)
.filter_by(step_id=step_id)
.filter_by(organization_id=organization_id)
)
).first():
if status is not None:
step.status = status
if status.is_terminal() and step.finished_at is None:
step.finished_at = datetime.utcnow()
if output is not None:
step.output = output.model_dump(exclude_none=True)
if is_last is not None:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
if incremental_cost is not None:
step.step_cost = incremental_cost + float(step.step_cost or 0)
if incremental_input_tokens is not None:
step.input_token_count = incremental_input_tokens + (step.input_token_count or 0)
if incremental_output_tokens is not None:
step.output_token_count = incremental_output_tokens + (step.output_token_count or 0)
if incremental_reasoning_tokens is not None:
step.reasoning_token_count = incremental_reasoning_tokens + (step.reasoning_token_count or 0)
if incremental_cached_tokens is not None:
step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0)
if created_by is not None:
step.created_by = created_by
await session.commit()
updated_step = await self.get_step(step_id, organization_id)
if not updated_step:
raise NotFoundError("Step not found")
return updated_step
else:
raise NotFoundError("Step not found")
@db_operation("clear_task_failure_reason")
async def clear_task_failure_reason(self, organization_id: str, task_id: str) -> Task:
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
task.failure_reason = None
await session.commit()
await session.refresh(task)
return convert_to_task(task, debug_enabled=self.debug_enabled)
else:
raise NotFoundError("Task not found")
@db_operation("update_task")
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
extracted_information: dict[str, Any] | list | str | None = None,
webhook_failure_reason: str | None = None,
failure_reason: str | None = None,
errors: list[dict[str, Any]] | None = None,
max_steps_per_run: int | None = None,
organization_id: str | None = None,
failure_category: list[dict[str, Any]] | None = None,
) -> Task:
if (
status is None
and extracted_information is None
and failure_reason is None
and errors is None
and max_steps_per_run is None
and webhook_failure_reason is None
and failure_category is None
):
raise ValueError(
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
)
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
if status is not None:
task.status = status
if status == TaskStatus.queued and task.queued_at is None:
task.queued_at = datetime.utcnow()
if status == TaskStatus.running and task.started_at is None:
task.started_at = datetime.utcnow()
if status.is_final() and task.finished_at is None:
task.finished_at = datetime.utcnow()
if extracted_information is not None:
task.extracted_information = extracted_information
if failure_reason is not None:
task.failure_reason = failure_reason
if errors is not None:
task.errors = (task.errors or []) + errors
if max_steps_per_run is not None:
task.max_steps_per_run = max_steps_per_run
if webhook_failure_reason is not None:
task.webhook_failure_reason = webhook_failure_reason
if failure_category is not None:
task.failure_category = failure_category
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
@db_operation("update_task_2fa_state")
async def update_task_2fa_state(
self,
task_id: str,
organization_id: str,
waiting_for_verification_code: bool,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> Task:
"""Update task 2FA verification code waiting state."""
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
task.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
task.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
task.verification_code_polling_started_at = verification_code_polling_started_at
if not waiting_for_verification_code:
# Clear identifiers when no longer waiting
task.verification_code_identifier = None
task.verification_code_polling_started_at = None
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
@db_operation("bulk_update_tasks")
async def bulk_update_tasks(
self,
task_ids: list[str],
status: TaskStatus | None = None,
failure_reason: str | None = None,
) -> None:
"""Bulk update tasks by their IDs.
Args:
task_ids: List of task IDs to update
status: Optional status to set for all tasks
failure_reason: Optional failure reason to set for all tasks
"""
if not task_ids:
return
async with self.Session() as session:
update_values = {}
if status:
update_values["status"] = status.value
if failure_reason:
update_values["failure_reason"] = failure_reason
if update_values:
update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values)
await session.execute(update_stmt)
await session.commit()
@db_operation("get_tasks")
async def get_tasks(
self,
page: int = 1,
page_size: int = 10,
task_status: list[TaskStatus] | None = None,
workflow_run_id: str | None = None,
organization_id: str | None = None,
only_standalone_tasks: bool = False,
application: str | None = None,
order_by_column: OrderBy = OrderBy.created_at,
order: SortDirection = SortDirection.desc,
) -> list[Task]:
"""
Get all tasks.
:param page: Starts at 1
:param page_size:
:param task_status:
:param workflow_run_id:
:param only_standalone_tasks:
:param order_by_column:
:param order:
:return:
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
async with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
query = (
select(TaskModel, WorkflowRunModel.workflow_permanent_id)
.join(WorkflowRunModel, TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id, isouter=True)
.filter(TaskModel.organization_id == organization_id)
)
if task_status:
query = query.filter(TaskModel.status.in_(task_status))
if workflow_run_id:
query = query.filter(TaskModel.workflow_run_id == workflow_run_id)
if only_standalone_tasks:
query = query.filter(TaskModel.workflow_run_id.is_(None))
if application:
query = query.filter(TaskModel.application == application)
order_by_col = getattr(TaskModel, order_by_column)
query = (
query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc())
.limit(page_size)
.offset(db_page * page_size)
)
results = (await session.execute(query)).all()
return [
convert_to_task(task, debug_enabled=self.debug_enabled, workflow_permanent_id=workflow_permanent_id)
for task, workflow_permanent_id in results
]
@db_operation("get_tasks_count")
async def get_tasks_count(
self,
organization_id: str,
task_status: list[TaskStatus] | None = None,
workflow_run_id: str | None = None,
only_standalone_tasks: bool = False,
application: str | None = None,
) -> int:
async with self.Session() as session:
count_query = (
select(func.count()).select_from(TaskModel).filter(TaskModel.organization_id == organization_id)
)
if task_status:
count_query = count_query.filter(TaskModel.status.in_(task_status))
if workflow_run_id:
count_query = count_query.filter(TaskModel.workflow_run_id == workflow_run_id)
if only_standalone_tasks:
count_query = count_query.filter(TaskModel.workflow_run_id.is_(None))
if application:
count_query = count_query.filter(TaskModel.application == application)
return (await session.execute(count_query)).scalar_one()
@db_operation("get_running_tasks_info_globally")
async def get_running_tasks_info_globally(
self,
stale_threshold_hours: int = 24,
) -> tuple[int, int]:
"""
Get information about running tasks across all organizations.
Used by cleanup service to determine if cleanup should be skipped.
Args:
stale_threshold_hours: Tasks not updated for this many hours are considered stale.
Returns:
Tuple of (active_task_count, stale_task_count).
Active tasks are those updated within the threshold.
Stale tasks are those not updated within the threshold but still in running status.
"""
async with self.Session() as session:
running_statuses = [TaskStatus.created, TaskStatus.queued, TaskStatus.running]
stale_cutoff = datetime.utcnow() - timedelta(hours=stale_threshold_hours)
# Count active tasks (recently updated)
active_query = (
select(func.count())
.select_from(TaskModel)
.filter(TaskModel.status.in_(running_statuses))
.filter(TaskModel.modified_at >= stale_cutoff)
)
active_count = (await session.execute(active_query)).scalar_one()
# Count stale tasks (not updated for a long time)
stale_query = (
select(func.count())
.select_from(TaskModel)
.filter(TaskModel.status.in_(running_statuses))
.filter(TaskModel.modified_at < stale_cutoff)
)
stale_count = (await session.execute(stale_query)).scalar_one()
return (active_count, stale_count)
@db_operation("get_latest_task_by_workflow_id")
async def get_latest_task_by_workflow_id(
self,
organization_id: str,
workflow_id: str,
before: datetime | None = None,
) -> Task | None:
async with self.Session() as session:
query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id)
if before:
query = query.filter(TaskModel.created_at < before)
task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first()
if task:
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
@db_operation("get_last_task_for_workflow_run")
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at.desc())
)
).first():
return convert_to_task(task, debug_enabled=self.debug_enabled)
return None
@db_operation("get_tasks_by_workflow_run_id")
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
async with self.Session() as session:
tasks = (
await session.scalars(
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
@db_operation("delete_task_steps")
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
async with self.Session() as session:
# delete artifacts by filtering organization_id and task_id
stmt = delete(StepModel).where(
and_(
StepModel.organization_id == organization_id,
StepModel.task_id == task_id,
)
)
await session.execute(stmt)
await session.commit()
@db_operation("get_previous_actions_for_task")
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
async with self.Session() as session:
query = (
select(ActionModel)
.filter_by(task_id=task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]
@db_operation("delete_task_actions")
async def delete_task_actions(self, organization_id: str, task_id: str) -> None:
async with self.Session() as session:
# delete actions by filtering organization_id and task_id
stmt = delete(ActionModel).where(
and_(
ActionModel.organization_id == organization_id,
ActionModel.task_id == task_id,
)
)
await session.execute(stmt)
await session.commit()
async def sync_task_run_status(
self,
organization_id: str,
run_id: str,
status: str,
started_at: datetime | None = None,
finished_at: datetime | None = None,
) -> None:
"""Best-effort write-through: propagate status from source table to task_runs.
Does NOT raise if the task_runs row is missing (race at creation time).
"""
try:
async with self.Session() as session:
vals: dict[str, Any] = {"status": status}
if started_at is not None:
vals["started_at"] = started_at
if finished_at is not None:
vals["finished_at"] = finished_at
stmt = (
update(TaskRunModel)
.where(TaskRunModel.run_id == run_id)
.where(TaskRunModel.organization_id == organization_id)
.values(**vals)
)
await session.execute(stmt)
await session.commit()
except Exception:
LOG.warning(
"Best-effort task_run status sync failed",
run_id=run_id,
organization_id=organization_id,
status=status,
exc_info=True,
)
@db_operation("create_task_run")
async def create_task_run(
self,
task_run_type: RunType,
organization_id: str,
run_id: str,
title: str | None = None,
url: str | None = None,
url_hash: str | None = None,
status: RunStatus | None = None,
workflow_permanent_id: str | None = None,
parent_workflow_run_id: str | None = None,
debug_session_id: str | None = None,
# script_run, started_at, finished_at are intentionally omitted here —
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
) -> Run:
searchable_text = " ".join(filter(None, [title, url]))
async with self.Session() as session:
task_run = TaskRunModel(
task_run_type=task_run_type,
organization_id=organization_id,
run_id=run_id,
title=title,
url=url,
url_hash=url_hash,
status=status,
workflow_permanent_id=workflow_permanent_id,
parent_workflow_run_id=parent_workflow_run_id,
debug_session_id=debug_session_id,
searchable_text=searchable_text or None,
)
session.add(task_run)
await session.commit()
await session.refresh(task_run)
return Run.model_validate(task_run)
@db_operation("update_task_run")
async def update_task_run(
self,
organization_id: str,
run_id: str,
title: str | None = None,
url: str | None = None,
url_hash: str | None = None,
status: str | None = None,
started_at: datetime | None = None,
finished_at: datetime | None = None,
) -> None:
async with self.Session() as session:
task_run = (
await session.scalars(
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
)
).first()
if not task_run:
raise NotFoundError(f"TaskRun {run_id} not found")
if title is not None:
task_run.title = title
if url is not None:
task_run.url = url
if url_hash is not None:
task_run.url_hash = url_hash
if status is not None:
task_run.status = status
if started_at is not None:
task_run.started_at = started_at
if finished_at is not None:
task_run.finished_at = finished_at
# Recompute searchable_text when title or url changes
if title is not None or url is not None:
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
await session.commit()
@db_operation("update_job_run_compute_cost")
async def update_job_run_compute_cost(
self,
organization_id: str,
run_id: str,
instance_type: str | None = None,
vcpu_millicores: int | None = None,
memory_mb: int | None = None,
duration_ms: int | None = None,
compute_cost: float | None = None,
) -> None:
"""Update compute cost metrics for a job run."""
async with self.Session() as session:
task_run = (
await session.scalars(
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
)
).first()
if not task_run:
LOG.warning(
"TaskRun not found for compute cost update",
run_id=run_id,
organization_id=organization_id,
)
return
if instance_type is not None:
task_run.instance_type = instance_type
if vcpu_millicores is not None:
task_run.vcpu_millicores = vcpu_millicores
if memory_mb is not None:
task_run.memory_mb = memory_mb
if duration_ms is not None:
task_run.duration_ms = duration_ms
if compute_cost is not None:
task_run.compute_cost = compute_cost
await session.commit()
@db_operation("cache_task_run")
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
async with self.Session() as session:
task_run = (
await session.scalars(
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
)
).first()
if task_run:
task_run.cached = True
await session.commit()
await session.refresh(task_run)
return Run.model_validate(task_run)
raise NotFoundError(f"Run {run_id} not found")
@db_operation("get_cached_task_run")
async def get_cached_task_run(
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
) -> Run | None:
async with self.Session() as session:
query = select(TaskRunModel)
if task_run_type:
query = query.filter_by(task_run_type=task_run_type)
if url_hash:
query = query.filter_by(url_hash=url_hash)
if organization_id:
query = query.filter_by(organization_id=organization_id)
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
task_run = (await session.scalars(query)).first()
return Run.model_validate(task_run) if task_run else None
@db_operation("get_run")
async def get_run(
self,
run_id: str,
organization_id: str | None = None,
) -> Run | None:
async with self.Session() as session:
query = select(TaskRunModel).filter_by(run_id=run_id)
if organization_id:
query = query.filter_by(organization_id=organization_id)
task_run = (await session.scalars(query)).first()
return Run.model_validate(task_run) if task_run else None

View file

@ -1,566 +0,0 @@
from __future__ import annotations
import json
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any
import structlog
from sqlalchemy import select
from skyvern.config import settings
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
ActionModel,
AISuggestionModel,
AWSSecretParameterModel,
AzureVaultCredentialParameterModel,
Base,
BitwardenCreditCardDataParameterModel,
BitwardenLoginCredentialParameterModel,
BitwardenSensitiveInformationParameterModel,
CredentialParameterModel,
OnePasswordCredentialParameterModel,
OutputParameterModel,
TaskGenerationModel,
TaskModel,
WorkflowCopilotChatMessageModel,
WorkflowCopilotChatModel,
WorkflowParameterModel,
)
from skyvern.forge.sdk.db.utils import (
convert_to_aws_secret_parameter,
convert_to_output_parameter,
convert_to_workflow_copilot_chat_message,
convert_to_workflow_parameter,
hydrate_action,
)
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotChat,
WorkflowCopilotChatMessage,
WorkflowCopilotChatSender,
)
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
AWSSecretParameter,
AzureVaultCredentialParameter,
BitwardenCreditCardDataParameter,
BitwardenLoginCredentialParameter,
BitwardenSensitiveInformationParameter,
ContextParameter,
CredentialParameter,
OnePasswordCredentialParameter,
OutputParameter,
WorkflowParameter,
WorkflowParameterType,
)
from skyvern.webeye.actions.actions import Action
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
from skyvern.forge.sdk.db._sentinels import _UNSET
LOG = structlog.get_logger()
class WorkflowParametersMixin:
Session: _SessionFactory
debug_enabled: bool
"""Database operations for workflow parameters, copilot chat, task generation, actions, and runs."""
@db_operation("create_workflow_parameter")
async def create_workflow_parameter(
self,
workflow_id: str,
workflow_parameter_type: WorkflowParameterType,
key: str,
default_value: Any,
description: str | None = None,
) -> WorkflowParameter:
async with self.Session() as session:
if default_value is None:
pass
elif workflow_parameter_type == WorkflowParameterType.JSON:
default_value = json.dumps(default_value)
else:
default_value = str(default_value)
workflow_parameter = WorkflowParameterModel(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
default_value=default_value,
description=description,
)
session.add(workflow_parameter)
await session.commit()
await session.refresh(workflow_parameter)
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
@db_operation("create_aws_secret_parameter")
async def create_aws_secret_parameter(
self,
workflow_id: str,
key: str,
aws_key: str,
description: str | None = None,
) -> AWSSecretParameter:
async with self.Session() as session:
aws_secret_parameter = AWSSecretParameterModel(
workflow_id=workflow_id,
key=key,
aws_key=aws_key,
description=description,
)
session.add(aws_secret_parameter)
await session.commit()
await session.refresh(aws_secret_parameter)
return convert_to_aws_secret_parameter(aws_secret_parameter)
@db_operation("create_output_parameter")
async def create_output_parameter(
self,
workflow_id: str,
key: str,
description: str | None = None,
) -> OutputParameter:
async with self.Session() as session:
output_parameter = OutputParameterModel(
key=key,
description=description,
workflow_id=workflow_id,
)
session.add(output_parameter)
await session.commit()
await session.refresh(output_parameter)
return convert_to_output_parameter(output_parameter)
@staticmethod
def _convert_parameter_to_model(parameter: PARAMETER_TYPE) -> Base:
"""Convert a parameter object to its corresponding SQLAlchemy model."""
if isinstance(parameter, WorkflowParameter):
if parameter.default_value is None:
default_value = None
elif parameter.workflow_parameter_type == WorkflowParameterType.JSON:
default_value = json.dumps(parameter.default_value)
else:
default_value = str(parameter.default_value)
return WorkflowParameterModel(
workflow_parameter_id=parameter.workflow_parameter_id,
workflow_parameter_type=parameter.workflow_parameter_type.value,
key=parameter.key,
description=parameter.description,
workflow_id=parameter.workflow_id,
default_value=default_value,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, OutputParameter):
return OutputParameterModel(
output_parameter_id=parameter.output_parameter_id,
key=parameter.key,
description=parameter.description,
workflow_id=parameter.workflow_id,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, AWSSecretParameter):
return AWSSecretParameterModel(
aws_secret_parameter_id=parameter.aws_secret_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
aws_key=parameter.aws_key,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, BitwardenLoginCredentialParameter):
return BitwardenLoginCredentialParameterModel(
bitwarden_login_credential_parameter_id=parameter.bitwarden_login_credential_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
bitwarden_item_id=parameter.bitwarden_item_id,
url_parameter_key=parameter.url_parameter_key,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, BitwardenSensitiveInformationParameter):
return BitwardenSensitiveInformationParameterModel(
bitwarden_sensitive_information_parameter_id=parameter.bitwarden_sensitive_information_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
bitwarden_identity_key=parameter.bitwarden_identity_key,
bitwarden_identity_fields=parameter.bitwarden_identity_fields,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, BitwardenCreditCardDataParameter):
return BitwardenCreditCardDataParameterModel(
bitwarden_credit_card_data_parameter_id=parameter.bitwarden_credit_card_data_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
bitwarden_item_id=parameter.bitwarden_item_id,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, CredentialParameter):
return CredentialParameterModel(
credential_parameter_id=parameter.credential_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
credential_id=parameter.credential_id,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, OnePasswordCredentialParameter):
return OnePasswordCredentialParameterModel(
onepassword_credential_parameter_id=parameter.onepassword_credential_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
vault_id=parameter.vault_id,
item_id=parameter.item_id,
deleted_at=parameter.deleted_at,
)
elif isinstance(parameter, AzureVaultCredentialParameter):
return AzureVaultCredentialParameterModel(
azure_vault_credential_parameter_id=parameter.azure_vault_credential_parameter_id,
workflow_id=parameter.workflow_id,
key=parameter.key,
description=parameter.description,
vault_name=parameter.vault_name,
username_key=parameter.username_key,
password_key=parameter.password_key,
totp_secret_key=parameter.totp_secret_key,
deleted_at=parameter.deleted_at,
)
else:
raise ValueError(f"Unsupported workflow definition parameter type: {type(parameter).__name__}")
@db_operation("save_workflow_definition_parameters")
async def save_workflow_definition_parameters(self, parameters: list[PARAMETER_TYPE]) -> None:
"""Save multiple workflow definition parameters in a single transaction."""
# ContextParameter is not persisted
parameters_to_save = [p for p in parameters if not isinstance(p, ContextParameter)]
if not parameters_to_save:
return
async with self.Session() as session:
for parameter in parameters_to_save:
model = self._convert_parameter_to_model(parameter)
session.add(model)
await session.commit()
@db_operation("get_workflow_output_parameters")
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
async with self.Session() as session:
output_parameters = (
await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id))
).all()
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
@db_operation("get_workflow_output_parameters_by_ids")
async def get_workflow_output_parameters_by_ids(self, output_parameter_ids: list[str]) -> list[OutputParameter]:
async with self.Session() as session:
output_parameters = (
await session.scalars(
select(OutputParameterModel).filter(
OutputParameterModel.output_parameter_id.in_(output_parameter_ids)
)
)
).all()
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
@db_operation("get_workflow_parameters")
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
async with self.Session() as session:
workflow_parameters = (
await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id))
).all()
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
@db_operation("get_workflow_parameter")
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
async with self.Session() as session:
if workflow_parameter := (
await session.scalars(
select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id)
)
).first():
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
return None
@db_operation("create_task_generation")
async def create_task_generation(
self,
organization_id: str,
user_prompt: str,
user_prompt_hash: str,
url: str | None = None,
navigation_goal: str | None = None,
navigation_payload: dict[str, Any] | None = None,
data_extraction_goal: str | None = None,
extracted_information_schema: dict[str, Any] | None = None,
suggested_title: str | None = None,
llm: str | None = None,
llm_prompt: str | None = None,
llm_response: str | None = None,
source_task_generation_id: str | None = None,
) -> TaskGeneration:
async with self.Session() as session:
new_task_generation = TaskGenerationModel(
organization_id=organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=extracted_information_schema,
llm=llm,
llm_prompt=llm_prompt,
llm_response=llm_response,
suggested_title=suggested_title,
source_task_generation_id=source_task_generation_id,
)
session.add(new_task_generation)
await session.commit()
await session.refresh(new_task_generation)
return TaskGeneration.model_validate(new_task_generation)
@db_operation("create_ai_suggestion")
async def create_ai_suggestion(
self,
organization_id: str,
ai_suggestion_type: str,
) -> AISuggestion:
async with self.Session() as session:
new_ai_suggestion = AISuggestionModel(
organization_id=organization_id,
ai_suggestion_type=ai_suggestion_type,
)
session.add(new_ai_suggestion)
await session.commit()
await session.refresh(new_ai_suggestion)
return AISuggestion.model_validate(new_ai_suggestion)
@db_operation("create_workflow_copilot_chat")
async def create_workflow_copilot_chat(
self,
organization_id: str,
workflow_permanent_id: str,
) -> WorkflowCopilotChat:
async with self.Session() as session:
new_chat = WorkflowCopilotChatModel(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
)
session.add(new_chat)
await session.commit()
await session.refresh(new_chat)
return WorkflowCopilotChat.model_validate(new_chat)
@db_operation("update_workflow_copilot_chat")
async def update_workflow_copilot_chat(
self,
organization_id: str,
workflow_copilot_chat_id: str,
proposed_workflow: dict | None | object = _UNSET,
auto_accept: bool | None = None,
) -> WorkflowCopilotChat | None:
async with self.Session() as session:
chat = (
await session.scalars(
select(WorkflowCopilotChatModel)
.where(WorkflowCopilotChatModel.organization_id == organization_id)
.where(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
)
).first()
if not chat:
return None
if proposed_workflow is not _UNSET:
chat.proposed_workflow = proposed_workflow
if auto_accept is not None:
chat.auto_accept = auto_accept
await session.commit()
await session.refresh(chat)
return WorkflowCopilotChat.model_validate(chat)
@db_operation("create_workflow_copilot_chat_message")
async def create_workflow_copilot_chat_message(
self,
organization_id: str,
workflow_copilot_chat_id: str,
sender: WorkflowCopilotChatSender,
content: str,
global_llm_context: str | None = None,
) -> WorkflowCopilotChatMessage:
async with self.Session() as session:
new_message = WorkflowCopilotChatMessageModel(
workflow_copilot_chat_id=workflow_copilot_chat_id,
organization_id=organization_id,
sender=sender,
content=content,
global_llm_context=global_llm_context,
)
session.add(new_message)
await session.commit()
await session.refresh(new_message)
return convert_to_workflow_copilot_chat_message(new_message, self.debug_enabled)
@db_operation("get_workflow_copilot_chat_messages")
async def get_workflow_copilot_chat_messages(
self,
workflow_copilot_chat_id: str,
) -> list[WorkflowCopilotChatMessage]:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatMessageModel)
.filter(WorkflowCopilotChatMessageModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
.order_by(WorkflowCopilotChatMessageModel.workflow_copilot_chat_message_id.asc())
)
messages = (await session.scalars(query)).all()
return [convert_to_workflow_copilot_chat_message(message, self.debug_enabled) for message in messages]
@db_operation("get_workflow_copilot_chat_by_id")
async def get_workflow_copilot_chat_by_id(
self,
organization_id: str,
workflow_copilot_chat_id: str,
) -> WorkflowCopilotChat | None:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatModel)
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
.filter(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
.order_by(WorkflowCopilotChatModel.created_at.desc())
.limit(1)
)
chat = (await session.scalars(query)).first()
if not chat:
return None
return WorkflowCopilotChat.model_validate(chat)
@db_operation("get_latest_workflow_copilot_chat")
async def get_latest_workflow_copilot_chat(
self,
organization_id: str,
workflow_permanent_id: str,
) -> WorkflowCopilotChat | None:
async with self.Session() as session:
query = (
select(WorkflowCopilotChatModel)
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
.filter(WorkflowCopilotChatModel.workflow_permanent_id == workflow_permanent_id)
.order_by(WorkflowCopilotChatModel.created_at.desc())
.limit(1)
)
chat = (await session.scalars(query)).first()
if not chat:
return None
return WorkflowCopilotChat.model_validate(chat)
@db_operation("get_task_generation_by_prompt_hash")
async def get_task_generation_by_prompt_hash(
self,
user_prompt_hash: str,
query_window_hours: int = settings.PROMPT_CACHE_WINDOW_HOURS,
) -> TaskGeneration | None:
before_time = datetime.utcnow() - timedelta(hours=query_window_hours)
async with self.Session() as session:
query = (
select(TaskGenerationModel)
.filter_by(user_prompt_hash=user_prompt_hash)
.filter(TaskGenerationModel.llm.is_not(None))
.filter(TaskGenerationModel.created_at > before_time)
)
task_generation = (await session.scalars(query)).first()
if not task_generation:
return None
return TaskGeneration.model_validate(task_generation)
@db_operation("create_action")
async def create_action(self, action: Action) -> Action:
async with self.Session() as session:
new_action = ActionModel(
action_type=action.action_type,
source_action_id=action.source_action_id,
organization_id=action.organization_id,
workflow_run_id=action.workflow_run_id,
task_id=action.task_id,
step_id=action.step_id,
step_order=action.step_order,
action_order=action.action_order,
status=action.status,
reasoning=action.reasoning,
intention=action.intention,
response=action.response,
element_id=action.element_id,
skyvern_element_hash=action.skyvern_element_hash,
skyvern_element_data=action.skyvern_element_data,
screenshot_artifact_id=action.screenshot_artifact_id,
action_json=action.model_dump(),
confidence_float=action.confidence_float,
created_by=action.created_by,
)
session.add(new_action)
await session.commit()
await session.refresh(new_action)
return hydrate_action(new_action)
@db_operation("update_action_reasoning")
async def update_action_reasoning(
self,
organization_id: str,
action_id: str,
reasoning: str,
) -> Action:
async with self.Session() as session:
action = (
await session.scalars(
select(ActionModel).filter_by(action_id=action_id).filter_by(organization_id=organization_id)
)
).first()
if action:
action.reasoning = reasoning
await session.commit()
await session.refresh(action)
return Action.model_validate(action)
raise NotFoundError(f"Action {action_id}")
@db_operation("retrieve_action_plan")
async def retrieve_action_plan(self, task: Task) -> list[Action]:
async with self.Session() as session:
subquery = (
select(TaskModel.task_id)
.filter(TaskModel.url == task.url)
.filter(TaskModel.navigation_goal == task.navigation_goal)
.filter(TaskModel.status == TaskStatus.completed)
.order_by(TaskModel.created_at.desc())
.limit(1)
.subquery()
)
query = (
select(ActionModel)
.filter(ActionModel.task_id == subquery.c.task_id)
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
)
actions = (await session.scalars(query)).all()
return [Action.model_validate(action) for action in actions]

View file

@ -1,974 +0,0 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any
import structlog
from sqlalchemy import Text, and_, cast, exists, func, literal, literal_column, or_, select, update
from sqlalchemy.dialects.postgresql import JSONB
from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.mixins.base import read_retry
from skyvern.forge.sdk.db.models import (
TaskModel,
TaskRunModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunBlockModel,
WorkflowRunModel,
WorkflowRunOutputParameterModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.db.utils import (
convert_to_task,
convert_to_workflow_run,
convert_to_workflow_run_output_parameter,
convert_to_workflow_run_parameter,
serialize_proxy_location,
)
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter
from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunStatus,
WorkflowRunTriggerType,
)
from skyvern.schemas.runs import ProxyLocationInput, RunType
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
from skyvern.forge.sdk.db._sentinels import _UNSET
LOG = structlog.get_logger()
class WorkflowRunsMixin:
"""Database operations for workflow runs."""
Session: _SessionFactory
engine: AsyncEngine
debug_enabled: bool
@db_operation("get_running_workflow_runs_info_globally")
async def get_running_workflow_runs_info_globally(
self,
stale_threshold_hours: int = 24,
) -> tuple[int, int]:
"""
Get information about running workflow runs across all organizations.
Used by cleanup service to determine if cleanup should be skipped.
Args:
stale_threshold_hours: Workflow runs not updated for this many hours are considered stale.
Returns:
Tuple of (active_workflow_count, stale_workflow_count).
Active workflows are those updated within the threshold.
Stale workflows are those not updated within the threshold but still in running status.
"""
async with self.Session() as session:
running_statuses = [
WorkflowRunStatus.created,
WorkflowRunStatus.queued,
WorkflowRunStatus.running,
WorkflowRunStatus.paused,
]
stale_cutoff = datetime.utcnow() - timedelta(hours=stale_threshold_hours)
# Count active workflow runs (recently updated)
active_query = (
select(func.count())
.select_from(WorkflowRunModel)
.filter(WorkflowRunModel.status.in_(running_statuses))
.filter(WorkflowRunModel.modified_at >= stale_cutoff)
)
active_count = (await session.execute(active_query)).scalar_one()
# Count stale workflow runs (not updated for a long time)
stale_query = (
select(func.count())
.select_from(WorkflowRunModel)
.filter(WorkflowRunModel.status.in_(running_statuses))
.filter(WorkflowRunModel.modified_at < stale_cutoff)
)
stale_count = (await session.execute(stale_query)).scalar_one()
return (active_count, stale_count)
@db_operation("create_workflow_run")
async def create_workflow_run(
self,
workflow_permanent_id: str,
workflow_id: str,
organization_id: str,
browser_session_id: str | None = None,
browser_profile_id: str | None = None,
proxy_location: ProxyLocationInput = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
parent_workflow_run_id: str | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
browser_address: str | None = None,
sequential_key: str | None = None,
run_with: str | None = None,
debug_session_id: str | None = None,
ai_fallback: bool | None = None,
code_gen: bool | None = None,
workflow_run_id: str | None = None,
trigger_type: WorkflowRunTriggerType | None = None,
workflow_schedule_id: str | None = None,
) -> WorkflowRun:
async with self.Session() as session:
kwargs: dict[str, Any] = {}
if workflow_run_id is not None:
kwargs["workflow_run_id"] = workflow_run_id
workflow_run = WorkflowRunModel(
workflow_permanent_id=workflow_permanent_id,
workflow_id=workflow_id,
organization_id=organization_id,
browser_session_id=browser_session_id,
browser_profile_id=browser_profile_id,
proxy_location=serialize_proxy_location(proxy_location),
status="created",
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
parent_workflow_run_id=parent_workflow_run_id,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
browser_address=browser_address,
sequential_key=sequential_key,
run_with=run_with,
debug_session_id=debug_session_id,
ai_fallback=ai_fallback,
code_gen=code_gen,
trigger_type=trigger_type.value if trigger_type else None,
workflow_schedule_id=workflow_schedule_id,
**kwargs,
)
session.add(workflow_run)
await session.commit()
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
@db_operation("update_workflow_run")
async def update_workflow_run(
self,
workflow_run_id: str,
status: WorkflowRunStatus | None = None,
failure_reason: str | None = None,
webhook_failure_reason: str | None = None,
ai_fallback_triggered: bool | None = None,
job_id: str | None = None,
run_with: str | None = None,
sequential_key: str | None = None,
ai_fallback: bool | None = None,
depends_on_workflow_run_id: str | None = None,
browser_session_id: str | None = None,
waiting_for_verification_code: bool | None = None,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
browser_profile_id: str | None | object = _UNSET,
browser_address: str | None = None,
extra_http_headers: dict[str, str] | None = None,
failure_category: list[dict[str, Any]] | None = None,
) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
).first()
if workflow_run:
if status:
workflow_run.status = status
if status and status == WorkflowRunStatus.queued and workflow_run.queued_at is None:
workflow_run.queued_at = datetime.utcnow()
if status and status == WorkflowRunStatus.running and workflow_run.started_at is None:
workflow_run.started_at = datetime.utcnow()
if status and status.is_final() and workflow_run.finished_at is None:
workflow_run.finished_at = datetime.utcnow()
if failure_reason:
workflow_run.failure_reason = failure_reason
if webhook_failure_reason is not None:
workflow_run.webhook_failure_reason = webhook_failure_reason
if ai_fallback_triggered is not None:
workflow_run.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
if job_id:
workflow_run.job_id = job_id
if run_with:
workflow_run.run_with = run_with
if sequential_key:
workflow_run.sequential_key = sequential_key
if ai_fallback is not None:
workflow_run.ai_fallback = ai_fallback
if depends_on_workflow_run_id:
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
if browser_session_id:
workflow_run.browser_session_id = browser_session_id
if browser_address:
workflow_run.browser_address = browser_address
if extra_http_headers:
workflow_run.extra_http_headers = extra_http_headers
# 2FA verification code waiting state updates
if waiting_for_verification_code is not None:
workflow_run.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
workflow_run.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
if waiting_for_verification_code is not None and not waiting_for_verification_code:
# Clear related fields when waiting is set to False
workflow_run.verification_code_identifier = None
workflow_run.verification_code_polling_started_at = None
if browser_profile_id is not _UNSET:
workflow_run.browser_profile_id = browser_profile_id
if failure_category is not None:
workflow_run.failure_category = failure_category
await session.commit()
await save_workflow_run_logs(workflow_run_id)
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
else:
raise WorkflowRunNotFound(workflow_run_id)
@db_operation("bulk_update_workflow_runs")
async def bulk_update_workflow_runs(
self,
workflow_run_ids: list[str],
status: WorkflowRunStatus | None = None,
failure_reason: str | None = None,
) -> None:
"""Bulk update workflow runs by their IDs.
Args:
workflow_run_ids: List of workflow run IDs to update
status: Optional status to set for all workflow runs
failure_reason: Optional failure reason to set for all workflow runs
"""
if not workflow_run_ids:
return
async with self.Session() as session:
update_values = {}
if status:
update_values["status"] = status.value
if failure_reason:
update_values["failure_reason"] = failure_reason
if update_values:
update_stmt = (
update(WorkflowRunModel)
.where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
.values(**update_values)
)
await session.execute(update_stmt)
await session.commit()
@db_operation("clear_workflow_run_failure_reason")
async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
await session.scalars(
select(WorkflowRunModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(organization_id=organization_id)
)
).first()
if workflow_run:
workflow_run.failure_reason = None
await session.commit()
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run)
else:
raise NotFoundError("Workflow run not found")
@db_operation("get_all_runs")
async def get_all_runs(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
status: list[WorkflowRunStatus] | None = None,
include_debugger_runs: bool = False,
search_key: str | None = None,
) -> list[WorkflowRun | Task]:
async with self.Session() as session:
# temporary limit to 10 pages
if page > 10:
return []
limit = page * page_size
workflow_run_query = (
select(WorkflowRunModel, WorkflowModel.title)
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
.filter(WorkflowRunModel.organization_id == organization_id)
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
)
if not include_debugger_runs:
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.debug_session_id.is_(None))
if search_key:
key_like = f"%{search_key}%"
# Match workflow_run_id directly
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
# Match parameter key or description (only for non-deleted parameter definitions)
param_key_desc_exists = exists(
select(1)
.select_from(WorkflowRunParameterModel)
.join(
WorkflowParameterModel,
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
)
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(WorkflowParameterModel.deleted_at.is_(None))
.where(
or_(
WorkflowParameterModel.key.ilike(key_like),
WorkflowParameterModel.description.ilike(key_like),
)
)
)
# Match run parameter value directly (searches all values regardless of parameter definition status)
param_value_exists = exists(
select(1)
.select_from(WorkflowRunParameterModel)
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(WorkflowRunParameterModel.value.ilike(key_like))
)
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
extra_headers_match = and_(
WorkflowRunModel.extra_http_headers.isnot(None),
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
)
workflow_run_query = workflow_run_query.where(
or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match)
)
if status:
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
workflow_run_query_result = (await session.execute(workflow_run_query)).all()
workflow_runs = [
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
for run, title in workflow_run_query_result
]
task_query = (
select(TaskModel)
.filter(TaskModel.organization_id == organization_id)
.filter(TaskModel.workflow_run_id.is_(None))
)
if status:
task_query = task_query.filter(TaskModel.status.in_(status))
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
task_query_result = (await session.scalars(task_query)).all()
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
runs = workflow_runs + tasks
runs.sort(key=lambda x: x.created_at, reverse=True)
lower = (page - 1) * page_size
upper = page * page_size
return runs[lower:upper]
@read_retry()
async def get_all_runs_v2(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
status: list[str] | None = None,
search_key: str | None = None,
) -> list[dict[str, Any]]:
async with self.Session() as session:
effective_status = func.coalesce(WorkflowRunModel.status, TaskRunModel.status)
query = (
select(
TaskRunModel.task_run_id.label("task_run_id"),
TaskRunModel.run_id.label("run_id"),
TaskRunModel.task_run_type.label("task_run_type"),
effective_status.label("status"),
TaskRunModel.title.label("title"),
TaskRunModel.started_at.label("started_at"),
TaskRunModel.finished_at.label("finished_at"),
TaskRunModel.created_at.label("created_at"),
TaskRunModel.workflow_permanent_id.label("workflow_permanent_id"),
TaskRunModel.script_run.label("script_run"),
TaskRunModel.searchable_text.label("searchable_text"),
)
.select_from(TaskRunModel)
.outerjoin(
WorkflowRunModel,
and_(
TaskRunModel.task_run_type == RunType.workflow_run,
WorkflowRunModel.workflow_run_id == TaskRunModel.run_id,
WorkflowRunModel.organization_id == TaskRunModel.organization_id,
),
)
.filter(TaskRunModel.organization_id == organization_id)
.filter(TaskRunModel.status.isnot(None))
.filter(TaskRunModel.parent_workflow_run_id.is_(None))
.filter(TaskRunModel.debug_session_id.is_(None))
)
if status:
query = query.filter(effective_status.in_(status))
if search_key:
query = query.filter(TaskRunModel.searchable_text.icontains(search_key, autoescape=True))
offset = (page - 1) * page_size
query = query.order_by(TaskRunModel.created_at.desc()).offset(offset).limit(page_size)
result = await session.execute(query)
return [dict(row) for row in result.mappings().all()]
@read_retry()
@db_operation("get_workflow_run", log_errors=False)
async def get_workflow_run(
self,
workflow_run_id: str,
organization_id: str | None = None,
job_id: str | None = None,
status: WorkflowRunStatus | None = None,
) -> WorkflowRun | None:
async with self.Session() as session:
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
if organization_id:
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
if job_id:
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
if status:
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
return convert_to_workflow_run(workflow_run)
return None
@db_operation("get_last_queued_workflow_run")
async def get_last_queued_workflow_run(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
sequential_key: str | None = None,
) -> WorkflowRun | None:
async with self.Session() as session:
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
if organization_id:
query = query.filter_by(organization_id=organization_id)
query = query.filter_by(status=WorkflowRunStatus.queued)
if sequential_key:
query = query.filter_by(sequential_key=sequential_key)
query = query.order_by(WorkflowRunModel.modified_at.desc())
workflow_run = (await session.scalars(query)).first()
return convert_to_workflow_run(workflow_run) if workflow_run else None
@db_operation("get_workflow_runs_by_ids")
async def get_workflow_runs_by_ids(
self,
workflow_run_ids: list[str],
workflow_permanent_id: str | None = None,
organization_id: str | None = None,
) -> list[WorkflowRun]:
async with self.Session() as session:
query = select(WorkflowRunModel).filter(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
if workflow_permanent_id:
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
if organization_id:
query = query.filter_by(organization_id=organization_id)
workflow_runs = (await session.scalars(query)).all()
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
@db_operation("get_last_running_workflow_run")
async def get_last_running_workflow_run(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
sequential_key: str | None = None,
) -> WorkflowRun | None:
async with self.Session() as session:
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
if organization_id:
query = query.filter_by(organization_id=organization_id)
query = query.filter_by(status=WorkflowRunStatus.running)
if sequential_key:
query = query.filter_by(sequential_key=sequential_key)
query = query.filter(
WorkflowRunModel.started_at.isnot(None)
) # filter out workflow runs that does not have a started_at timestamp
query = query.order_by(WorkflowRunModel.started_at.desc())
workflow_run = (await session.scalars(query)).first()
return convert_to_workflow_run(workflow_run) if workflow_run else None
@db_operation("get_workflows_depending_on")
async def get_workflows_depending_on(
self,
workflow_run_id: str,
) -> list[WorkflowRun]:
"""
Get all workflow runs that depend on the given workflow_run_id.
Used to find workflows that should be signaled when a workflow completes,
for sequential workflow dependency handling.
Args:
workflow_run_id: The workflow_run_id to find dependents for
Returns:
List of WorkflowRun objects that have depends_on_workflow_run_id set to workflow_run_id
"""
async with self.Session() as session:
query = select(WorkflowRunModel).filter_by(depends_on_workflow_run_id=workflow_run_id)
workflow_runs = (await session.scalars(query)).all()
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
@staticmethod
def _apply_search_key_filter(query, search_key: str | None): # type: ignore[no-untyped-def]
if not search_key:
return query
key_like = f"%{search_key}%"
# Match workflow_run_id directly
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
# Match parameter key or description (only for non-deleted parameter definitions)
# Use EXISTS to avoid duplicate rows and to keep pagination correct
param_key_desc_exists = exists(
select(1)
.select_from(WorkflowRunParameterModel)
.join(
WorkflowParameterModel,
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
)
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(WorkflowParameterModel.deleted_at.is_(None))
.where(
or_(
WorkflowParameterModel.key.ilike(key_like),
WorkflowParameterModel.description.ilike(key_like),
)
)
)
# Match run parameter value directly (searches all values regardless of parameter definition status)
param_value_exists = exists(
select(1)
.select_from(WorkflowRunParameterModel)
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(WorkflowRunParameterModel.value.ilike(key_like))
)
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
extra_headers_match = and_(
WorkflowRunModel.extra_http_headers.isnot(None),
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
)
return query.where(or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match))
def _apply_error_code_filter(self, query, error_code: str | None): # type: ignore[no-untyped-def]
if not error_code:
return query
dialect_name = self.engine.dialect.name
if dialect_name == "sqlite":
# Task errors: array of objects like [{"error_code": "timeout", ...}]
# Use json_each to iterate + json_extract to match the error_code field
error_code_in_tasks = exists(
select(1)
.select_from(TaskModel)
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(
exists(
select(1)
.select_from(func.json_each(TaskModel.errors))
.where(func.json_extract(literal_column("json_each.value"), "$.error_code") == error_code)
)
)
)
# Block errors: flat array of strings like ["timeout", "network_error"]
error_code_in_blocks = exists(
select(1)
.select_from(WorkflowRunBlockModel)
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(
exists(
select(1)
.select_from(func.json_each(WorkflowRunBlockModel.error_codes))
.where(literal_column("json_each.value") == error_code)
)
)
)
else:
# PostgreSQL: native JSONB containment
error_code_in_tasks = exists(
select(1)
.select_from(TaskModel)
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(cast(TaskModel.errors, JSONB).contains(literal([{"error_code": error_code}], type_=JSONB)))
)
error_code_in_blocks = exists(
select(1)
.select_from(WorkflowRunBlockModel)
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
.where(cast(WorkflowRunBlockModel.error_codes, JSONB).contains(literal([error_code], type_=JSONB)))
)
return query.where(or_(error_code_in_tasks, error_code_in_blocks))
@db_operation("get_workflow_runs")
async def get_workflow_runs(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
status: list[WorkflowRunStatus] | None = None,
ordering: tuple[str, str] | None = None,
search_key: str | None = None,
error_code: str | None = None,
) -> list[WorkflowRun]:
async with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
query = (
select(WorkflowRunModel, WorkflowModel.title)
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
.filter(WorkflowRunModel.organization_id == organization_id)
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
)
query = self._apply_search_key_filter(query, search_key)
query = self._apply_error_code_filter(query, error_code)
if status:
query = query.filter(WorkflowRunModel.status.in_(status))
allowed_ordering_fields = {
"created_at": WorkflowRunModel.created_at,
"status": WorkflowRunModel.status,
}
field, direction = ("created_at", "desc")
if ordering and isinstance(ordering, tuple) and len(ordering) == 2:
req_field, req_direction = ordering
if req_field in allowed_ordering_fields and req_direction in ("asc", "desc"):
field, direction = req_field, req_direction
order_column = allowed_ordering_fields[field]
if direction == "asc":
query = query.order_by(order_column.asc())
else:
query = query.order_by(order_column.desc())
query = query.limit(page_size).offset(db_page * page_size)
workflow_runs = (await session.execute(query)).all()
return [
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
for run, title in workflow_runs
]
@db_operation("get_workflow_runs_count")
async def get_workflow_runs_count(
self,
organization_id: str,
status: list[WorkflowRunStatus] | None = None,
) -> int:
async with self.Session() as session:
count_query = (
select(func.count())
.select_from(WorkflowRunModel)
.filter(WorkflowRunModel.organization_id == organization_id)
)
if status:
count_query = count_query.filter(WorkflowRunModel.status.in_(status))
return (await session.execute(count_query)).scalar_one()
@db_operation("get_workflow_run_block_errors")
async def get_workflow_run_block_errors(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> list[tuple[list[str], str | None]]:
"""Return (error_codes, failure_reason) tuples for blocks with non-null error_codes."""
async with self.Session() as session:
query = select(WorkflowRunBlockModel.error_codes, WorkflowRunBlockModel.failure_reason).filter_by(
workflow_run_id=workflow_run_id
)
if organization_id is not None:
query = query.filter_by(organization_id=organization_id)
query = query.where(WorkflowRunBlockModel.error_codes.isnot(None))
rows = (await session.execute(query)).all()
return [(row.error_codes, row.failure_reason) for row in rows]
@db_operation("get_workflow_runs_for_workflow_permanent_id")
async def get_workflow_runs_for_workflow_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str,
page: int = 1,
page_size: int = 10,
status: list[WorkflowRunStatus] | None = None,
search_key: str | None = None,
error_code: str | None = None,
) -> list[WorkflowRun]:
"""
Get runs for a workflow, with optional `search_key` on run ID, parameter key/description/value,
or extra HTTP headers.
"""
async with self.Session() as session:
db_page = page - 1 # offset logic is 0 based
query = (
select(WorkflowRunModel, WorkflowModel.title)
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
.filter(WorkflowRunModel.organization_id == organization_id)
)
query = self._apply_search_key_filter(query, search_key)
query = self._apply_error_code_filter(query, error_code)
if status:
query = query.filter(WorkflowRunModel.status.in_(status))
query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
workflow_runs_and_titles_tuples = (await session.execute(query)).all()
workflow_runs = [
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
for run, title in workflow_runs_and_titles_tuples
]
return workflow_runs
@db_operation("get_workflow_runs_by_parent_workflow_run_id")
async def get_workflow_runs_by_parent_workflow_run_id(
self,
parent_workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRun]:
async with self.Session() as session:
query = select(WorkflowRunModel).filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id)
if organization_id is not None:
query = query.filter(WorkflowRunModel.organization_id == organization_id)
workflow_runs = (await session.scalars(query)).all()
return [convert_to_workflow_run(run) for run in workflow_runs]
@db_operation("get_workflow_run_output_parameters")
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
async with self.Session() as session:
workflow_run_output_parameters = (
await session.scalars(
select(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(WorkflowRunOutputParameterModel.created_at)
)
).all()
return [
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
for parameter in workflow_run_output_parameters
]
@db_operation("get_workflow_run_output_parameter_by_id")
async def get_workflow_run_output_parameter_by_id(
self, workflow_run_id: str, output_parameter_id: str
) -> WorkflowRunOutputParameter | None:
async with self.Session() as session:
parameter = (
await session.scalars(
select(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(output_parameter_id=output_parameter_id)
.order_by(WorkflowRunOutputParameterModel.created_at)
)
).first()
if parameter:
return convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
return None
@db_operation("create_or_update_workflow_run_output_parameter")
async def create_or_update_workflow_run_output_parameter(
self,
workflow_run_id: str,
output_parameter_id: str,
value: dict[str, Any] | list | str | None,
) -> WorkflowRunOutputParameter:
async with self.Session() as session:
# check if the workflow run output parameter already exists
# if it does, update the value
if workflow_run_output_parameter := (
await session.scalars(
select(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(output_parameter_id=output_parameter_id)
)
).first():
LOG.info(
"Updating existing workflow run output parameter",
workflow_run_id=workflow_run_output_parameter.workflow_run_id,
output_parameter_id=workflow_run_output_parameter.output_parameter_id,
)
workflow_run_output_parameter.value = value
await session.commit()
await session.refresh(workflow_run_output_parameter)
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
# if it does not exist, create a new one
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
workflow_run_id=workflow_run_id,
output_parameter_id=output_parameter_id,
value=value,
)
session.add(workflow_run_output_parameter)
await session.commit()
await session.refresh(workflow_run_output_parameter)
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
@db_operation("update_workflow_run_output_parameter")
async def update_workflow_run_output_parameter(
self,
workflow_run_id: str,
output_parameter_id: str,
value: dict[str, Any] | list | str | None,
) -> WorkflowRunOutputParameter:
async with self.Session() as session:
workflow_run_output_parameter = (
await session.scalars(
select(WorkflowRunOutputParameterModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(output_parameter_id=output_parameter_id)
)
).first()
if not workflow_run_output_parameter:
raise NotFoundError(
f"WorkflowRunOutputParameter not found for {workflow_run_id} and {output_parameter_id}"
)
workflow_run_output_parameter.value = value
await session.commit()
await session.refresh(workflow_run_output_parameter)
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
@db_operation("create_workflow_run_parameter")
async def create_workflow_run_parameter(
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
) -> WorkflowRunParameter:
workflow_parameter_id = workflow_parameter.workflow_parameter_id
async with self.Session() as session:
workflow_run_parameter = WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter_id,
value=value,
)
session.add(workflow_run_parameter)
await session.flush()
converted = convert_to_workflow_run_parameter(
workflow_run_parameter, workflow_parameter, self.debug_enabled
)
await session.commit()
return converted
@db_operation("create_workflow_run_parameters")
async def create_workflow_run_parameters(
self,
workflow_run_id: str,
workflow_parameter_values: list[tuple[WorkflowParameter, Any]],
) -> list[WorkflowRunParameter]:
if not workflow_parameter_values:
return []
workflow_run_parameters = [
WorkflowRunParameterModel(
workflow_run_id=workflow_run_id,
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
value=value,
)
for workflow_parameter, value in workflow_parameter_values
]
async with self.Session() as session:
session.add_all(workflow_run_parameters)
await session.flush()
converted = [
convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
for workflow_run_parameter, (workflow_parameter, _) in zip(
workflow_run_parameters, workflow_parameter_values, strict=True
)
]
await session.commit()
return converted
@db_operation("get_workflow_run_parameters")
async def get_workflow_run_parameters(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
async with self.Session() as session:
workflow_run_parameters = (
await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id))
).all()
results = []
for workflow_run_parameter in workflow_run_parameters:
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id) # type: ignore[attr-defined]
if not workflow_parameter:
raise WorkflowParameterNotFound(workflow_parameter_id=workflow_run_parameter.workflow_parameter_id)
results.append(
(
workflow_parameter,
convert_to_workflow_run_parameter(
workflow_run_parameter,
workflow_parameter,
self.debug_enabled,
),
)
)
return results
@db_operation("_get_last_workflow_run_by_filter")
async def _get_last_workflow_run_by_filter(
self,
organization_id: str | None = None,
**filters: str,
) -> WorkflowRun | None:
"""Get the last queued or running workflow run matching the given column filters.
Used for browser_session_id and browser_address sequential execution.
"""
async with self.Session() as session:
query = select(WorkflowRunModel).filter_by(**filters)
if organization_id:
query = query.filter_by(organization_id=organization_id)
# check if there's a queued run
queue_query = query.filter_by(status=WorkflowRunStatus.queued)
queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc())
workflow_run = (await session.scalars(queue_query)).first()
if workflow_run:
return convert_to_workflow_run(workflow_run)
# check if there's a running run
running_query = query.filter_by(status=WorkflowRunStatus.running)
running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None))
running_query = running_query.order_by(WorkflowRunModel.started_at.desc())
workflow_run = (await session.scalars(running_query)).first()
if workflow_run:
return convert_to_workflow_run(workflow_run)
return None
@db_operation("get_last_workflow_run_for_browser_address")
async def get_last_workflow_run_for_browser_address(
self,
browser_address: str,
organization_id: str | None = None,
) -> WorkflowRun | None:
return await self._get_last_workflow_run_by_filter(
organization_id=organization_id,
browser_address=browser_address,
)

View file

@ -1,695 +0,0 @@
# TODO: Standardize soft-delete filtering — some methods use exclude_deleted(), others use
# inline .filter(WorkflowModel.deleted_at.is_(None)). Migrate all to exclude_deleted()
# where possible (some join/filter-order cases require inline).
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Any
import structlog
from sqlalchemy import exists, func, or_, select, update
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
from skyvern.forge.sdk.db._error_handling import db_operation
from skyvern.forge.sdk.db._soft_delete import exclude_deleted
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
AWSSecretParameterModel,
AzureVaultCredentialParameterModel,
BitwardenCreditCardDataParameterModel,
BitwardenLoginCredentialParameterModel,
BitwardenSensitiveInformationParameterModel,
CredentialParameterModel,
FolderModel,
OnePasswordCredentialParameterModel,
OutputParameterModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
WorkflowTemplateModel,
)
from skyvern.forge.sdk.db.utils import convert_to_workflow, serialize_proxy_location
from skyvern.forge.sdk.workflow.models.workflow import Workflow
from skyvern.schemas.runs import ProxyLocationInput
from skyvern.schemas.workflows import WorkflowStatus
if TYPE_CHECKING:
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
LOG = structlog.get_logger()
class WorkflowsMixin:
"""Database operations for workflow management."""
Session: _SessionFactory
debug_enabled: bool
@db_operation("create_workflow")
async def create_workflow(
self,
title: str,
workflow_definition: dict[str, Any],
organization_id: str | None = None,
description: str | None = None,
proxy_location: ProxyLocationInput = None,
webhook_callback_url: str | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
persist_browser_session: bool = False,
model: dict[str, Any] | None = None,
workflow_permanent_id: str | None = None,
version: int | None = None,
is_saved_task: bool = False,
status: WorkflowStatus = WorkflowStatus.published,
run_with: str | None = None,
ai_fallback: bool = True,
cache_key: str | None = None,
adaptive_caching: bool = False,
code_version: int | None = None,
generate_script_on_terminal: bool = False,
run_sequentially: bool = False,
sequential_key: str | None = None,
folder_id: str | None = None,
) -> Workflow:
async with self.Session() as session:
workflow = WorkflowModel(
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition,
proxy_location=serialize_proxy_location(proxy_location),
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
extra_http_headers=extra_http_headers,
persist_browser_session=persist_browser_session,
model=model,
is_saved_task=is_saved_task,
status=status,
run_with=run_with,
ai_fallback=ai_fallback,
cache_key=cache_key or DEFAULT_SCRIPT_RUN_ID,
adaptive_caching=adaptive_caching,
code_version=code_version,
generate_script_on_terminal=generate_script_on_terminal,
run_sequentially=run_sequentially,
sequential_key=sequential_key,
folder_id=folder_id,
)
if workflow_permanent_id:
workflow.workflow_permanent_id = workflow_permanent_id
if version:
workflow.version = version
session.add(workflow)
# Update folder's modified_at if folder_id is provided
if folder_id:
# Validate folder exists and belongs to the same organization
folder_stmt = (
select(FolderModel)
.where(FolderModel.folder_id == folder_id)
.where(FolderModel.organization_id == organization_id)
.where(FolderModel.deleted_at.is_(None))
)
folder_model = await session.scalar(folder_stmt)
if not folder_model:
raise ValueError(
f"Folder {folder_id} not found or does not belong to organization {organization_id}"
)
folder_model.modified_at = datetime.utcnow()
await session.commit()
await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled)
@db_operation("soft_delete_workflow_by_id")
async def soft_delete_workflow_by_id(self, workflow_id: str, organization_id: str) -> None:
async with self.Session() as session:
# soft delete the workflow by setting the deleted_at field to the current time
update_deleted_at_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_id == workflow_id)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.values(deleted_at=datetime.utcnow())
)
await session.execute(update_deleted_at_query)
await session.commit()
@db_operation("get_workflow")
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
async with self.Session() as session:
get_workflow_query = exclude_deleted(
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
is_template = (
await self.is_workflow_template(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
)
if organization_id
else False
)
return convert_to_workflow(
workflow,
self.debug_enabled,
is_template=is_template,
)
return None
@db_operation("get_workflow_by_permanent_id")
async def get_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
version: int | None = None,
ignore_version: int | None = None,
filter_deleted: bool = True,
) -> Workflow | None:
get_workflow_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
if filter_deleted:
get_workflow_query = exclude_deleted(get_workflow_query, WorkflowModel)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if version:
get_workflow_query = get_workflow_query.filter_by(version=version)
if ignore_version:
get_workflow_query = get_workflow_query.filter(WorkflowModel.version != ignore_version)
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
async with self.Session() as session:
if workflow := (await session.scalars(get_workflow_query)).first():
is_template = (
await self.is_workflow_template(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
)
if organization_id
else False
)
return convert_to_workflow(
workflow,
self.debug_enabled,
is_template=is_template,
)
return None
@db_operation("get_workflow_for_workflow_run")
async def get_workflow_for_workflow_run(
self,
workflow_run_id: str,
organization_id: str | None = None,
filter_deleted: bool = True,
) -> Workflow | None:
get_workflow_query = select(WorkflowModel)
if filter_deleted:
# Can't use exclude_deleted() here — it appends a WHERE clause, but the
# deleted_at filter must precede the JOIN to avoid filtering on the wrong table.
get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None))
get_workflow_query = get_workflow_query.join(
WorkflowRunModel,
WorkflowRunModel.workflow_id == WorkflowModel.workflow_id,
)
if organization_id:
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.organization_id == organization_id)
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.workflow_run_id == workflow_run_id)
async with self.Session() as session:
if workflow := (await session.scalars(get_workflow_query)).first():
is_template = (
await self.is_workflow_template(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
)
if organization_id
else False
)
return convert_to_workflow(
workflow,
self.debug_enabled,
is_template=is_template,
)
return None
@db_operation("get_workflow_versions_by_permanent_id")
async def get_workflow_versions_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
filter_deleted: bool = True,
) -> list[Workflow]:
"""
Get all versions of a workflow by its permanent ID, ordered by version descending (newest first).
"""
get_workflows_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
if filter_deleted:
get_workflows_query = get_workflows_query.filter(WorkflowModel.deleted_at.is_(None))
if organization_id:
get_workflows_query = get_workflows_query.filter_by(organization_id=organization_id)
get_workflows_query = get_workflows_query.order_by(WorkflowModel.version.desc())
async with self.Session() as session:
workflows = (await session.scalars(get_workflows_query)).all()
template_permanent_ids: set[str] = set()
if workflows and organization_id:
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
return [
convert_to_workflow(
workflow,
self.debug_enabled,
is_template=workflow.workflow_permanent_id in template_permanent_ids,
)
for workflow in workflows
]
@db_operation("get_workflows_by_permanent_ids")
async def get_workflows_by_permanent_ids(
self,
workflow_permanent_ids: list[str],
organization_id: str | None = None,
page: int = 1,
page_size: int = 10,
title: str = "",
statuses: list[WorkflowStatus] | None = None,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
async with self.Session() as session:
subquery = (
select(
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
main_query = select(WorkflowModel).join(
subquery,
(WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
if organization_id:
main_query = main_query.where(WorkflowModel.organization_id == organization_id)
if title:
main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%"))
if statuses:
main_query = main_query.where(WorkflowModel.status.in_(statuses))
main_query = (
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
# Map template status by permanent_id so API responses surface is_template
template_permanent_ids: set[str] = set()
if workflows and organization_id:
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
return [
convert_to_workflow(
workflow,
self.debug_enabled,
is_template=workflow.workflow_permanent_id in template_permanent_ids,
)
for workflow in workflows
]
@db_operation("get_workflows_by_organization_id")
async def get_workflows_by_organization_id(
self,
organization_id: str,
page: int = 1,
page_size: int = 10,
only_saved_tasks: bool = False,
only_workflows: bool = False,
only_templates: bool = False,
search_key: str | None = None,
folder_id: str | None = None,
statuses: list[WorkflowStatus] | None = None,
) -> list[Workflow]:
"""
Get all workflows with the latest version for the organization.
Search semantics:
- If `search_key` is provided, its value is used as a unified search term for
`workflows.title`, `folders.title`, and workflow parameter metadata (key, description, and default_value).
- If `search_key` is not provided, no search filtering is applied.
- Parameter metadata search excludes soft-deleted parameter rows across parameter tables.
"""
if page < 1:
raise ValueError(f"Page must be greater than 0, got {page}")
db_page = page - 1
async with self.Session() as session:
subquery = (
select(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
func.max(WorkflowModel.version).label("max_version"),
)
.where(WorkflowModel.organization_id == organization_id)
.where(WorkflowModel.deleted_at.is_(None))
.group_by(
WorkflowModel.organization_id,
WorkflowModel.workflow_permanent_id,
)
.subquery()
)
main_query = (
select(WorkflowModel)
.join(
subquery,
(WorkflowModel.organization_id == subquery.c.organization_id)
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
& (WorkflowModel.version == subquery.c.max_version),
)
.outerjoin(
FolderModel,
(WorkflowModel.folder_id == FolderModel.folder_id)
& (FolderModel.organization_id == WorkflowModel.organization_id),
)
)
if only_saved_tasks:
main_query = main_query.where(WorkflowModel.is_saved_task.is_(True))
elif only_workflows:
main_query = main_query.where(WorkflowModel.is_saved_task.is_(False))
if only_templates:
# Filter by workflow_templates table (templates at permanent_id level)
template_subquery = select(WorkflowTemplateModel.workflow_permanent_id).where(
WorkflowTemplateModel.organization_id == organization_id,
WorkflowTemplateModel.deleted_at.is_(None),
)
main_query = main_query.where(WorkflowModel.workflow_permanent_id.in_(template_subquery))
if statuses:
main_query = main_query.where(WorkflowModel.status.in_(statuses))
if folder_id:
main_query = main_query.where(WorkflowModel.folder_id == folder_id)
if search_key:
search_like = f"%{search_key}%"
title_like = WorkflowModel.title.ilike(search_like)
folder_title_like = FolderModel.title.ilike(search_like)
parameter_filters = [
# WorkflowParameterModel
exists(
select(1)
.select_from(WorkflowParameterModel)
.where(WorkflowParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(WorkflowParameterModel.deleted_at.is_(None))
.where(
or_(
WorkflowParameterModel.key.ilike(search_like),
WorkflowParameterModel.description.ilike(search_like),
WorkflowParameterModel.default_value.ilike(search_like),
)
)
),
# OutputParameterModel
exists(
select(1)
.select_from(OutputParameterModel)
.where(OutputParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(OutputParameterModel.deleted_at.is_(None))
.where(
or_(
OutputParameterModel.key.ilike(search_like),
OutputParameterModel.description.ilike(search_like),
)
)
),
# AWSSecretParameterModel
exists(
select(1)
.select_from(AWSSecretParameterModel)
.where(AWSSecretParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(AWSSecretParameterModel.deleted_at.is_(None))
.where(
or_(
AWSSecretParameterModel.key.ilike(search_like),
AWSSecretParameterModel.description.ilike(search_like),
)
)
),
# BitwardenLoginCredentialParameterModel
exists(
select(1)
.select_from(BitwardenLoginCredentialParameterModel)
.where(BitwardenLoginCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(BitwardenLoginCredentialParameterModel.deleted_at.is_(None))
.where(
or_(
BitwardenLoginCredentialParameterModel.key.ilike(search_like),
BitwardenLoginCredentialParameterModel.description.ilike(search_like),
)
)
),
# BitwardenSensitiveInformationParameterModel
exists(
select(1)
.select_from(BitwardenSensitiveInformationParameterModel)
.where(BitwardenSensitiveInformationParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(BitwardenSensitiveInformationParameterModel.deleted_at.is_(None))
.where(
or_(
BitwardenSensitiveInformationParameterModel.key.ilike(search_like),
BitwardenSensitiveInformationParameterModel.description.ilike(search_like),
)
)
),
# BitwardenCreditCardDataParameterModel
exists(
select(1)
.select_from(BitwardenCreditCardDataParameterModel)
.where(BitwardenCreditCardDataParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(BitwardenCreditCardDataParameterModel.deleted_at.is_(None))
.where(
or_(
BitwardenCreditCardDataParameterModel.key.ilike(search_like),
BitwardenCreditCardDataParameterModel.description.ilike(search_like),
)
)
),
# OnePasswordCredentialParameterModel
exists(
select(1)
.select_from(OnePasswordCredentialParameterModel)
.where(OnePasswordCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(OnePasswordCredentialParameterModel.deleted_at.is_(None))
.where(
or_(
OnePasswordCredentialParameterModel.key.ilike(search_like),
OnePasswordCredentialParameterModel.description.ilike(search_like),
)
)
),
# AzureVaultCredentialParameterModel
exists(
select(1)
.select_from(AzureVaultCredentialParameterModel)
.where(AzureVaultCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(AzureVaultCredentialParameterModel.deleted_at.is_(None))
.where(
or_(
AzureVaultCredentialParameterModel.key.ilike(search_like),
AzureVaultCredentialParameterModel.description.ilike(search_like),
)
)
),
# CredentialParameterModel
exists(
select(1)
.select_from(CredentialParameterModel)
.where(CredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
.where(CredentialParameterModel.deleted_at.is_(None))
.where(
or_(
CredentialParameterModel.key.ilike(search_like),
CredentialParameterModel.description.ilike(search_like),
)
)
),
]
main_query = main_query.where(or_(title_like, folder_title_like, or_(*parameter_filters)))
main_query = (
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
)
workflows = (await session.scalars(main_query)).all()
template_permanent_ids: set[str] = set()
if workflows and organization_id:
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
return [
convert_to_workflow(
workflow,
self.debug_enabled,
is_template=workflow.workflow_permanent_id in template_permanent_ids,
)
for workflow in workflows
]
@db_operation("update_workflow")
async def update_workflow(
self,
workflow_id: str,
organization_id: str | None = None,
title: str | None = None,
description: str | None = None,
workflow_definition: dict[str, Any] | None = None,
version: int | None = None,
run_with: str | None = None,
cache_key: str | None = None,
status: str | None = None,
import_error: str | None = None,
) -> Workflow:
async with self.Session() as session:
get_workflow_query = exclude_deleted(
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
)
if organization_id:
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
if title is not None:
workflow.title = title
if description is not None:
workflow.description = description
if workflow_definition is not None:
workflow.workflow_definition = workflow_definition
if version is not None:
workflow.version = version
if run_with is not None:
workflow.run_with = run_with
if cache_key is not None:
workflow.cache_key = cache_key
if status is not None:
workflow.status = status
if import_error is not None:
workflow.import_error = import_error
await session.commit()
await session.refresh(workflow)
is_template = (
await self.is_workflow_template(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
)
if organization_id
else False
)
return convert_to_workflow(
workflow,
self.debug_enabled,
is_template=is_template,
)
else:
raise NotFoundError("Workflow not found")
@db_operation("soft_delete_workflow_by_permanent_id")
async def soft_delete_workflow_by_permanent_id(
self,
workflow_permanent_id: str,
organization_id: str | None = None,
) -> None:
async with self.Session() as session:
# soft delete the workflow by setting the deleted_at field
update_deleted_at_query = (
update(WorkflowModel)
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowModel.deleted_at.is_(None))
)
if organization_id:
update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id)
update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow())
await session.execute(update_deleted_at_query)
await session.commit()
@db_operation("add_workflow_template")
async def add_workflow_template(
self,
workflow_permanent_id: str,
organization_id: str,
) -> None:
"""Add a workflow to the templates table."""
async with self.Session() as session:
existing = (
await session.scalars(
select(WorkflowTemplateModel)
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowTemplateModel.organization_id == organization_id)
)
).first()
if existing:
if existing.deleted_at is not None:
existing.deleted_at = None
await session.commit()
return
template = WorkflowTemplateModel(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
)
session.add(template)
await session.commit()
@db_operation("remove_workflow_template")
async def remove_workflow_template(
self,
workflow_permanent_id: str,
organization_id: str,
) -> None:
"""Soft delete a workflow from the templates table."""
async with self.Session() as session:
update_deleted_at_query = (
update(WorkflowTemplateModel)
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowTemplateModel.organization_id == organization_id)
.where(WorkflowTemplateModel.deleted_at.is_(None))
.values(deleted_at=datetime.utcnow())
)
await session.execute(update_deleted_at_query)
await session.commit()
@db_operation("get_org_template_permanent_ids")
async def get_org_template_permanent_ids(
self,
organization_id: str,
) -> set[str]:
"""Get all workflow_permanent_ids that are templates for an organization."""
async with self.Session() as session:
result = await session.scalars(
select(WorkflowTemplateModel.workflow_permanent_id)
.where(WorkflowTemplateModel.organization_id == organization_id)
.where(WorkflowTemplateModel.deleted_at.is_(None))
)
return set(result.all())
@db_operation("is_workflow_template")
async def is_workflow_template(
self,
workflow_permanent_id: str,
organization_id: str,
) -> bool:
"""Check if a workflow is marked as a template."""
async with self.Session() as session:
result = (
await session.scalars(
select(WorkflowTemplateModel)
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
.where(WorkflowTemplateModel.organization_id == organization_id)
.where(WorkflowTemplateModel.deleted_at.is_(None))
)
).first()
return result is not None

View file

@ -29,7 +29,9 @@ async def await_browser_session(
try:
async with asyncio.timeout(timeout):
while True:
persistent_browser_session = await db.get_persistent_browser_session(session_id, organization_id)
persistent_browser_session = await db.browser_sessions.get_persistent_browser_session(
session_id, organization_id
)
if persistent_browser_session is None:
raise Exception(f"Persistent browser session not found for {session_id}")

View file

@ -13,12 +13,10 @@ from skyvern.forge.sdk.db.base_repository import BaseRepository
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
BrowserProfileModel,
DebugSessionModel,
PersistentBrowserSessionModel,
)
from skyvern.forge.sdk.db.utils import serialize_proxy_location
from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
Extensions,
PersistentBrowserSession,
@ -457,19 +455,3 @@ class BrowserSessionsRepository(BaseRepository):
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
)
return result.scalars().all()
@db_operation("get_debug_session_by_browser_session_id")
async def get_debug_session_by_browser_session_id(
self,
browser_session_id: str,
organization_id: str,
) -> DebugSession | None:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
model = (await session.scalars(query)).first()
return DebugSession.model_validate(model) if model else None

View file

@ -111,27 +111,20 @@ class DebugRepository(BaseRepository):
await session.commit()
@db_operation("get_latest_debug_session_for_user")
async def get_latest_debug_session_for_user(
@db_operation("get_debug_session_by_browser_session_id")
async def get_debug_session_by_browser_session_id(
self,
*,
browser_session_id: str,
organization_id: str,
user_id: str,
workflow_permanent_id: str,
) -> DebugSession | None:
async with self.Session() as session:
query = (
select(DebugSessionModel)
.filter_by(browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
.filter_by(status="created")
.filter_by(user_id=user_id)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.order_by(DebugSessionModel.created_at.desc())
)
model = (await session.scalars(query)).first()
return DebugSession.model_validate(model) if model else None
@db_operation("get_debug_session_by_id")

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime, timedelta, timezone
import structlog
from sqlalchemy import and_, delete, distinct, func, select, update
from sqlalchemy import and_, delete, distinct, func, or_, select, update
from sqlalchemy.dialects.postgresql import insert
from skyvern.forge.sdk.db._error_handling import db_operation
@ -699,12 +699,18 @@ class ScriptsRepository(BaseRepository):
page_size: int = 50,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> tuple[list[WorkflowRunModel], int, dict[str, int]]:
"""Get workflow runs associated with a script, with total count and status counts.
) -> tuple[list[WorkflowRunModel], int, dict[str, int], float | None]:
"""Get workflow runs associated with a script, with total count, status counts,
and average AI fallbacks per run.
Returns (runs, total_count, status_counts) where runs is limited by page_size,
total_count is derived from the status_counts GROUP BY, and status_counts is a
GROUP BY aggregation of statuses across all runs.
Includes actual script runs (run_with='code', 'code_v2', or NULL), excluding
explicit agent runs. run_with is NULL when the workflow ran in auto mode and
should_run_script() resolved to code via fallback (e.g. code_version >= 1).
Returns (runs, total_count, status_counts, avg_fallbacks_per_run) where runs
is limited by page_size, total_count is derived from the status_counts GROUP BY,
status_counts is a GROUP BY aggregation of statuses across all runs, and
avg_fallbacks_per_run is the average number of fallback episodes per run.
If created_after/created_before are provided, filters by the workflow_script
entry's created_at (not the run's created_at), scoping to the version that
@ -730,10 +736,19 @@ class ScriptsRepository(BaseRepository):
WorkflowScriptModel.created_at < created_before,
)
# Base filter for workflow runs
# Base filter for workflow runs - only include actual script runs.
# run_with may be NULL when the workflow ran in auto mode and
# should_run_script() resolved to code mode via fallback (e.g.
# code_version >= 1 or adaptive_caching). NULL is therefore
# treated as a code run here; explicit "agent" runs are excluded
# by not appearing in the workflow_scripts join above.
base_filters = [
WorkflowRunModel.workflow_run_id.in_(run_ids_subquery),
WorkflowRunModel.organization_id == organization_id,
or_(
WorkflowRunModel.run_with.in_(["code", "code_v2"]),
WorkflowRunModel.run_with.is_(None),
),
]
# Count statuses via GROUP BY (also gives us total_count)
@ -744,7 +759,7 @@ class ScriptsRepository(BaseRepository):
total_count = sum(status_counts.values())
if total_count == 0:
return [], 0, {}
return [], 0, {}, None
# Get the actual workflow runs (paginated)
runs_query = (
@ -755,7 +770,27 @@ class ScriptsRepository(BaseRepository):
)
runs = list((await session.scalars(runs_query)).all())
return runs, total_count, status_counts
# Compute average AI fallbacks per run over the last 20 runs.
max_fallback_sample = 20
recent_run_ids = (
select(WorkflowRunModel.workflow_run_id)
.filter(*base_filters)
.order_by(WorkflowRunModel.created_at.desc())
.limit(max_fallback_sample)
)
total_fallbacks_result = await session.execute(
select(func.count())
.select_from(ScriptFallbackEpisodeModel)
.filter(
ScriptFallbackEpisodeModel.workflow_run_id.in_(recent_run_ids),
ScriptFallbackEpisodeModel.organization_id == organization_id,
)
)
total_fallbacks = total_fallbacks_result.scalar() or 0
sample_size = min(total_count, max_fallback_sample)
avg_fallbacks_per_run = round(total_fallbacks / sample_size, 2)
return runs, total_count, status_counts, avg_fallbacks_per_run
@db_operation("get_script_run_stats")
async def get_script_run_stats(

View file

@ -36,14 +36,14 @@ class BackgroundTaskExecutor(AsyncExecutor):
if organization is None:
raise OrganizationNotFound(organization_id)
step = await app.DATABASE.create_step(
step = await app.DATABASE.tasks.create_step(
task_id,
order=0,
retry_index=0,
organization_id=organization_id,
)
task = await app.DATABASE.update_task(
task = await app.DATABASE.tasks.update_task(
task_id,
status=TaskStatus.running,
organization_id=organization_id,
@ -51,7 +51,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
close_browser_on_completion = browser_session_id is None and not task.browser_address
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
run_obj = await app.DATABASE.tasks.get_run(run_id=task_id, organization_id=organization_id)
engine = RunEngine.skyvern_v1
if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua
@ -136,17 +136,17 @@ class BackgroundTaskExecutor(AsyncExecutor):
if organization is None:
raise OrganizationNotFound(organization_id)
task_v2 = await app.DATABASE.get_task_v2(task_v2_id=task_v2_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(task_v2_id=task_v2_id, organization_id=organization_id)
if not task_v2 or not task_v2.workflow_run_id:
raise ValueError("No task v2 or no workflow run associated with task v2")
# mark task v2 as queued
await app.DATABASE.update_task_v2(
await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2_id,
status=TaskV2Status.queued,
organization_id=organization_id,
)
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=task_v2.workflow_run_id,
status=WorkflowRunStatus.queued,
)

View file

@ -127,7 +127,7 @@ async def _save_log_artifacts(
return
log_json = json.dumps(log, cls=SkyvernJSONLogEncoder, indent=2)
log_artifact = await app.DATABASE.get_artifact_by_entity_id(
log_artifact = await app.DATABASE.artifacts.get_artifact_by_entity_id(
artifact_type=ArtifactType.SKYVERN_LOG_RAW,
step_id=step_id,
task_id=task_id,
@ -158,7 +158,7 @@ async def _save_log_artifacts(
formatted_log = SkyvernLogEncoder.encode(log)
formatted_log_artifact = await app.DATABASE.get_artifact_by_entity_id(
formatted_log_artifact = await app.DATABASE.artifacts.get_artifact_by_entity_id(
artifact_type=ArtifactType.SKYVERN_LOG,
step_id=step_id,
task_id=task_id,

View file

@ -304,7 +304,7 @@ async def run_task(
max_steps_override=run_request.max_steps,
browser_session_id=run_request.browser_session_id,
)
refreshed_task_v2 = await app.DATABASE.get_task_v2(
refreshed_task_v2 = await app.DATABASE.observer.get_task_v2(
task_v2_id=task_v2.observer_cruise_id, organization_id=current_org.organization_id
)
task_v2 = refreshed_task_v2 if refreshed_task_v2 else task_v2
@ -1416,7 +1416,7 @@ async def get_artifact(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> Artifact:
analytics.capture("skyvern-oss-artifact-get")
artifact = await app.DATABASE.get_artifact_by_id(
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
artifact_id=artifact_id,
organization_id=current_org.organization_id,
)
@ -1488,14 +1488,14 @@ async def get_artifact_content(
status_code=http_status.HTTP_403_FORBIDDEN,
detail="Invalid or expired artifact URL",
)
artifact = await app.DATABASE.get_artifact_by_id_no_org(artifact_id=artifact_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id_no_org(artifact_id=artifact_id)
else:
# Standard org-auth path (existing behaviour).
current_org = await org_auth_service.get_current_org(
x_api_key=x_api_key,
authorization=authorization,
)
artifact = await app.DATABASE.get_artifact_by_id(
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
artifact_id=artifact_id,
organization_id=current_org.organization_id,
)
@ -1542,7 +1542,7 @@ async def get_run_artifacts(
) -> Response:
analytics.capture("skyvern-oss-run-artifacts-get")
# Get artifacts as a list (not grouped by type)
artifacts = await app.DATABASE.get_artifacts_for_run(
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
run_id=run_id,
organization_id=current_org.organization_id,
artifact_types=artifact_type,
@ -1642,7 +1642,9 @@ async def get_run_timeline(
# Handle task_v2 runs by getting their associated workflow_run_id
if run_response.run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(task_v2_id=run_id, organization_id=current_org.organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(
task_v2_id=run_id, organization_id=current_org.organization_id
)
if not task_v2:
raise HTTPException(
status_code=http_status.HTTP_404_NOT_FOUND,
@ -1902,7 +1904,7 @@ async def cancel_task(
x_api_key: Annotated[str | None, Header()] = None,
) -> None:
analytics.capture("skyvern-oss-agent-task-get")
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=current_org.organization_id)
if not task_obj:
raise HTTPException(
status_code=http_status.HTTP_404_NOT_FOUND,
@ -1914,7 +1916,7 @@ async def cancel_task(
async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api_key: str | None = None) -> None:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -1929,7 +1931,7 @@ async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api
await app.PERSISTENT_SESSIONS_MANAGER.release_browser_session(workflow_run.browser_session_id, organization_id)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
parent_workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -1949,7 +1951,7 @@ async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api
async def _continue_workflow_run(workflow_run_id: str, organization_id: str) -> None:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
status=WorkflowRunStatus.paused,
@ -2028,7 +2030,7 @@ async def retry_webhook(
x_api_key: Annotated[str | None, Header()] = None,
) -> TaskResponse:
analytics.capture("skyvern-oss-agent-task-retry-webhook")
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=current_org.organization_id)
if not task_obj:
raise HTTPException(
status_code=http_status.HTTP_404_NOT_FOUND,
@ -2036,7 +2038,7 @@ async def retry_webhook(
)
# get latest step
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
latest_step = await app.DATABASE.tasks.get_latest_step(task_id, organization_id=current_org.organization_id)
if not latest_step:
return await app.agent.build_task_response(task=task_obj)
@ -2088,7 +2090,7 @@ async def get_tasks(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail="only_standalone_tasks and workflow_run_id cannot be used together",
)
tasks = await app.DATABASE.get_tasks(
tasks = await app.DATABASE.tasks.get_tasks(
page,
page_size,
task_status=task_status,
@ -2137,7 +2139,7 @@ async def get_runs(
if page > 10:
return []
runs = await app.DATABASE.get_all_runs(
runs = await app.DATABASE.workflow_runs.get_all_runs(
current_org.organization_id, page=page, page_size=page_size, status=status, search_key=search_key
)
return ORJSONResponse([run.model_dump() for run in runs])
@ -2174,7 +2176,7 @@ async def get_runs_v2(
) -> Response:
analytics.capture("skyvern-oss-agent-runs-v2-get")
rows = await app.DATABASE.get_all_runs_v2(
rows = await app.DATABASE.workflow_runs.get_all_runs_v2(
current_org.organization_id,
page=page,
page_size=page_size,
@ -2208,7 +2210,7 @@ async def get_steps(
:return: List of steps for a task with pagination.
"""
analytics.capture("skyvern-oss-agent-task-steps-get")
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
steps = await app.DATABASE.tasks.get_task_steps(task_id, organization_id=current_org.organization_id)
return ORJSONResponse([step.model_dump(exclude_none=True) for step in steps])
@ -2255,7 +2257,10 @@ async def get_artifacts(
params = {
entity_type_to_param[entity_type]: entity_id,
}
artifacts = await app.DATABASE.get_artifacts_by_entity_id(organization_id=current_org.organization_id, **params) # type: ignore
artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
organization_id=current_org.organization_id,
**params, # type: ignore[arg-type]
)
signed_urls = await app.ARTIFACT_MANAGER.get_share_links_with_bundle_support(artifacts)
for i, artifact in enumerate(artifacts):
@ -2289,7 +2294,7 @@ async def get_step_artifacts(
:return: List of artifacts for a list of steps.
"""
analytics.capture("skyvern-oss-agent-task-step-artifacts-get")
artifacts = await app.DATABASE.get_artifacts_for_task_step(
artifacts = await app.DATABASE.artifacts.get_artifacts_for_task_step(
task_id,
step_id,
organization_id=current_org.organization_id,
@ -2318,7 +2323,7 @@ async def get_actions(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> list[Action]:
analytics.capture("skyvern-oss-agent-task-actions-get")
actions = await app.DATABASE.get_task_actions(task_id, organization_id=current_org.organization_id)
actions = await app.DATABASE.tasks.get_task_actions(task_id, organization_id=current_org.organization_id)
return actions
@ -2602,7 +2607,7 @@ async def get_workflow_run_with_workflow_id(
)
return_dict = workflow_run_status_response.model_dump(by_alias=True)
browser_session = await app.DATABASE.get_persistent_browser_session_by_runnable_id(
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session_by_runnable_id(
runnable_id=workflow_run_id,
organization_id=current_org.organization_id,
)
@ -2966,7 +2971,7 @@ async def suggest(
)
try:
new_ai_suggestion = await app.DATABASE.create_ai_suggestion(
new_ai_suggestion = await app.DATABASE.workflow_params.create_ai_suggestion(
organization_id=current_org.organization_id,
ai_suggestion_type=ai_suggestion_type,
)
@ -3239,7 +3244,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
"""
# get task v2 by workflow run id
task_v2_obj = await app.DATABASE.get_task_v2_by_workflow_run_id(
task_v2_obj = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)

View file

@ -143,7 +143,7 @@ async def list_browser_profiles(
include_deleted=include_deleted,
)
profiles = await app.DATABASE.list_browser_profiles(
profiles = await app.DATABASE.browser_sessions.list_browser_profiles(
organization_id=organization_id,
include_deleted=include_deleted,
)
@ -199,7 +199,7 @@ async def get_browser_profile(
browser_profile_id=profile_id,
)
profile = await app.DATABASE.get_browser_profile(
profile = await app.DATABASE.browser_sessions.get_browser_profile(
profile_id=profile_id,
organization_id=organization_id,
)
@ -264,7 +264,7 @@ async def delete_browser_profile(
)
try:
await app.DATABASE.delete_browser_profile(
await app.DATABASE.browser_sessions.delete_browser_profile(
profile_id=profile_id,
organization_id=organization_id,
)
@ -290,7 +290,9 @@ async def _create_profile_from_session(
description: str | None,
browser_session_id: str,
) -> BrowserProfile:
browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id)
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
browser_session_id, organization_id
)
if browser_session is None:
LOG.warning(
"Browser session not found for profile creation",
@ -318,7 +320,7 @@ async def _create_profile_from_session(
)
try:
profile = await app.DATABASE.create_browser_profile(
profile = await app.DATABASE.browser_sessions.create_browser_profile(
organization_id=organization_id,
name=name,
description=description,
@ -334,7 +336,9 @@ async def _create_profile_from_session(
)
except Exception:
# Rollback: delete the profile if storage fails
await app.DATABASE.delete_browser_profile(profile.browser_profile_id, organization_id=organization_id)
await app.DATABASE.browser_sessions.delete_browser_profile(
profile.browser_profile_id, organization_id=organization_id
)
LOG.error(
"Failed to store browser profile artifacts, rolled back profile creation",
organization_id=organization_id,
@ -359,7 +363,7 @@ async def _create_profile_from_workflow_run(
description: str | None,
workflow_run_id: str,
) -> BrowserProfile:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id=organization_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id=organization_id)
if not workflow_run:
LOG.warning(
"Workflow run not found for profile creation",
@ -421,7 +425,7 @@ async def _create_profile_from_workflow_run(
)
try:
profile = await app.DATABASE.create_browser_profile(
profile = await app.DATABASE.browser_sessions.create_browser_profile(
organization_id=organization_id,
name=name,
description=description,
@ -443,7 +447,9 @@ async def _create_profile_from_workflow_run(
)
except Exception:
# Rollback: delete the profile if storage fails
await app.DATABASE.delete_browser_profile(profile.browser_profile_id, organization_id=organization_id)
await app.DATABASE.browser_sessions.delete_browser_profile(
profile.browser_profile_id, organization_id=organization_id
)
LOG.error(
"Failed to store browser profile artifacts, rolled back profile creation",
organization_id=organization_id,

View file

@ -42,7 +42,7 @@ async def get_browser_sessions_all(
"""Get all browser sessions for the organization"""
analytics.capture("skyvern-oss-agent-browser-sessions-get-all")
browser_sessions = await app.DATABASE.get_persistent_browser_sessions_history(
browser_sessions = await app.DATABASE.browser_sessions.get_persistent_browser_sessions_history(
current_org.organization_id,
page=page,
page_size=page_size,
@ -91,7 +91,7 @@ async def create_browser_session(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> BrowserSessionResponse:
if browser_session_request.browser_profile_id:
profile = await app.DATABASE.get_browser_profile(
profile = await app.DATABASE.browser_sessions.get_browser_profile(
browser_session_request.browser_profile_id,
current_org.organization_id,
)

View file

@ -163,7 +163,7 @@ async def send_totp_code(
)
# validate task_id, workflow_id, workflow_run_id are valid ids in db if provided
if data.task_id:
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
task = await app.DATABASE.tasks.get_task(data.task_id, curr_org.organization_id)
if not task:
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
if data.workflow_id:
@ -171,7 +171,7 @@ async def send_totp_code(
if not workflow:
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
if data.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
if not workflow_run:
raise HTTPException(status_code=400, detail=f"Invalid workflow run id: {data.workflow_run_id}")
content = data.content.strip()
@ -679,7 +679,7 @@ async def test_credential(
# Check if the credential already has a browser profile
existing_browser_profile_id = credential.browser_profile_id
if existing_browser_profile_id:
profile = await app.DATABASE.get_browser_profile(
profile = await app.DATABASE.browser_sessions.get_browser_profile(
profile_id=existing_browser_profile_id,
organization_id=organization_id,
)
@ -891,7 +891,9 @@ async def get_test_credential_status(
) -> TestCredentialStatusResponse:
organization_id = current_org.organization_id
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not workflow_run:
raise HTTPException(status_code=404, detail=f"Workflow run {workflow_run_id} not found")
@ -1054,7 +1056,7 @@ async def _create_browser_profile_after_workflow(
try:
for _ in range(max_polls):
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not workflow_run:
@ -1130,7 +1132,7 @@ async def _create_browser_profile_after_workflow(
# Create the browser profile in DB
profile_name = f"Profile - {credential_name} ({credential_id})"
profile = await app.DATABASE.create_browser_profile(
profile = await app.DATABASE.browser_sessions.create_browser_profile(
organization_id=organization_id,
name=profile_name,
description=f"Browser profile from credential test for {credential_name}",

View file

@ -71,7 +71,7 @@ async def get_or_create_debug_session_by_user_and_workflow_permanent_id(
)
# Skip renewal for sessions that haven't started yet (browser still launching)
session = await app.DATABASE.get_persistent_browser_session(
session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
debug_session.browser_session_id,
current_org.organization_id,
)
@ -129,7 +129,7 @@ async def new_debug_session(
"""
if current_user_id:
debug_session = await app.DATABASE.debug.get_latest_debug_session_for_user(
debug_session = await app.DATABASE.debug.get_debug_session(
organization_id=current_org.organization_id,
user_id=current_user_id,
workflow_permanent_id=workflow_permanent_id,
@ -172,7 +172,7 @@ async def new_debug_session(
for debug_session in completed_debug_sessions:
try:
browser_session = await app.DATABASE.get_persistent_browser_session(
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
debug_session.browser_session_id,
current_org.organization_id,
)

View file

@ -52,13 +52,13 @@ async def _load_main_script_content(
script_revision_id: str,
) -> str | None:
"""Load the main.py content from a script revision, if it exists."""
script_files = await app.DATABASE.get_script_files(
script_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=script_revision_id,
organization_id=organization_id,
)
for f in script_files:
if f.file_path == "main.py" and f.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(f.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(f.artifact_id, organization_id)
if artifact:
data = await app.STORAGE.retrieve_artifact(artifact)
if data:
@ -81,7 +81,7 @@ async def get_script_blocks_response(
script_id: str | None = None,
version: int | None = None,
) -> ScriptBlocksResponse:
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script_revision_id,
organization_id=organization_id,
)
@ -117,7 +117,7 @@ async def get_script_blocks_response(
)
continue
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id=script_revision_id,
file_id=script_file_id,
organization_id=organization_id,
@ -147,7 +147,7 @@ async def get_script_blocks_response(
)
continue
artifact = await app.DATABASE.get_artifact_by_id(
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
artifact_id,
organization_id,
)
@ -261,7 +261,7 @@ async def get_script(
script_id=script_id,
)
script = await app.DATABASE.get_script(
script = await app.DATABASE.scripts.get_script(
script_id=script_id,
organization_id=current_org.organization_id,
)
@ -287,7 +287,7 @@ async def get_script_versions(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> ScriptVersionListResponse:
"""List all versions of a script."""
scripts = await app.DATABASE.get_script_versions(
scripts = await app.DATABASE.scripts.get_script_versions(
script_id=script_id,
organization_id=current_org.organization_id,
)
@ -319,7 +319,7 @@ async def get_script_version_code(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> ScriptBlocksResponse:
"""Get a specific version's code blocks."""
script = await app.DATABASE.get_script(
script = await app.DATABASE.scripts.get_script(
script_id=script_id,
organization_id=current_org.organization_id,
version=version,
@ -358,8 +358,8 @@ async def compare_script_versions(
organization_id = current_org.organization_id
base_script, compare_script = await asyncio.gather(
app.DATABASE.get_script(script_id=script_id, organization_id=organization_id, version=base),
app.DATABASE.get_script(script_id=script_id, organization_id=organization_id, version=compare),
app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id, version=base),
app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id, version=compare),
)
if not base_script:
raise HTTPException(status_code=404, detail=f"Base version {base} not found")
@ -418,7 +418,7 @@ async def get_script_version_detail(
"""Get full detail for a specific script version, including code blocks and metadata."""
organization_id = current_org.organization_id
script = await app.DATABASE.get_script(
script = await app.DATABASE.scripts.get_script(
script_id=script_id,
organization_id=organization_id,
version=version,
@ -436,7 +436,7 @@ async def get_script_version_detail(
script_id=script.script_id,
version=script.version,
),
app.DATABASE.get_fallback_episodes_count(
app.DATABASE.scripts.get_fallback_episodes_count(
organization_id=organization_id,
script_revision_id=script.script_revision_id,
),
@ -492,7 +492,7 @@ async def get_scripts(
page_size=page_size,
)
scripts = await app.DATABASE.get_scripts(
scripts = await app.DATABASE.scripts.get_scripts(
organization_id=current_org.organization_id,
page=page,
page_size=page_size,
@ -535,7 +535,7 @@ async def deploy_script(
try:
# Get the latest version of the script
latest_script = await app.DATABASE.get_latest_script_version(
latest_script = await app.DATABASE.scripts.get_latest_script_version(
script_id=script_id,
organization_id=current_org.organization_id,
)
@ -545,7 +545,7 @@ async def deploy_script(
# Create a new version of the script
new_version = latest_script.version + 1
new_script = await app.DATABASE.create_script(
new_script = await app.DATABASE.scripts.create_script(
organization_id=current_org.organization_id,
run_id=latest_script.run_id,
script_id=script_id,
@ -553,7 +553,7 @@ async def deploy_script(
)
# Fetch source files from the base revision to build the old->new ID mapping
source_files = await app.DATABASE.get_script_files(
source_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=latest_script.script_revision_id,
organization_id=current_org.organization_id,
)
@ -580,7 +580,7 @@ async def deploy_script(
file_path=file.path,
data=content_bytes,
)
new_file = await app.DATABASE.create_script_file(
new_file = await app.DATABASE.scripts.create_script_file(
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=current_org.organization_id,
@ -601,7 +601,7 @@ async def deploy_script(
for f in source_files:
if f.file_path in deployed_file_paths:
continue
new_file = await app.DATABASE.create_script_file(
new_file = await app.DATABASE.scripts.create_script_file(
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=current_org.organization_id,
@ -617,13 +617,13 @@ async def deploy_script(
old_to_new_file_id[f.file_id] = new_file.file_id
# Copy existing script blocks, re-pointing file IDs to the new revision's files
existing_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
existing_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=latest_script.script_revision_id,
organization_id=current_org.organization_id,
)
for sb in existing_blocks:
new_file_id = old_to_new_file_id.get(sb.script_file_id, sb.script_file_id) if sb.script_file_id else None
await app.DATABASE.create_script_block(
await app.DATABASE.scripts.create_script_block(
organization_id=current_org.organization_id,
script_id=new_script.script_id,
script_revision_id=new_script.script_revision_id,
@ -716,7 +716,7 @@ async def get_workflow_script_blocks(
include_main_script = True
workflow_run_id = block_script_request.workflow_run_id
if workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
@ -727,7 +727,7 @@ async def get_workflow_script_blocks(
# get_workflow_script() always resolves the latest version for a
# cache_key_value, but the Code tab should show the version that was
# active when this run executed (SKY-8448).
workflow_script = await app.DATABASE.get_workflow_script(
workflow_script = await app.DATABASE.scripts.get_workflow_script(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
@ -833,7 +833,7 @@ async def get_workflow_cache_key_values(
) -> ScriptCacheKeyValuesResponse:
# TODO(jdo): concurrent-ize
values = await app.DATABASE.get_workflow_cache_key_values(
values = await app.DATABASE.scripts.get_workflow_cache_key_values(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key=cache_key,
@ -842,13 +842,13 @@ async def get_workflow_cache_key_values(
filter=filter,
)
total_count = await app.DATABASE.get_workflow_cache_key_count(
total_count = await app.DATABASE.scripts.get_workflow_cache_key_count(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key=cache_key,
)
filtered_count = await app.DATABASE.get_workflow_cache_key_count(
filtered_count = await app.DATABASE.scripts.get_workflow_cache_key_count(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key=cache_key,
@ -889,7 +889,7 @@ async def list_workflow_scripts(
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
workflow_scripts = await app.DATABASE.get_workflow_scripts_by_permanent_id(
workflow_scripts = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
statuses=[ScriptStatus.published],
@ -918,11 +918,11 @@ async def list_workflow_scripts(
# These are independent -- run in parallel.
rep_script_ids = [ws.script_id for ws in representatives]
version_stats, run_stats = await asyncio.gather(
app.DATABASE.get_script_version_stats(
app.DATABASE.scripts.get_script_version_stats(
organization_id=organization_id,
script_ids=rep_script_ids,
),
app.DATABASE.get_script_run_stats(
app.DATABASE.scripts.get_script_run_stats(
organization_id=organization_id,
script_ids=rep_script_ids,
),
@ -980,7 +980,7 @@ async def get_script_runs(
organization_id = current_org.organization_id
# Verify script exists
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
script = await app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id)
if not script:
raise HTTPException(status_code=404, detail="Script not found")
@ -989,7 +989,7 @@ async def get_script_runs(
if version is not None:
# Get all versions to determine the time window for this version
all_versions = await app.DATABASE.get_script_versions(
all_versions = await app.DATABASE.scripts.get_script_versions(
script_id=script_id,
organization_id=organization_id,
)
@ -1006,7 +1006,7 @@ async def get_script_runs(
if not version_found:
raise HTTPException(status_code=404, detail=f"Script version {version} not found")
runs, total_count, status_counts, avg_fallbacks_per_run = await app.DATABASE.get_workflow_runs_for_script(
runs, total_count, status_counts, avg_fallbacks_per_run = await app.DATABASE.scripts.get_workflow_runs_for_script(
organization_id=organization_id,
script_id=script_id,
page_size=page_size,
@ -1063,7 +1063,7 @@ async def delete_workflow_cache_key_value(
raise HTTPException(status_code=404, detail="Workflow not found")
# Delete the cache key value
deleted = await app.DATABASE.delete_workflow_cache_key_value(
deleted = await app.DATABASE.scripts.delete_workflow_cache_key_value(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key_value=cache_key_value,
@ -1130,7 +1130,7 @@ async def clear_workflow_cache(
raise HTTPException(status_code=404, detail="Workflow not found")
# Clear database cache (soft delete)
deleted_count = await app.DATABASE.delete_workflow_scripts_by_permanent_id(
deleted_count = await app.DATABASE.scripts.delete_workflow_scripts_by_permanent_id(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
)
@ -1185,7 +1185,7 @@ async def pin_workflow_script(
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
result = await app.DATABASE.pin_workflow_script(
result = await app.DATABASE.scripts.pin_workflow_script(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key_value=data.cache_key_value,
@ -1233,7 +1233,7 @@ async def unpin_workflow_script(
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
result = await app.DATABASE.unpin_workflow_script(
result = await app.DATABASE.scripts.unpin_workflow_script(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key_value=data.cache_key_value,
@ -1299,7 +1299,7 @@ async def review_script_with_instructions(
run_parameter_values: dict[str, str] = {}
latest_script = None
if data.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=data.workflow_run_id,
organization_id=organization_id,
)
@ -1308,17 +1308,17 @@ async def review_script_with_instructions(
if workflow_run.workflow_permanent_id != workflow_permanent_id:
raise HTTPException(status_code=400, detail="Workflow run does not belong to this workflow")
# Look up the specific script used by this run
run_workflow_script = await app.DATABASE.get_workflow_script(
run_workflow_script = await app.DATABASE.scripts.get_workflow_script(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=data.workflow_run_id,
)
if run_workflow_script:
latest_script = await app.DATABASE.get_latest_script_version(
latest_script = await app.DATABASE.scripts.get_latest_script_version(
script_id=run_workflow_script.script_id,
organization_id=organization_id,
)
episodes = await app.DATABASE.get_fallback_episodes(
episodes = await app.DATABASE.scripts.get_fallback_episodes(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=data.workflow_run_id,
@ -1326,7 +1326,7 @@ async def review_script_with_instructions(
page_size=50,
)
try:
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=data.workflow_run_id,
)
for wf_param, run_param in run_param_tuples:
@ -1436,7 +1436,7 @@ async def get_fallback_episodes(
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
episodes = await app.DATABASE.get_fallback_episodes(
episodes = await app.DATABASE.scripts.get_fallback_episodes(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
page=page,
@ -1446,7 +1446,7 @@ async def get_fallback_episodes(
reviewed=reviewed,
fallback_type=fallback_type,
)
total_count = await app.DATABASE.get_fallback_episodes_count(
total_count = await app.DATABASE.scripts.get_fallback_episodes_count(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
@ -1485,7 +1485,7 @@ async def get_fallback_episode(
if not workflow:
raise HTTPException(status_code=404, detail="Workflow not found")
episode = await app.DATABASE.get_fallback_episode(
episode = await app.DATABASE.scripts.get_fallback_episode(
episode_id=episode_id,
organization_id=current_org.organization_id,
)

View file

@ -56,7 +56,7 @@ async def run_sdk_action(
# Use existing workflow_run_id if provided, otherwise create a new one
if action_request.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=action_request.workflow_run_id,
organization_id=organization_id,
)
@ -90,12 +90,12 @@ async def run_sdk_action(
organization=organization,
version=None,
)
workflow_run = await app.DATABASE.update_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run.workflow_run_id,
status=WorkflowRunStatus.completed,
)
task = await app.DATABASE.create_task(
task = await app.DATABASE.tasks.create_task(
organization_id=organization_id,
url=action_request.url,
navigation_goal=action.get_navigation_goal(),
@ -107,14 +107,14 @@ async def run_sdk_action(
browser_address=browser_address,
)
step = await app.DATABASE.create_step(
step = await app.DATABASE.tasks.create_step(
task.task_id,
order=0,
retry_index=0,
organization_id=organization.organization_id,
)
await app.DATABASE.create_workflow_run_block(
await app.DATABASE.observer.create_workflow_run_block(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=organization_id,
block_type=BlockType.ACTION,
@ -221,13 +221,13 @@ async def run_sdk_action(
model=action.model,
)
result = prompt_result
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=organization_id,
status=TaskStatus.completed,
)
except ScrapingFailed as e:
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=organization_id,
status=TaskStatus.failed,
@ -240,7 +240,7 @@ async def run_sdk_action(
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.reason or str(e))
except Exception as e:
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=organization_id,
status=TaskStatus.failed,

View file

@ -271,7 +271,7 @@ async def cdp_input_stream(
try:
deadline = time.monotonic() + 120
while True:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)

View file

@ -86,7 +86,7 @@ async def task_stream(
)
return
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
await websocket.send_json(
@ -210,7 +210,7 @@ async def workflow_run_streaming(
)
return
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -403,7 +403,7 @@ async def _local_screencast_for_workflow_run(
async def wait_for_running() -> str | None:
deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
while True:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -423,14 +423,14 @@ async def _local_screencast_for_workflow_run(
await asyncio.sleep(1)
async def check_finalized() -> bool:
wr = await app.DATABASE.get_workflow_run(
wr = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return wr is None or wr.status.is_final()
async def get_current_status() -> str | None:
wr = await app.DATABASE.get_workflow_run(
wr = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -457,7 +457,7 @@ async def _local_screencast_for_task(
nonlocal task_workflow_run_id
deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
while True:
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
if not task:
return "not_found"
if task.status.is_final():
@ -475,11 +475,11 @@ async def _local_screencast_for_task(
await asyncio.sleep(1)
async def check_finalized() -> bool:
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
return task is None or task.status.is_final()
async def get_current_status() -> str | None:
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
return task.status if task else None
await _run_local_screencast(

View file

@ -117,7 +117,7 @@ async def verify_task(
with it.
"""
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
@ -171,7 +171,7 @@ async def verify_workflow_run(
with it.
"""
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)

View file

@ -66,7 +66,7 @@ class RunInfo:
async def _get_debug_artifact(organization_id: str, workflow_run_id: str) -> Artifact | None:
artifacts = await app.DATABASE.get_artifacts_for_run(
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
run_id=workflow_run_id, organization_id=organization_id, artifact_types=[ArtifactType.VISIBLE_ELEMENTS_TREE]
)
return artifacts[0] if isinstance(artifacts, list) and artifacts else None
@ -76,7 +76,7 @@ async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None)
if not workflow_run_id:
return None
blocks = await app.DATABASE.get_workflow_run_blocks(
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not blocks:
@ -651,7 +651,7 @@ async def workflow_copilot_chat_post(
)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
@ -660,12 +660,12 @@ async def workflow_copilot_chat_post(
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.create_workflow_copilot_chat(
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
@ -719,20 +719,20 @@ async def workflow_copilot_chat_post(
return
if updated_workflow and chat.auto_accept is not True:
await app.DATABASE.update_workflow_copilot_chat(
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
proposed_workflow=updated_workflow.model_dump(mode="json"),
)
await app.DATABASE.create_workflow_copilot_chat_message(
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.AI,
@ -791,12 +791,14 @@ async def workflow_copilot_chat_history(
workflow_permanent_id: str,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowCopilotChatHistoryResponse:
latest_chat = await app.DATABASE.get_latest_workflow_copilot_chat(
latest_chat = await app.DATABASE.workflow_params.get_latest_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=workflow_permanent_id,
)
if latest_chat:
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(latest_chat.workflow_copilot_chat_id)
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
latest_chat.workflow_copilot_chat_id
)
else:
chat_messages = []
return WorkflowCopilotChatHistoryResponse(
@ -814,7 +816,7 @@ async def workflow_copilot_clear_proposed_workflow(
clear_request: WorkflowCopilotClearProposedWorkflowRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> None:
updated_chat = await app.DATABASE.update_workflow_copilot_chat(
updated_chat = await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_copilot_chat_id=clear_request.workflow_copilot_chat_id,
proposed_workflow=None,

View file

@ -207,7 +207,7 @@ class Block(BaseModel, abc.ABC):
parameter=self.output_parameter,
value=value,
)
await app.DATABASE.create_or_update_workflow_run_output_parameter(
await app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=value,
@ -238,7 +238,7 @@ class Block(BaseModel, abc.ABC):
output_parameter_value = {"value": output_parameter_value}
if workflow_run_block_id:
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
output=output_parameter_value,
status=status,
@ -495,7 +495,7 @@ class Block(BaseModel, abc.ABC):
LOG.exception("Failed to generate description for the workflow run block", error=e)
if description:
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
description=description,
organization_id=organization_id,
@ -522,7 +522,7 @@ class Block(BaseModel, abc.ABC):
if isinstance(self, BaseTaskBlock):
engine = self.engine
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_block = await app.DATABASE.observer.create_workflow_run_block(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
parent_workflow_run_block_id=parent_workflow_run_block_id,
@ -707,7 +707,9 @@ class BaseTaskBlock(Block):
"""
Returns the order and retry for the next task in the workflow run as a tuple.
"""
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
last_task_for_workflow_run = await app.DATABASE.tasks.get_last_task_for_workflow_run(
workflow_run_id=workflow_run_id
)
# If there is no previous task, the order will be 0 and the retry will be 0.
if last_task_for_workflow_run is None:
return 0, 0
@ -746,7 +748,7 @@ class BaseTaskBlock(Block):
This helper method consolidates the error detection logic that was previously
duplicated across multiple exception handlers in the execute method.
"""
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=organization_id,
@ -764,7 +766,7 @@ class BaseTaskBlock(Block):
if detected_errors:
# Only pass new errors — update_task() appends to existing errors
new_errors = [error.model_dump() for error in detected_errors]
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task.task_id,
organization_id=organization_id,
errors=new_errors,
@ -867,7 +869,7 @@ class BaseTaskBlock(Block):
task_order=task_order,
task_retry=task_retry,
)
workflow_run_block = await app.DATABASE.update_workflow_run_block(
workflow_run_block = await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
task_id=task.task_id,
organization_id=organization_id,
@ -1011,7 +1013,7 @@ class BaseTaskBlock(Block):
current_context.task_id = None
# Check task status
updated_task = await app.DATABASE.get_task(
updated_task = await app.DATABASE.tasks.get_task(
task_id=task.task_id, organization_id=workflow_run.organization_id
)
if not updated_task:
@ -1941,7 +1943,7 @@ class ForLoopBlock(Block):
)
try:
if block_output.workflow_run_block_id:
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=block_output.workflow_run_block_id,
organization_id=organization_id,
current_value=str(loop_over_value),
@ -2106,7 +2108,7 @@ class ForLoopBlock(Block):
organization_id=organization_id,
)
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
loop_values=loop_over_values,
@ -2586,7 +2588,9 @@ class TextPromptBlock(Block):
artifacts_to_persist: list[tuple[ArtifactType, bytes]] = []
if workflow_run_block_id:
try:
workflow_run_block = await app.DATABASE.get_workflow_run_block(workflow_run_block_id, organization_id)
workflow_run_block = await app.DATABASE.observer.get_workflow_run_block(
workflow_run_block_id, organization_id
)
if workflow_run_block:
artifacts_to_persist.append((ArtifactType.LLM_PROMPT, prompt.encode("utf-8")))
except Exception as e:
@ -2645,7 +2649,7 @@ class TextPromptBlock(Block):
)
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
prompt=self.prompt,
@ -3444,7 +3448,7 @@ class SendEmailBlock(Block):
**kwargs: dict,
) -> BlockResult:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
recipients=self.recipients,
@ -4140,7 +4144,7 @@ class WaitBlock(Block):
**kwargs: dict,
) -> BlockResult:
# TODO: we need to support to interrupt the sleep when the workflow run failed/cancelled/terminated
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
wait_sec=self.wait_sec,
@ -4246,7 +4250,7 @@ class HumanInteractionBlock(BaseTaskBlock):
organization_id=organization_id,
)
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
recipients=self.recipients,
@ -4265,12 +4269,12 @@ class HumanInteractionBlock(BaseTaskBlock):
browser_session_id=browser_session_id,
)
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
status=WorkflowRunStatus.paused,
)
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -4372,7 +4376,7 @@ class HumanInteractionBlock(BaseTaskBlock):
organization_id=organization_id,
)
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -4574,7 +4578,7 @@ class TaskV2Block(Block):
organization = await app.DATABASE.organizations.get_organization(organization_id)
if not organization:
raise ValueError(f"Organization not found {organization_id}")
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id)
if not workflow_run:
raise ValueError(f"WorkflowRun not found {workflow_run_id} when running TaskV2Block")
try:
@ -4588,15 +4592,15 @@ class TaskV2Block(Block):
totp_verification_url=resolved_totp_verification_url,
max_screenshot_scrolling_times=workflow_run.max_screenshot_scrolls,
)
await app.DATABASE.update_task_v2(
await app.DATABASE.observer.update_task_v2(
task_v2.observer_cruise_id, status=TaskV2Status.queued, organization_id=organization_id
)
if task_v2.workflow_run_id:
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=task_v2.workflow_run_id,
status=WorkflowRunStatus.queued,
)
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
block_workflow_run_id=task_v2.workflow_run_id,
@ -4642,7 +4646,9 @@ class TaskV2Block(Block):
failure_reason: str | None = None
task_v2_workflow_run_id = task_v2.workflow_run_id
if task_v2_workflow_run_id:
task_v2_workflow_run = await app.DATABASE.get_workflow_run(task_v2_workflow_run_id, organization_id)
task_v2_workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
task_v2_workflow_run_id, organization_id
)
if task_v2_workflow_run:
failure_reason = task_v2_workflow_run.failure_reason
@ -5237,7 +5243,7 @@ class PrintPageBlock(Block):
return None, None
try:
workflow_run_block = await app.DATABASE.get_workflow_run_block(
workflow_run_block = await app.DATABASE.observer.get_workflow_run_block(
workflow_run_block_id,
organization_id=artifact_org_id,
)
@ -5269,7 +5275,7 @@ class PrintPageBlock(Block):
# Generate a downloadable URL for the artifact
artifact_url = None
try:
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id=artifact_org_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(artifact_id, organization_id=artifact_org_id)
if artifact:
artifact_url = await app.ARTIFACT_MANAGER.get_share_link(artifact)
except Exception:
@ -6575,7 +6581,7 @@ class WorkflowTriggerBlock(Block):
f"Workflow trigger depth exceeds maximum of {self.MAX_TRIGGER_DEPTH}. "
"This may indicate a circular workflow trigger chain."
)
run = await app.DATABASE.get_workflow_run(current_run_id)
run = await app.DATABASE.workflow_runs.get_workflow_run(current_run_id)
if not run or not run.parent_workflow_run_id:
break
current_run_id = run.parent_workflow_run_id
@ -6721,7 +6727,7 @@ class WorkflowTriggerBlock(Block):
elif self.wait_for_completion:
# Sync mode: child runs inline in the same process, so it needs
# its own persistent session to avoid sharing the parent's browser.
parent_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id)
parent_workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id)
proxy_location = parent_workflow_run.proxy_location if parent_workflow_run else None
try:
child_browser_session = await app.PERSISTENT_SESSIONS_MANAGER.create_session(

View file

@ -363,14 +363,14 @@ class WorkflowService:
target_labels = set(block_labels_to_disable)
for candidate in candidates:
script = await app.DATABASE.get_script(
script = await app.DATABASE.scripts.get_script(
script_id=candidate.script_id,
organization_id=organization_id,
)
if not script:
continue
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
@ -400,7 +400,7 @@ class WorkflowService:
"""Remove cached run signatures for the supplied block groups to force regeneration."""
for group in groups:
for block in group.blocks_to_clear:
await app.DATABASE.update_script_block(
await app.DATABASE.scripts.update_script_block(
script_block_id=block.script_block_id,
organization_id=organization_id,
clear_run_signature=True,
@ -451,7 +451,7 @@ class WorkflowService:
if not artifact_ids or not organization_id:
return []
artifacts = await app.DATABASE.get_artifacts_by_ids(artifact_ids, organization_id)
artifacts = await app.DATABASE.artifacts.get_artifacts_by_ids(artifact_ids, organization_id)
if not artifacts:
return []
@ -988,7 +988,7 @@ class WorkflowService:
if browser_session:
browser_session_id = browser_session.persistent_browser_session_id
close_browser_on_completion = True
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run.workflow_run_id,
browser_session_id=browser_session_id,
)
@ -1064,7 +1064,7 @@ class WorkflowService:
# Refresh workflow_run from DB to pick up status/failure_reason
# set by _execute_workflow_blocks.
if refreshed_workflow_run := await app.DATABASE.get_workflow_run(
if refreshed_workflow_run := await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
):
@ -1110,7 +1110,7 @@ class WorkflowService:
should_trigger_reviewer = True
current_ctx = skyvern_context.current()
if current_ctx and current_ctx.script_id:
latest_script = await app.DATABASE.get_latest_script_version(
latest_script = await app.DATABASE.scripts.get_latest_script_version(
script_id=current_ctx.script_id,
organization_id=workflow.organization_id,
)
@ -1257,7 +1257,7 @@ class WorkflowService:
context.script_id = script.script_id
context.script_revision_id = script.script_revision_id
try:
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
@ -1271,7 +1271,7 @@ class WorkflowService:
if is_script_run:
# load the script files
script_files = await app.DATABASE.get_script_files(
script_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
@ -1280,7 +1280,7 @@ class WorkflowService:
script_path = os.path.join(settings.TEMP_PATH, script.script_id, "main.py")
if os.path.exists(script_path):
# setup script run
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id
)
script_parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
@ -1353,7 +1353,7 @@ class WorkflowService:
is_script_run = True
# Initialize RunContext with the browser page + parameters,
# same as the normal script loading path at line 1310.
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
script_parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
@ -1559,7 +1559,7 @@ class WorkflowService:
# browser-automation code like page.classify for pure-Python conditionals).
fallback_type = "conditional_agent" if isinstance(block, ConditionalBlock) else "full_block"
episode = await app.DATABASE.create_fallback_episode(
episode = await app.DATABASE.scripts.create_fallback_episode(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id,
@ -1824,7 +1824,7 @@ class WorkflowService:
block_executed_with_code = False
try:
if refreshed_workflow_run := await app.DATABASE.get_workflow_run(
if refreshed_workflow_run := await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
):
@ -1884,12 +1884,12 @@ class WorkflowService:
# Persist the browser_profile_id on the workflow_run so
# subsequent blocks create / reuse a browser with the
# saved profile (cookies, localStorage, etc.).
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
browser_profile_id=resolved_browser_profile_id,
)
workflow_run = (
await app.DATABASE.get_workflow_run(
await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -1932,7 +1932,7 @@ class WorkflowService:
)
profile_loaded = False
# Clear the profile so the normal login path doesn't reuse it
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
browser_profile_id=None,
)
@ -2027,7 +2027,7 @@ class WorkflowService:
exc_info=True,
)
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -2215,12 +2215,12 @@ class WorkflowService:
fallback_wrb_id = workflow_run_block_result.workflow_run_block_id
if fallback_wrb_id:
try:
wrb = await app.DATABASE.get_workflow_run_block(
wrb = await app.DATABASE.observer.get_workflow_run_block(
workflow_run_block_id=fallback_wrb_id,
organization_id=organization_id,
)
if wrb and wrb.task_id:
actions = await app.DATABASE.get_task_actions(
actions = await app.DATABASE.tasks.get_task_actions(
task_id=wrb.task_id,
organization_id=organization_id,
)
@ -2232,7 +2232,7 @@ class WorkflowService:
exc_info=True,
)
await app.DATABASE.update_fallback_episode(
await app.DATABASE.scripts.update_fallback_episode(
episode_id=fallback_episode_id,
organization_id=organization_id,
agent_actions=agent_actions_summary,
@ -2306,7 +2306,7 @@ class WorkflowService:
}
)
cond_context = skyvern_context.current()
cond_episode = await app.DATABASE.create_fallback_episode(
cond_episode = await app.DATABASE.scripts.create_fallback_episode(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id,
@ -2321,7 +2321,7 @@ class WorkflowService:
"expressions": expressions,
},
)
await app.DATABASE.update_fallback_episode(
await app.DATABASE.scripts.update_fallback_episode(
episode_id=cond_episode.episode_id,
organization_id=organization_id,
fallback_succeeded=True,
@ -2450,7 +2450,7 @@ class WorkflowService:
# falls back to default_value on the workflow parameter).
if run_param_tuples is None:
try:
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run_id,
)
except Exception:
@ -2486,7 +2486,7 @@ class WorkflowService:
)
if db_cred and db_cred.browser_profile_id:
# Verify the browser profile still exists before using it
profile = await app.DATABASE.get_browser_profile(
profile = await app.DATABASE.browser_sessions.get_browser_profile(
profile_id=db_cred.browser_profile_id,
organization_id=organization_id,
)
@ -3186,7 +3186,7 @@ class WorkflowService:
if not block_run:
continue
output_parameter = await app.DATABASE.get_workflow_run_output_parameter_by_id(
output_parameter = await app.DATABASE.workflow_runs.get_workflow_run_output_parameter_by_id(
workflow_run_id=block_run.workflow_run_id, output_parameter_id=block_run.output_parameter_id
)
@ -3323,7 +3323,7 @@ class WorkflowService:
previous_blocks=current_definition.get("blocks", []),
new_blocks=new_definition.get("blocks", []),
)
candidates = await app.DATABASE.get_workflow_scripts_by_permanent_id(
candidates = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
organization_id=organization_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
)
@ -3396,7 +3396,7 @@ class WorkflowService:
if len(to_delete) > 0:
try:
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
await app.DATABASE.scripts.delete_workflow_scripts_by_permanent_id(
organization_id=organization_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
script_ids=[s.script_id for s in to_delete],
@ -3458,7 +3458,7 @@ class WorkflowService:
search_key: str | None = None,
error_code: str | None = None,
) -> list[WorkflowRun]:
return await app.DATABASE.get_workflow_runs(
return await app.DATABASE.workflow_runs.get_workflow_runs(
organization_id=organization_id,
page=page,
page_size=page_size,
@ -3473,7 +3473,7 @@ class WorkflowService:
organization_id: str,
status: list[WorkflowRunStatus] | None = None,
) -> int:
return await app.DATABASE.get_workflow_runs_count(
return await app.DATABASE.workflow_runs.get_workflow_runs_count(
organization_id=organization_id,
status=status,
)
@ -3488,7 +3488,7 @@ class WorkflowService:
search_key: str | None = None,
error_code: str | None = None,
) -> list[WorkflowRun]:
return await app.DATABASE.get_workflow_runs_for_workflow_permanent_id(
return await app.DATABASE.workflow_runs.get_workflow_runs_for_workflow_permanent_id(
workflow_permanent_id=workflow_permanent_id,
organization_id=organization_id,
page=page,
@ -3515,7 +3515,7 @@ class WorkflowService:
# validate the browser session or profile id
browser_profile_id = workflow_request.browser_profile_id
if workflow_request.browser_session_id:
browser_session = await app.DATABASE.get_persistent_browser_session(
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
session_id=workflow_request.browser_session_id,
organization_id=organization_id,
)
@ -3531,7 +3531,7 @@ class WorkflowService:
)
if browser_profile_id:
browser_profile = await app.DATABASE.get_browser_profile(
browser_profile = await app.DATABASE.browser_sessions.get_browser_profile(
browser_profile_id,
organization_id=organization_id,
)
@ -3579,7 +3579,7 @@ class WorkflowService:
browser_session_id=browser_session_id,
)
return await app.DATABASE.create_workflow_run(
return await app.DATABASE.workflow_runs.create_workflow_run(
workflow_permanent_id=workflow_permanent_id,
workflow_id=workflow_id,
organization_id=organization_id,
@ -3612,7 +3612,7 @@ class WorkflowService:
ai_fallback: bool | None = None,
failure_category: list[dict] | None = None,
) -> WorkflowRun:
workflow_run = await app.DATABASE.update_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
status=status,
failure_reason=failure_reason,
@ -3659,7 +3659,7 @@ class WorkflowService:
) -> None:
"""Fire-and-forget: propagate workflow_run status to task_runs."""
try:
await app.DATABASE.sync_task_run_status(
await app.DATABASE.tasks.sync_task_run_status(
organization_id=workflow_run.organization_id,
run_id=workflow_run_id,
status=status.value,
@ -3667,12 +3667,12 @@ class WorkflowService:
finished_at=workflow_run.finished_at,
)
# Also sync task_v2 if this workflow_run backs an observer_cruise
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
workflow_run_id=workflow_run_id,
organization_id=workflow_run.organization_id,
)
if task_v2:
await app.DATABASE.sync_task_run_status(
await app.DATABASE.tasks.sync_task_run_status(
organization_id=workflow_run.organization_id,
run_id=task_v2.observer_cruise_id,
status=status.value,
@ -3882,7 +3882,7 @@ class WorkflowService:
)
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -3898,7 +3898,7 @@ class WorkflowService:
default_value: bool | int | float | str | dict | list | None = None,
description: str | None = None,
) -> WorkflowParameter:
return await app.DATABASE.create_workflow_parameter(
return await app.DATABASE.workflow_params.create_workflow_parameter(
workflow_id=workflow_id,
workflow_parameter_type=workflow_parameter_type,
key=key,
@ -3909,17 +3909,19 @@ class WorkflowService:
async def create_aws_secret_parameter(
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
) -> AWSSecretParameter:
return await app.DATABASE.create_aws_secret_parameter(
return await app.DATABASE.workflow_params.create_aws_secret_parameter(
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
)
async def create_output_parameter(
self, workflow_id: str, key: str, description: str | None = None
) -> OutputParameter:
return await app.DATABASE.create_output_parameter(workflow_id=workflow_id, key=key, description=description)
return await app.DATABASE.workflow_params.create_output_parameter(
workflow_id=workflow_id, key=key, description=description
)
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
return await app.DATABASE.workflow_params.get_workflow_parameters(workflow_id=workflow_id)
async def create_workflow_run_parameter(
self,
@ -3929,7 +3931,7 @@ class WorkflowService:
) -> WorkflowRunParameter:
value = self._serialize_workflow_run_parameter_value(workflow_parameter, value)
return await app.DATABASE.create_workflow_run_parameter(
return await app.DATABASE.workflow_runs.create_workflow_run_parameter(
workflow_run_id=workflow_run_id,
workflow_parameter=workflow_parameter,
value=value,
@ -3945,7 +3947,7 @@ class WorkflowService:
for workflow_parameter, value in workflow_parameter_values
]
return await app.DATABASE.create_workflow_run_parameters(
return await app.DATABASE.workflow_runs.create_workflow_run_parameters(
workflow_run_id=workflow_run_id,
workflow_parameter_values=serialized_workflow_parameter_values,
)
@ -3960,27 +3962,27 @@ class WorkflowService:
async def get_workflow_run_parameter_tuples(
self, workflow_run_id: str
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
return await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
@staticmethod
async def get_workflow_output_parameters(workflow_id: str) -> list[OutputParameter]:
return await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
return await app.DATABASE.workflow_params.get_workflow_output_parameters(workflow_id=workflow_id)
@staticmethod
async def get_workflow_run_output_parameters(
workflow_run_id: str,
) -> list[WorkflowRunOutputParameter]:
return await app.DATABASE.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
return await app.DATABASE.workflow_runs.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
@staticmethod
async def get_output_parameter_workflow_run_output_parameter_tuples(
workflow_id: str,
workflow_run_id: str,
) -> list[tuple[OutputParameter, WorkflowRunOutputParameter]]:
workflow_run_output_parameters = await app.DATABASE.get_workflow_run_output_parameters(
workflow_run_output_parameters = await app.DATABASE.workflow_runs.get_workflow_run_output_parameters(
workflow_run_id=workflow_run_id
)
output_parameters = await app.DATABASE.get_workflow_output_parameters_by_ids(
output_parameters = await app.DATABASE.workflow_params.get_workflow_output_parameters_by_ids(
output_parameter_ids=[
workflow_run_output_parameter.output_parameter_id
for workflow_run_output_parameter in workflow_run_output_parameters
@ -3995,10 +3997,10 @@ class WorkflowService:
]
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
return await app.DATABASE.tasks.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
return await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
async def get_recent_task_screenshot_artifacts(
self,
@ -4015,7 +4017,7 @@ class WorkflowService:
artifacts: list[Artifact] = []
if task_id:
artifacts = (
await app.DATABASE.get_latest_n_artifacts(
await app.DATABASE.artifacts.get_latest_n_artifacts(
task_id=task_id,
artifact_types=artifact_types,
organization_id=organization_id,
@ -4024,13 +4026,13 @@ class WorkflowService:
or []
)
elif task_v2_id:
action_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
action_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
organization_id=organization_id,
artifact_type=ArtifactType.SCREENSHOT_ACTION,
task_v2_id=task_v2_id,
limit=limit,
)
final_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
final_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
organization_id=organization_id,
artifact_type=ArtifactType.SCREENSHOT_FINAL,
task_v2_id=task_v2_id,
@ -4077,10 +4079,10 @@ class WorkflowService:
seen_artifact_ids: set[str] = set()
if workflow_run_tasks is None:
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
for task in workflow_run_tasks[::-1]:
artifact = await app.DATABASE.get_latest_artifact(
artifact = await app.DATABASE.artifacts.get_latest_artifact(
task_id=task.task_id,
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
organization_id=organization_id,
@ -4092,13 +4094,13 @@ class WorkflowService:
break
if len(screenshot_artifacts) < limit:
action_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
action_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
organization_id=organization_id,
artifact_type=ArtifactType.SCREENSHOT_ACTION,
workflow_run_id=workflow_run_id,
limit=limit,
)
final_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
final_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
organization_id=organization_id,
artifact_type=ArtifactType.SCREENSHOT_FINAL,
workflow_run_id=workflow_run_id,
@ -4173,11 +4175,11 @@ class WorkflowService:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
screenshot_urls: list[str] | None = await self.get_recent_workflow_screenshot_urls(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
@ -4201,7 +4203,7 @@ class WorkflowService:
LOG.warning("Timeout getting recordings", browser_session_id=workflow_run.browser_session_id)
if recording_url is None:
recording_artifact = await app.DATABASE.get_artifact_for_run(
recording_artifact = await app.DATABASE.artifacts.get_artifact_for_run(
run_id=task_v2.observer_cruise_id if task_v2 else workflow_run_id,
artifact_type=ArtifactType.RECORDING,
organization_id=organization_id,
@ -4239,7 +4241,9 @@ class WorkflowService:
workflow_run_id=workflow_run.workflow_run_id,
)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
workflow_parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run_id
)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
output_parameter_tuples: list[
tuple[OutputParameter, WorkflowRunOutputParameter]
@ -4270,7 +4274,7 @@ class WorkflowService:
# matching the task-level error format. Uses a lightweight query that only
# fetches blocks with non-null error_codes to avoid a full block load on
# every status poll.
block_errors = await app.DATABASE.get_workflow_run_block_errors(
block_errors = await app.DATABASE.workflow_runs.get_workflow_run_block_errors(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
for error_codes, failure_reason in block_errors:
@ -4286,13 +4290,13 @@ class WorkflowService:
total_steps = None
total_cost = None
if include_step_count or include_cost:
workflow_run_steps = await app.DATABASE.get_steps_by_task_ids(
workflow_run_steps = await app.DATABASE.tasks.get_steps_by_task_ids(
task_ids=[task.task_id for task in workflow_run_tasks], organization_id=organization_id
)
total_steps = len(workflow_run_steps)
if include_cost:
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
text_prompt_blocks = [
@ -4353,7 +4357,7 @@ class WorkflowService:
# tasks into the parent list for debug artifact persistence, and collect
# child workflow_run IDs so cleanup_for_workflow_run can pop their orphaned
# entries from self.pages (child skips clean_up_workflow).
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
parent_workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow_run.organization_id,
)
@ -4514,7 +4518,7 @@ class WorkflowService:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run.workflow_run_id,
webhook_failure_reason="",
)
@ -4528,7 +4532,7 @@ class WorkflowService:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run.workflow_run_id,
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",
)
@ -4611,7 +4615,7 @@ class WorkflowService:
workflow: Workflow,
workflow_run: WorkflowRun,
) -> None:
last_step = await app.DATABASE.get_latest_step(
last_step = await app.DATABASE.tasks.get_latest_step(
task_id=last_task.task_id, organization_id=last_task.organization_id
)
if not last_step:
@ -4679,7 +4683,7 @@ class WorkflowService:
workflow_id=workflow_id,
)
await app.DATABASE.save_workflow_definition_parameters(workflow_definition.parameters)
await app.DATABASE.workflow_params.save_workflow_definition_parameters(workflow_definition.parameters)
return workflow_definition
@ -4830,7 +4834,7 @@ class WorkflowService:
@staticmethod
async def create_output_parameter_for_block(workflow_id: str, block_yaml: BLOCK_YAML_TYPES) -> OutputParameter:
output_parameter_key = f"{block_yaml.label}_output"
return await app.DATABASE.create_output_parameter(
return await app.DATABASE.workflow_params.create_output_parameter(
workflow_id=workflow_id,
key=output_parameter_key,
description=f"Output parameter for block {block_yaml.label}",
@ -4875,7 +4879,7 @@ class WorkflowService:
"""
build the tree structure of the workflow run timeline
"""
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -4884,7 +4888,7 @@ class WorkflowService:
task_id_to_block: dict[str, WorkflowRunBlock] = {
block.task_id: block for block in workflow_run_blocks if block.task_id
}
actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids, organization_id=organization_id)
actions = await app.DATABASE.tasks.get_tasks_actions(task_ids=task_ids, organization_id=organization_id)
for action in actions:
if not action.task_id:
continue
@ -5005,7 +5009,7 @@ class WorkflowService:
return None
cached_block_labels: set[str] = set()
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=existing_script.script_revision_id,
organization_id=workflow.organization_id,
)
@ -5096,7 +5100,7 @@ class WorkflowService:
return
# Get the latest version number so we can increment it
version_stats = await app.DATABASE.get_script_version_stats(
version_stats = await app.DATABASE.scripts.get_script_version_stats(
organization_id=workflow.organization_id,
script_ids=[existing_script.script_id],
)
@ -5117,7 +5121,7 @@ class WorkflowService:
)
# Create a new version of the SAME script_id instead of a new script
regenerated_script = await app.DATABASE.create_script(
regenerated_script = await app.DATABASE.scripts.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
script_id=existing_script.script_id,
@ -5135,7 +5139,7 @@ class WorkflowService:
# If generation failed (e.g. syntax error), clean up the empty script row
# to avoid orphaned versions that skip version numbers on next regeneration.
script_files = await app.DATABASE.get_script_files(
script_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=regenerated_script.script_revision_id,
organization_id=workflow.organization_id,
)
@ -5145,7 +5149,7 @@ class WorkflowService:
script_id=regenerated_script.script_id,
version=regenerated_script.version,
)
await app.DATABASE.soft_delete_script_by_revision(
await app.DATABASE.scripts.soft_delete_script_by_revision(
script_revision_id=regenerated_script.script_revision_id,
organization_id=workflow.organization_id,
)
@ -5194,7 +5198,7 @@ class WorkflowService:
await _regenerate_script()
return
created_script = await app.DATABASE.create_script(
created_script = await app.DATABASE.scripts.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
)
@ -5236,7 +5240,7 @@ class WorkflowService:
# Check if the script is pinned — skip auto-review for pinned scripts.
# Query by script_id (not workflow_run_id) because pinning is applied
# at the cache_key_value level and may not be on this run's row.
if await app.DATABASE.is_script_pinned(
if await app.DATABASE.scripts.is_script_pinned(
organization_id=workflow.organization_id,
script_id=script_id,
):
@ -5357,7 +5361,7 @@ class WorkflowService:
) -> None:
"""Run the script reviewer inside a lock. Episodes are scoped to the script version."""
# Double-check: re-query episodes after acquiring lock (another process may have reviewed them)
episodes = await app.DATABASE.get_unreviewed_episodes(
episodes = await app.DATABASE.scripts.get_unreviewed_episodes(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
script_revision_id=script_revision_id,
@ -5375,7 +5379,7 @@ class WorkflowService:
# Query stale branches for TTL-based pruning
stale_branches: list = []
try:
stale_branches = await app.DATABASE.get_stale_branches(
stale_branches = await app.DATABASE.scripts.get_stale_branches(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
stale_days=90,
@ -5393,7 +5397,7 @@ class WorkflowService:
# Use the latest version as the base (not the potentially-stale run revision)
reviewer_base_revision_id = script_revision_id
try:
latest = await app.DATABASE.get_latest_script_version(
latest = await app.DATABASE.scripts.get_latest_script_version(
script_id=script_id,
organization_id=workflow.organization_id,
)
@ -5405,7 +5409,7 @@ class WorkflowService:
# Fetch historical (already-reviewed) episodes for cross-run context
historical_episodes: list = []
try:
historical_episodes = await app.DATABASE.get_recent_reviewed_episodes(
historical_episodes = await app.DATABASE.scripts.get_recent_reviewed_episodes(
workflow_permanent_id=workflow.workflow_permanent_id,
organization_id=workflow.organization_id,
limit=20,
@ -5457,7 +5461,7 @@ class WorkflowService:
# use context.parameters['recipient'] instead of a literal string).
run_parameter_values: dict[str, str] = {}
try:
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
for wf_param, run_param in run_param_tuples:
@ -5513,7 +5517,7 @@ class WorkflowService:
)
# Still mark episodes as reviewed
for episode in episodes:
await app.DATABASE.mark_episode_reviewed(
await app.DATABASE.scripts.mark_episode_reviewed(
episode_id=episode.episode_id,
organization_id=workflow.organization_id,
reviewer_output=None,
@ -5523,7 +5527,7 @@ class WorkflowService:
# Get the base script to create a new version from
base_script = None
if script_revision_id:
base_script = await app.DATABASE.get_script_revision(
base_script = await app.DATABASE.scripts.get_script_revision(
script_revision_id=script_revision_id,
organization_id=workflow.organization_id,
)
@ -5558,7 +5562,7 @@ class WorkflowService:
# Mark all episodes as reviewed
for episode in episodes:
await app.DATABASE.mark_episode_reviewed(
await app.DATABASE.scripts.mark_episode_reviewed(
episode_id=episode.episode_id,
organization_id=workflow.organization_id,
reviewer_output=str(updated_blocks) if updated_blocks else None,

View file

@ -18,7 +18,7 @@ async def get_action_history(
"""
# Get action results from the last history_window steps
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
# the last step is always the newly created one and it should be excluded from the history window
window_steps = steps[-1 - history_window : -1]
if current_step:

View file

@ -53,12 +53,12 @@ async def check_running_tasks_or_workflows() -> tuple[bool, int, int]:
stale_threshold = settings.CLEANUP_STALE_TASK_THRESHOLD_HOURS
# Check tasks
active_tasks, stale_tasks = await app.DATABASE.get_running_tasks_info_globally(
active_tasks, stale_tasks = await app.DATABASE.tasks.get_running_tasks_info_globally(
stale_threshold_hours=stale_threshold
)
# Check workflow runs
active_workflows, stale_workflows = await app.DATABASE.get_running_workflow_runs_info_globally(
active_workflows, stale_workflows = await app.DATABASE.workflow_runs.get_running_workflow_runs_info_globally(
stale_threshold_hours=stale_threshold
)

View file

@ -10,12 +10,12 @@ from skyvern.services import task_v1_service, task_v2_service, webhook_service,
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
if not run:
# try to see if it's a workflow run id for task v2
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
if task_v2:
run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
run = await app.DATABASE.tasks.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
if not run:
return None
@ -73,7 +73,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
step_count=task_v1_response.step_count,
)
elif run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(run.run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(run.run_id, organization_id=organization_id)
if not task_v2:
return None
return await task_v2_service.build_task_v2_run_response(task_v2)
@ -83,7 +83,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id, organization_id=organization_id)
if not task:
raise TaskNotFound(task_id=task_id)
task = await app.agent.update_task(task, status=TaskStatus.canceled)
@ -91,7 +91,7 @@ async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_k
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(task_id, organization_id=organization_id)
if not task_v2:
raise TaskNotFound(task_id=task_id)
await task_v2_service.mark_task_v2_as_canceled(
@ -102,7 +102,7 @@ async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> No
async def cancel_workflow_run(
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
) -> None:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -110,7 +110,7 @@ async def cancel_workflow_run(
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
parent_workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@ -128,7 +128,7 @@ async def cancel_workflow_run(
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
if not run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -156,7 +156,7 @@ async def retry_run_webhook(
) -> None:
"""Retry sending the webhook for a run."""
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
if not run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -175,19 +175,19 @@ async def retry_run_webhook(
return
if run.task_run_type in [RunType.task_v1, RunType.openai_cua, RunType.anthropic_cua, RunType.ui_tars]:
task = await app.DATABASE.get_task(run_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(run_id, organization_id=organization_id)
if not task:
raise TaskNotFound(task_id=run_id)
latest_step = await app.DATABASE.get_latest_step(run_id, organization_id=organization_id)
latest_step = await app.DATABASE.tasks.get_latest_step(run_id, organization_id=organization_id)
if latest_step:
await app.agent.execute_task_webhook(task=task, api_key=api_key)
elif run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(run_id, organization_id=organization_id)
if not task_v2:
raise TaskNotFound(task_id=run_id)
await task_v2_service.send_task_v2_webhook(task_v2)
elif run.task_run_type == RunType.workflow_run:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=run_id,
organization_id=organization_id,
)

View file

@ -184,7 +184,7 @@ class ScriptReviewer:
async def _load_run_params(run_id: str) -> tuple[str, dict[str, str]]:
try:
param_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=run_id)
param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=run_id)
params = {
wf_param.key: str(run_param.value)
for wf_param, run_param in param_tuples
@ -213,7 +213,7 @@ class ScriptReviewer:
triaged_episodes.append(episode)
else:
# Mark as reviewed so we don't re-triage on every run
await app.DATABASE.mark_episode_reviewed(
await app.DATABASE.scripts.mark_episode_reviewed(
episode_id=episode.episode_id,
organization_id=organization_id,
reviewer_output="TRIAGE: not_code_fixable — skipped",
@ -1081,20 +1081,20 @@ class ScriptReviewer:
return None
try:
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script_revision_id,
organization_id=organization_id,
)
for sb in script_blocks:
if sb.script_block_label == block_label and sb.script_file_id:
# Load the code from the script file
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id=script_revision_id,
file_id=sb.script_file_id,
organization_id=organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
artifact_id=script_file.artifact_id,
organization_id=organization_id,
)
@ -1124,7 +1124,7 @@ class ScriptReviewer:
if not script_revision_id:
return {}
try:
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script_revision_id,
organization_id=organization_id,
)
@ -1132,13 +1132,13 @@ class ScriptReviewer:
for sb in script_blocks:
if not sb.script_file_id:
continue
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id=script_revision_id,
file_id=sb.script_file_id,
organization_id=organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
artifact_id=script_file.artifact_id,
organization_id=organization_id,
)

View file

@ -128,7 +128,7 @@ async def build_file_tree(
try:
if pending:
# get the script file object
script_file = await app.DATABASE.get_script_file_by_path(
script_file = await app.DATABASE.scripts.get_script_file_by_path(
script_revision_id=script_revision_id,
file_path=file.path,
organization_id=organization_id,
@ -140,7 +140,7 @@ async def build_file_tree(
script_file_id=script_file.file_id,
)
continue
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
# override the actual file in the storage
asyncio.create_task(app.STORAGE.store_artifact(artifact, content_bytes))
@ -153,7 +153,7 @@ async def build_file_tree(
data=content_bytes,
)
# update the artifact_id in the script file
await app.DATABASE.update_script_file(
await app.DATABASE.scripts.update_script_file(
script_file_id=script_file.file_id,
organization_id=organization_id,
artifact_id=artifact_id,
@ -174,7 +174,7 @@ async def build_file_tree(
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
await app.DATABASE.scripts.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
@ -202,7 +202,7 @@ async def build_file_tree(
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
await app.DATABASE.scripts.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
@ -264,10 +264,10 @@ async def create_script(
)
try:
if run_id and not await app.DATABASE.get_run(run_id=run_id, organization_id=organization_id):
if run_id and not await app.DATABASE.tasks.get_run(run_id=run_id, organization_id=organization_id):
raise HTTPException(status_code=404, detail=f"Run_id {run_id} not found")
script = await app.DATABASE.create_script(
script = await app.DATABASE.scripts.create_script(
organization_id=organization_id,
run_id=run_id,
)
@ -306,7 +306,7 @@ async def load_scripts(
# retrieve the artifact
if not file.artifact_id:
continue
artifact = await app.DATABASE.get_artifact_by_id(file.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(file.artifact_id, organization_id)
if not artifact:
LOG.error("Artifact not found", artifact_id=file.artifact_id, script_id=script.script_id)
continue
@ -345,7 +345,7 @@ async def execute_script(
background_tasks: BackgroundTasks | None = None,
) -> None:
# step 1: get the script revision
script = await app.DATABASE.get_script(
script = await app.DATABASE.scripts.get_script(
script_id=script_id,
organization_id=organization_id,
)
@ -353,7 +353,7 @@ async def execute_script(
raise ScriptNotFound(script_id=script_id)
# step 2: get the script files
script_files = await app.DATABASE.get_script_files(
script_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=script.script_revision_id, organization_id=organization_id
)
@ -362,7 +362,7 @@ async def execute_script(
# step 4: execute the script
if workflow_run_id and not parameters:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
LOG.info("Script run Parameters is using workflow run parameters", parameters=parameters)
@ -453,7 +453,7 @@ async def _create_workflow_block_run_and_task(
current_value_str = str(cv) if cv is not None else None
current_index_val = context.loop_metadata.get("current_index")
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_block = await app.DATABASE.observer.create_workflow_run_block(
workflow_run_id=workflow_run_id,
parent_workflow_run_block_id=context.parent_workflow_run_block_id,
organization_id=organization_id,
@ -486,7 +486,7 @@ async def _create_workflow_block_run_and_task(
# Without this, the LLM only sees a generic navigation_goal and can
# falsely mark login as complete when credentials were never entered.
task_complete_criterion = DEFAULT_LOGIN_COMPLETE_CRITERION if block_type == BlockType.LOGIN else None
task = await app.DATABASE.create_task(
task = await app.DATABASE.tasks.create_task(
# fix HACK: changed the type of url to str | None to support None url. url is not used in the script right now.
url=url or "",
title=f"Script {block_type.value} task",
@ -508,7 +508,7 @@ async def _create_workflow_block_run_and_task(
task_id = task.task_id
# create a single step for the task
step = await app.DATABASE.create_step(
step = await app.DATABASE.tasks.create_step(
task_id=task_id,
order=0,
retry_index=0,
@ -525,7 +525,7 @@ async def _create_workflow_block_run_and_task(
)
# Update workflow run block with task_id
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
task_id=task_id,
organization_id=organization_id,
@ -616,7 +616,7 @@ async def _record_output_parameter_value(
parameter=output_parameter,
value=output,
)
await app.DATABASE.create_or_update_workflow_run_output_parameter(
await app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=output_parameter.output_parameter_id,
value=output,
@ -672,7 +672,7 @@ async def _update_workflow_block(
errors=errors,
)
await app.DATABASE.update_step(
await app.DATABASE.tasks.update_step(
step_id=step_id,
task_id=task_id,
organization_id=context.organization_id,
@ -680,7 +680,7 @@ async def _update_workflow_block(
is_last=is_last,
output=step_output,
)
updated_task = await app.DATABASE.update_task(
updated_task = await app.DATABASE.tasks.update_task(
task_id=task_id,
organization_id=context.organization_id,
status=task_status,
@ -719,7 +719,7 @@ async def _update_workflow_block(
final_output = task_output.model_dump()
step_for_billing: Step | None = None
if step_id:
step_for_billing = await app.DATABASE.get_step(
step_for_billing = await app.DATABASE.tasks.get_step(
step_id=step_id,
organization_id=context.organization_id,
)
@ -746,7 +746,7 @@ async def _update_workflow_block(
# final_output is already set to `output` at line 596.
pass
await app.DATABASE.update_workflow_run_block(
await app.DATABASE.observer.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=context.organization_id if context else None,
status=status,
@ -855,7 +855,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
return
try:
script_block = await app.DATABASE.get_script_block_by_label(
script_block = await app.DATABASE.scripts.get_script_block_by_label(
organization_id=context.organization_id,
script_revision_id=context.script_revision_id,
script_block_label=cache_key,
@ -873,7 +873,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
return
try:
source_block = await app.DATABASE.get_workflow_run_block(
source_block = await app.DATABASE.observer.get_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=context.organization_id,
)
@ -886,7 +886,9 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
try:
# actios are ordered by created_at
actions = await app.DATABASE.get_task_actions_hydrated(task_id=task_id, organization_id=context.organization_id)
actions = await app.DATABASE.tasks.get_task_actions_hydrated(
task_id=task_id, organization_id=context.organization_id
)
except Exception:
return
@ -938,7 +940,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
)
step = None
if step_id:
step = await app.DATABASE.get_step(step_id=step_id, organization_id=context.organization_id)
step = await app.DATABASE.tasks.get_step(step_id=step_id, organization_id=context.organization_id)
llm_response = await app.SCRIPT_GENERATION_LLM_API_HANDLER(
prompt=merged_prompt,
prompt_name="merged-block-inputs",
@ -1102,7 +1104,7 @@ async def _fallback_to_ai_run(
workflow_run_id=workflow_run_id,
)
# 1. fail the previous step
previous_step = await app.DATABASE.update_step(
previous_step = await app.DATABASE.tasks.update_step(
step_id=script_step_id,
task_id=task_id,
organization_id=organization_id,
@ -1112,7 +1114,7 @@ async def _fallback_to_ai_run(
organization = await app.DATABASE.organizations.get_organization(organization_id=organization_id)
if not organization:
raise Exception(f"Organization is missing organization_id={organization_id}")
task = await app.DATABASE.get_task(task_id=context.task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=context.task_id, organization_id=organization_id)
if not task:
raise Exception(f"Task is missing task_id={context.task_id}")
workflow = await app.DATABASE.workflows.get_workflow(
@ -1120,7 +1122,7 @@ async def _fallback_to_ai_run(
)
if not workflow:
return
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not workflow_run:
@ -1157,7 +1159,7 @@ async def _fallback_to_ai_run(
if detected_errors:
task_errors = task.errors or []
task_errors.extend([error.model_dump() for error in detected_errors])
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task_id=task_id,
organization_id=organization_id,
errors=task_errors,
@ -1241,7 +1243,7 @@ async def _fallback_to_ai_run(
# _fallback_to_ai_run is only called for TaskBlock-style blocks (navigation,
# extraction, action, login, download), never for ConditionalBlock, so
# fallback_type is always "full_block" here.
episode = await app.DATABASE.create_fallback_episode(
episode = await app.DATABASE.scripts.create_fallback_episode(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
@ -1262,7 +1264,7 @@ async def _fallback_to_ai_run(
)
# 2. create a new step for ai run
ai_step = await app.DATABASE.create_step(
ai_step = await app.DATABASE.tasks.create_step(
task_id=task_id,
organization_id=organization_id,
order=previous_step.order + 1,
@ -1323,7 +1325,7 @@ async def _fallback_to_ai_run(
# update workflow run to indicate that there's a script run
if workflow_run_id:
await app.DATABASE.update_workflow_run(
await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
ai_fallback_triggered=True,
)
@ -1332,7 +1334,7 @@ async def _fallback_to_ai_run(
if workflow_run_block_id:
# refresh the task
failure_reason = None
refreshed_task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
refreshed_task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
if refreshed_task:
task = refreshed_task
if task.status in [TaskStatus.terminated, TaskStatus.failed]:
@ -1389,7 +1391,7 @@ async def _fallback_to_ai_run(
if not fallback_succeeded and task.failure_reason:
agent_actions_summary["failure_reason"] = str(task.failure_reason)[:2000]
try:
actions = await app.DATABASE.get_task_actions(
actions = await app.DATABASE.tasks.get_task_actions(
task_id=task_id,
organization_id=organization_id,
)
@ -1397,7 +1399,7 @@ async def _fallback_to_ai_run(
except Exception:
LOG.debug("Could not fetch actions for fallback episode", exc_info=True)
await app.DATABASE.update_fallback_episode(
await app.DATABASE.scripts.update_fallback_episode(
episode_id=fallback_episode_id,
organization_id=organization_id,
agent_actions=agent_actions_summary,
@ -1468,7 +1470,9 @@ async def _regenerate_script_block_after_ai_fallback(
cache_key_value = ""
if workflow.cache_key:
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run_id
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
cache_key_value = jinja_sandbox_env.from_string(workflow.cache_key).render(parameters)
except Exception as e:
@ -1479,7 +1483,7 @@ async def _regenerate_script_block_after_ai_fallback(
if not cache_key_value:
cache_key_value = cache_key # Fallback
existing_script, _is_pinned = await app.DATABASE.get_workflow_script_by_cache_key_value(
existing_script, _is_pinned = await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
@ -1501,7 +1505,7 @@ async def _regenerate_script_block_after_ai_fallback(
)
# Create a new script version
new_script = await app.DATABASE.create_script(
new_script = await app.DATABASE.scripts.create_script(
organization_id=organization_id,
run_id=workflow_run_id,
script_id=current_script.script_id, # Use same script_id for versioning
@ -1509,14 +1513,14 @@ async def _regenerate_script_block_after_ai_fallback(
)
# deprecate the current workflow script
await app.DATABASE.delete_workflow_cache_key_value(
await app.DATABASE.scripts.delete_workflow_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
)
# Create workflow script mapping for the new version
await app.DATABASE.create_workflow_script(
await app.DATABASE.scripts.create_workflow_script(
organization_id=organization_id,
script_id=new_script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
@ -1527,7 +1531,7 @@ async def _regenerate_script_block_after_ai_fallback(
)
# Get all existing script blocks from the previous version
existing_script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
existing_script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=current_script.script_revision_id,
organization_id=organization_id,
)
@ -1553,7 +1557,7 @@ async def _regenerate_script_block_after_ai_fallback(
# Copy the existing block to the new version
# Get the script file content for this block and copy a new script block for it
if existing_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id=current_script.script_revision_id,
file_id=existing_block.script_file_id,
organization_id=organization_id,
@ -1561,7 +1565,9 @@ async def _regenerate_script_block_after_ai_fallback(
if script_file and script_file.artifact_id:
# Retrieve the artifact content
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
script_file.artifact_id, organization_id
)
if artifact:
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
if file_content:
@ -1650,7 +1656,7 @@ async def _get_block_definition_by_label(
if not final_dump:
return None
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
if task:
task_dump = task.model_dump()
final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump})
@ -1681,7 +1687,7 @@ async def _generate_block_code_from_task(
return ""
try:
# Now regenerate only the specific block that fell back to AI
task_actions = await app.DATABASE.get_task_actions_hydrated(
task_actions = await app.DATABASE.tasks.get_task_actions_hydrated(
task_id=task_id,
organization_id=organization_id,
)
@ -2505,13 +2511,13 @@ async def run_script(
context.script_revision_id = script_revision_id
if workflow_run_id and organization_id:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
# update workfow run to indicate that there's a script run
workflow_run = await app.DATABASE.update_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
workflow_run_id=workflow_run_id,
ai_fallback_triggered=False,
)

View file

@ -11,14 +11,14 @@ async def is_cua_task(
if task.workflow_run_id:
# it's a task based block, should look up the block run to see if it's a CUA task
block = await app.DATABASE.get_workflow_run_block_by_task_id(
block = await app.DATABASE.observer.get_workflow_run_block_by_task_id(
task_id=task.task_id,
organization_id=task.organization_id,
)
if block.engine is not None and block.engine in CUA_ENGINES:
return True
run = await app.DATABASE.get_run(
run = await app.DATABASE.tasks.get_run(
run_id=task.task_id,
organization_id=task.organization_id,
)

View file

@ -25,11 +25,11 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
existing_task_generation = await app.DATABASE.workflow_params.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
new_task_generation = await app.DATABASE.workflow_params.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
@ -53,7 +53,7 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
task_generation = await app.DATABASE.workflow_params.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
@ -94,7 +94,7 @@ async def run_task(
run_type = RunType.anthropic_cua
elif engine == RunEngine.ui_tars:
run_type = RunType.ui_tars
await app.DATABASE.create_task_run(
await app.DATABASE.tasks.create_task_run(
task_run_type=run_type,
organization_id=organization.organization_id,
run_id=created_task.task_id,
@ -123,15 +123,15 @@ async def run_task(
async def get_task_v1_response(task_id: str, organization_id: str | None = None) -> TaskResponse:
task_obj = await app.DATABASE.get_task(task_id, organization_id=organization_id)
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=organization_id)
if not task_obj:
raise TaskNotFound(task_id=task_id)
# get step count efficiently via COUNT query
step_count = await app.DATABASE.get_task_step_count(task_id, organization_id)
step_count = await app.DATABASE.tasks.get_task_step_count(task_id, organization_id)
# get latest step
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
latest_step = await app.DATABASE.tasks.get_latest_step(task_id, organization_id=organization_id)
if not latest_step:
return await app.agent.build_task_response(task=task_obj, step_count=step_count)

View file

@ -119,14 +119,14 @@ async def _summarize_max_steps_failure_reason(
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=str(task_v2.url), draw_boxes=False)
run_blocks = await app.DATABASE.get_workflow_run_blocks(
run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=task_v2.workflow_run_id,
organization_id=organization_id,
)
history = [f"{idx + 1}. {block.description} -- {block.status}" for idx, block in enumerate(run_blocks[::-1])]
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
workflow_run_id=task_v2.workflow_run_id,
@ -197,7 +197,7 @@ async def _handle_task_v2_termination(
)
# Create a dedicated termination thought for UI visibility
termination_thought = await app.DATABASE.create_thought(
termination_thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2_id,
organization_id=organization_id,
workflow_run_id=workflow_run_id,
@ -216,7 +216,7 @@ async def _handle_task_v2_termination(
if source:
output["source"] = source
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=termination_thought.observer_thought_id,
organization_id=organization_id,
output=output,
@ -265,7 +265,7 @@ async def initialize_task_v2(
browser_address: str | None = None,
run_with: str | None = None,
) -> TaskV2:
task_v2 = await app.DATABASE.create_task_v2(
task_v2 = await app.DATABASE.observer.create_task_v2(
prompt=user_prompt,
url=user_url if user_url else None,
organization_id=organization.organization_id,
@ -329,7 +329,7 @@ async def initialize_task_v2(
# update observer cruise
try:
task_v2 = await app.DATABASE.update_task_v2(
task_v2 = await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2.observer_cruise_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=new_workflow.workflow_id,
@ -337,7 +337,7 @@ async def initialize_task_v2(
organization_id=organization.organization_id,
)
if create_task_run:
await app.DATABASE.create_task_run(
await app.DATABASE.tasks.create_task_run(
task_run_type=RunType.task_v2,
organization_id=organization.organization_id,
run_id=task_v2.observer_cruise_id,
@ -369,7 +369,7 @@ async def initialize_task_v2_metadata(
current_browser_url: str | None,
user_url: str | None,
) -> TaskV2:
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2.observer_cruise_id,
organization_id=organization.organization_id,
thought_type=ThoughtType.metadata,
@ -403,7 +403,7 @@ async def initialize_task_v2_metadata(
raise UrlGenerationFailure()
try:
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=organization.organization_id,
workflow_run_id=workflow_run.workflow_run_id,
@ -422,7 +422,7 @@ async def initialize_task_v2_metadata(
organization_id=organization.organization_id,
title=metadata.workflow_title,
)
task_v2 = await app.DATABASE.update_task_v2(
task_v2 = await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2.observer_cruise_id,
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
@ -430,11 +430,11 @@ async def initialize_task_v2_metadata(
url=metadata.url,
organization_id=organization.organization_id,
)
task_run = await app.DATABASE.get_run(
task_run = await app.DATABASE.tasks.get_run(
run_id=task_v2.observer_cruise_id, organization_id=organization.organization_id
)
if task_run:
await app.DATABASE.update_task_run(
await app.DATABASE.tasks.update_task_run(
organization_id=organization.organization_id,
run_id=task_v2.observer_cruise_id,
title=metadata.workflow_title,
@ -465,7 +465,7 @@ async def run_task_v2(
) -> TaskV2:
organization_id = organization.organization_id
try:
task_v2 = await app.DATABASE.get_task_v2(task_v2_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(task_v2_id, organization_id=organization_id)
except Exception:
LOG.error(
"Failed to get task v2",
@ -643,7 +643,7 @@ async def run_task_v2_helper(
)
)
task_v2 = await app.DATABASE.update_task_v2(
task_v2 = await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2_id, organization_id=organization_id, status=TaskV2Status.running
)
await app.WORKFLOW_SERVICE.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
@ -802,7 +802,7 @@ async def run_task_v2_helper(
task_history=task_history,
local_datetime=datetime.now(context.tz_info).isoformat(),
)
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2_id,
organization_id=organization_id,
workflow_run_id=workflow_run.workflow_run_id,
@ -837,7 +837,7 @@ async def run_task_v2_helper(
plan = task_v2_response.get("plan", "")
task_type = task_v2_response.get("task_type", "")
# Create and save task thought
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=organization_id,
thought=thoughts,
@ -1055,7 +1055,7 @@ async def run_task_v2_helper(
task_history=task_history,
local_datetime=datetime.now(context.tz_info).isoformat(),
)
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2_id,
organization_id=organization_id,
workflow_run_id=workflow_run_id,
@ -1082,7 +1082,7 @@ async def run_task_v2_helper(
termination_reason = completion_resp.get("termination_reason")
completion_failure_categories = completion_resp.get("failure_categories")
thought_content = completion_resp.get("thoughts", "")
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=organization_id,
thought=thought_content,
@ -1128,8 +1128,8 @@ async def run_task_v2_helper(
return workflow, workflow_run, task_v2
# total step number validation
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
total_step_count = await app.DATABASE.get_total_unique_step_order_count_by_task_ids(
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
total_step_count = await app.DATABASE.tasks.get_total_unique_step_order_count_by_task_ids(
task_ids=[task.task_id for task in workflow_run_tasks],
organization_id=organization_id,
)
@ -1314,7 +1314,7 @@ async def _generate_loop_task(
plan=plan,
)
data_extraction_thought = f"Going to generate a list of values to go through based on the plan: {plan}."
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
workflow_run_id=workflow_run_id,
@ -1383,7 +1383,7 @@ async def _generate_loop_task(
raise
# update the thought
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=task_v2.organization_id,
output=output_value_obj,
@ -1441,7 +1441,7 @@ async def _generate_loop_task(
is_link=is_loop_value_link,
loop_values=loop_values,
)
thought_task_in_loop = await app.DATABASE.create_thought(
thought_task_in_loop = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
workflow_run_id=workflow_run_id,
@ -1461,7 +1461,7 @@ async def _generate_loop_task(
data_extraction_goal = task_in_loop_metadata_response.get("data_extraction_goal")
data_extraction_schema = task_in_loop_metadata_response.get("data_schema")
thought_content = task_in_loop_metadata_response.get("thoughts")
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought_task_in_loop.observer_thought_id,
organization_id=task_v2.organization_id,
thought=thought_content,
@ -1675,7 +1675,7 @@ async def _generate_goto_url_task(
async def get_thought_timelines(*, task_v2_id: str, organization_id: str) -> list[WorkflowRunTimeline]:
thoughts = await app.DATABASE.get_thoughts(
thoughts = await app.DATABASE.observer.get_thoughts(
task_v2_id=task_v2_id,
organization_id=organization_id,
thought_types=[
@ -1695,7 +1695,7 @@ async def get_thought_timelines(*, task_v2_id: str, organization_id: str) -> lis
async def get_task_v2(task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
return await app.DATABASE.get_task_v2(task_v2_id, organization_id=organization_id)
return await app.DATABASE.observer.get_task_v2(task_v2_id, organization_id=organization_id)
async def _update_task_v2_status(
@ -1706,7 +1706,7 @@ async def _update_task_v2_status(
output: dict[str, Any] | None = None,
failure_category: list[dict] | None = None,
) -> TaskV2:
task_v2 = await app.DATABASE.update_task_v2(
task_v2 = await app.DATABASE.observer.update_task_v2(
task_v2_id,
organization_id=organization_id,
status=status,
@ -1939,7 +1939,7 @@ async def _summarize_task_v2(
context: SkyvernContext,
screenshots: list[bytes] | None = None,
) -> TaskV2:
thought = await app.DATABASE.create_thought(
thought = await app.DATABASE.observer.create_thought(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
workflow_run_id=task_v2.workflow_run_id,
@ -1966,7 +1966,7 @@ async def _summarize_task_v2(
summary_description = task_v2_summary_resp.get("description")
summarized_output = task_v2_summary_resp.get("output")
await app.DATABASE.update_thought(
await app.DATABASE.observer.update_thought(
thought_id=thought.observer_thought_id,
organization_id=task_v2.organization_id,
thought=summary_description,
@ -2081,7 +2081,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_task_v2(
await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
webhook_failure_reason="",
@ -2094,7 +2094,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
resp_code=resp.status_code,
resp_text=resp.text,
)
await app.DATABASE.update_task_v2(
await app.DATABASE.observer.update_task_v2(
task_v2_id=task_v2.observer_cruise_id,
organization_id=task_v2.organization_id,
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",

View file

@ -267,13 +267,13 @@ async def replay_run_webhook(
async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookPayload:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
if not run:
# Attempt to resolve task v2 runs that may not yet be in the runs table.
task_v2 = await app.DATABASE.get_task_v2(run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(run_id, organization_id=organization_id)
if task_v2:
return await _build_task_v2_payload(task_v2)
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=run_id,
organization_id=organization_id,
)
@ -300,7 +300,7 @@ async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookP
run_type_str=run_type,
)
if run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(run.run_id, organization_id=organization_id)
task_v2 = await app.DATABASE.observer.get_task_v2(run.run_id, organization_id=organization_id)
if not task_v2:
raise SkyvernHTTPException(
f"Task v2 run {run_id} missing task record",
@ -314,7 +314,7 @@ async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookP
async def _build_task_payload(organization_id: str, run_id: str, run_type_str: str) -> _WebhookPayload:
task: Task | None = await app.DATABASE.get_task(run_id, organization_id=organization_id)
task: Task | None = await app.DATABASE.tasks.get_task(run_id, organization_id=organization_id)
if not task:
raise TaskNotFound(task_id=run_id)
if not task.status.is_final():
@ -324,7 +324,7 @@ async def _build_task_payload(organization_id: str, run_id: str, run_type_str: s
status=task.status,
)
raise WebhookReplayError(f"Run {run_id} has not reached a terminal state (status={task.status}).")
latest_step = await app.DATABASE.get_latest_step(run_id, organization_id=organization_id)
latest_step = await app.DATABASE.tasks.get_latest_step(run_id, organization_id=organization_id)
task_response = await app.agent.build_task_response(task=task, last_step=latest_step)
payload_dict = json.loads(task_response.model_dump_json(exclude={"request"}))
@ -382,7 +382,7 @@ async def _build_workflow_payload(
organization_id: str,
workflow_run_id: str,
) -> _WebhookPayload:
workflow_run: WorkflowRun | None = await app.DATABASE.get_workflow_run(
workflow_run: WorkflowRun | None = await app.DATABASE.workflow_runs.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)

View file

@ -147,10 +147,12 @@ async def generate_or_update_pending_workflow_script(
script_id = context.script_id
script = None
if script_id:
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
script = await app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id)
if not script:
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
script = await app.DATABASE.scripts.create_script(
organization_id=organization_id, run_id=workflow_run.workflow_run_id
)
if context:
context.script_id = script.script_id
context.script_revision_id = script.script_revision_id
@ -184,7 +186,7 @@ async def get_workflow_script(
rendered_cache_key_value = ""
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
@ -296,7 +298,7 @@ async def get_workflow_script_by_cache_key_value(
return _workflow_script_cache[cache_key_tuple]
# Cache miss - fetch from database
script, is_pinned = await app.DATABASE.get_workflow_script_by_cache_key_value(
script, is_pinned = await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key_value=cache_key_value,
@ -311,7 +313,7 @@ async def get_workflow_script_by_cache_key_value(
return script, is_pinned
return await app.DATABASE.get_workflow_script_by_cache_key_value(
return await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
cache_key_value=cache_key_value,
@ -331,7 +333,7 @@ async def get_latest_published_script(
variants), this returns the script with the highest version number to ensure
the most recently reviewed code is selected.
"""
workflow_scripts = await app.DATABASE.get_workflow_scripts_by_permanent_id(
workflow_scripts = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
statuses=[ScriptStatus.published],
@ -344,7 +346,7 @@ async def get_latest_published_script(
# TODO: add a bulk get_latest_script_versions() if this becomes a bottleneck.
best: Script | None = None
for ws in workflow_scripts:
script = await app.DATABASE.get_latest_script_version(
script = await app.DATABASE.scripts.get_latest_script_version(
script_id=ws.script_id,
organization_id=organization_id,
)
@ -362,7 +364,7 @@ async def _load_cached_script_block_sources(
"""
cached_blocks: dict[str, ScriptBlockSource] = {}
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
@ -373,13 +375,13 @@ async def _load_cached_script_block_sources(
code_str: str | None = None
if script_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_file = await app.DATABASE.scripts.get_script_file_by_id(
script_revision_id=script.script_revision_id,
file_id=script_block.script_file_id,
organization_id=organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
if isinstance(file_content, bytes):
@ -539,7 +541,7 @@ async def generate_workflow_script(
status = ScriptStatus.published
if pending:
status = ScriptStatus.pending
existing_pending_workflow_script = await app.DATABASE.get_workflow_script(
existing_pending_workflow_script = await app.DATABASE.scripts.get_workflow_script(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run.workflow_run_id,
@ -547,7 +549,7 @@ async def generate_workflow_script(
)
if not existing_pending_workflow_script:
# Record the workflow->script mapping for cache lookup
await app.DATABASE.create_workflow_script(
await app.DATABASE.scripts.create_workflow_script(
organization_id=workflow.organization_id,
script_id=script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
@ -965,13 +967,13 @@ async def _find_main_py_content(script_id: str, organization_id: str, base_revis
"""
async def _load_main_py_from_revision(revision_id: str) -> str | None:
files = await app.DATABASE.get_script_files(
files = await app.DATABASE.scripts.get_script_files(
script_revision_id=revision_id,
organization_id=organization_id,
)
for f in files:
if f.file_path == "main.py" and f.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(f.artifact_id, organization_id)
artifact = await app.DATABASE.artifacts.get_artifact_by_id(f.artifact_id, organization_id)
if artifact:
content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
if content:
@ -984,7 +986,7 @@ async def _find_main_py_content(script_id: str, organization_id: str, base_revis
return result
# Fall back to v1 (bootstrapping: base was created before this fix)
v1_script = await app.DATABASE.get_script(
v1_script = await app.DATABASE.scripts.get_script(
script_id=script_id,
organization_id=organization_id,
version=1,
@ -1022,7 +1024,7 @@ async def create_script_version_from_review(
# _trigger_script_reviewer() already gates on is_script_pinned(), but
# that check can be bypassed when the skyvern context is missing.
# Guard here so no code path can mutate a pinned script.
if await app.DATABASE.is_script_pinned(
if await app.DATABASE.scripts.is_script_pinned(
organization_id=organization_id,
script_id=base_script.script_id,
):
@ -1035,7 +1037,7 @@ async def create_script_version_from_review(
return None
# Create a new script version
new_script = await app.DATABASE.create_script(
new_script = await app.DATABASE.scripts.create_script(
organization_id=organization_id,
script_id=base_script.script_id,
version=base_script.version + 1,
@ -1043,7 +1045,7 @@ async def create_script_version_from_review(
)
# Copy existing script blocks from the base revision
existing_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
existing_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
script_revision_id=base_script.script_revision_id,
organization_id=organization_id,
)
@ -1066,7 +1068,7 @@ async def create_script_version_from_review(
file_path=file_path,
data=content_bytes,
)
new_file = await app.DATABASE.create_script_file(
new_file = await app.DATABASE.scripts.create_script_file(
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
@ -1090,7 +1092,7 @@ async def create_script_version_from_review(
)
# Create script block entry pointing to the new file
await app.DATABASE.create_script_block(
await app.DATABASE.scripts.create_script_block(
organization_id=organization_id,
script_id=new_script.script_id,
script_revision_id=new_script.script_revision_id,
@ -1103,7 +1105,7 @@ async def create_script_version_from_review(
)
else:
# Copy existing block as-is
await app.DATABASE.create_script_block(
await app.DATABASE.scripts.create_script_block(
organization_id=organization_id,
script_id=new_script.script_id,
script_revision_id=new_script.script_revision_id,
@ -1136,7 +1138,7 @@ async def create_script_version_from_review(
file_path="main.py",
data=patched_bytes,
)
await app.DATABASE.create_script_file(
await app.DATABASE.scripts.create_script_file(
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
@ -1158,19 +1160,19 @@ async def create_script_version_from_review(
# Copy non-block files (e.g., .skyvern metadata) from the base revision
# or v1 — whichever has the full file set
source_files = await app.DATABASE.get_script_files(
source_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=base_script.script_revision_id,
organization_id=organization_id,
)
if not any(f.file_path != "main.py" and not f.file_path.startswith("blocks/") for f in source_files):
# Base revision has no non-block files, fall back to v1
v1_script = await app.DATABASE.get_script(
v1_script = await app.DATABASE.scripts.get_script(
script_id=base_script.script_id,
organization_id=organization_id,
version=1,
)
if v1_script:
source_files = await app.DATABASE.get_script_files(
source_files = await app.DATABASE.scripts.get_script_files(
script_revision_id=v1_script.script_revision_id,
organization_id=organization_id,
)
@ -1180,7 +1182,7 @@ async def create_script_version_from_review(
# Skip main.py (already patched) and updated block files (already created)
if f.file_path == "main.py" or f.file_path in updated_block_file_paths:
continue
await app.DATABASE.create_script_file(
await app.DATABASE.scripts.create_script_file(
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
@ -1208,7 +1210,7 @@ async def create_script_version_from_review(
)
else:
# No workflow run — look up the existing cache key value from the base script
existing_ws = await app.DATABASE.get_workflow_scripts_by_permanent_id(
existing_ws = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
statuses=[ScriptStatus.published],
@ -1219,7 +1221,7 @@ async def create_script_version_from_review(
rendered_cache_key_value = ws.cache_key_value
break
await app.DATABASE.create_workflow_script(
await app.DATABASE.scripts.create_workflow_script(
organization_id=organization_id,
script_id=new_script.script_id,
workflow_permanent_id=workflow_permanent_id,

View file

@ -55,7 +55,7 @@ async def prepare_workflow(
version=version,
)
await app.DATABASE.create_task_run(
await app.DATABASE.tasks.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=organization.organization_id,
run_id=workflow_run.workflow_run_id,
@ -120,7 +120,7 @@ async def run_workflow(
async def get_workflow_run_response(
workflow_run_id: str, organization_id: str | None = None
) -> WorkflowRunResponse | None:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id=organization_id)
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id=organization_id)
if not workflow_run:
return None
workflow_run_resp = await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(

View file

@ -24,14 +24,14 @@ async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPag
# V0: use the previous action plan if there is a completed task with the same url and navigation goal
# get completed task with the same url and navigation goal
# TODO(kerem): don't use step_order, get all the previous actions instead
cached_actions = await app.DATABASE.retrieve_action_plan(task=task)
cached_actions = await app.DATABASE.workflow_params.retrieve_action_plan(task=task)
if not cached_actions:
LOG.info("No cached actions found for the task, fallback to no-cache mode")
return []
# Get the existing actions for this task from the database. Then find the actions that are already executed by looking at
# the source_action_id field for this task's actions.
previous_actions = await app.DATABASE.get_previous_actions_for_task(task_id=task.task_id)
previous_actions = await app.DATABASE.tasks.get_previous_actions_for_task(task_id=task.task_id)
executed_cached_actions = []
remaining_cached_actions = []

View file

@ -409,7 +409,7 @@ class ActionHandler:
page=page,
action=action,
)
persisted_action = await app.DATABASE.create_action(action=action)
persisted_action = await app.DATABASE.workflow_params.create_action(action=action)
action.action_id = persisted_action.action_id
return results
@ -545,7 +545,7 @@ class ActionHandler:
exc_info=True,
)
persisted_action = await app.DATABASE.create_action(action=action)
persisted_action = await app.DATABASE.workflow_params.create_action(action=action)
action.action_id = persisted_action.action_id
@staticmethod

View file

@ -847,7 +847,7 @@ async def generate_cua_fallback_actions(
assistant_message=assistant_message,
reasoning=reasoning,
)
await app.DATABASE.update_task(
await app.DATABASE.tasks.update_task(
task.task_id,
organization_id=task.organization_id,
extracted_information=assistant_message,

View file

@ -45,7 +45,7 @@ async def validate_session_for_renewal(
Validate a specific browser session for renewal. Otherwise raise.
"""
browser_session = await database.get_persistent_browser_session(
browser_session = await database.browser_sessions.get_persistent_browser_session(
session_id=session_id,
organization_id=organization_id,
)
@ -117,7 +117,7 @@ async def renew_session(database: AgentDB, session_id: str, organization_id: str
minutes_diff = floor((new_timeout_datetime - current_timeout_datetime).total_seconds() / 60)
new_timeout_minutes = current_timeout_minutes + minutes_diff
browser_session = await database.update_persistent_browser_session(
browser_session = await database.browser_sessions.update_persistent_browser_session(
session_id,
organization_id=organization_id,
timeout_minutes=new_timeout_minutes,
@ -138,7 +138,7 @@ async def renew_session(database: AgentDB, session_id: str, organization_id: str
async def update_status(
db: AgentDB, session_id: str, organization_id: str, status: str
) -> PersistentBrowserSession | None:
persistent_browser_session = await db.get_persistent_browser_session(session_id, organization_id)
persistent_browser_session = await db.browser_sessions.get_persistent_browser_session(session_id, organization_id)
if not persistent_browser_session:
LOG.warning(
@ -167,7 +167,7 @@ async def update_status(
)
completed_at = datetime.now(timezone.utc) if is_final_status(status) else None
persistent_browser_session = await db.update_persistent_browser_session(
persistent_browser_session = await db.browser_sessions.update_persistent_browser_session(
session_id,
status=status,
organization_id=organization_id,
@ -215,7 +215,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
LOG.info("Begin browser session", browser_session_id=browser_session_id)
persistent_browser_session = await self.database.get_persistent_browser_session(
persistent_browser_session = await self.database.browser_sessions.get_persistent_browser_session(
browser_session_id, organization_id
)
@ -246,11 +246,13 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
self, runnable_id: str, organization_id: str
) -> PersistentBrowserSession | None:
"""Get a specific browser session by runnable ID."""
return await self.database.get_persistent_browser_session_by_runnable_id(runnable_id, organization_id)
return await self.database.browser_sessions.get_persistent_browser_session_by_runnable_id(
runnable_id, organization_id
)
async def get_active_sessions(self, organization_id: str) -> list[PersistentBrowserSession]:
"""Get all active sessions for an organization."""
return await self.database.get_active_persistent_browser_sessions(organization_id)
return await self.database.browser_sessions.get_active_persistent_browser_sessions(organization_id)
async def get_browser_state(self, session_id: str, organization_id: str | None = None) -> BrowserState | None:
"""Get a specific browser session's state by session ID."""
@ -263,7 +265,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
async def get_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession | None:
"""Get a specific browser session by session ID."""
return await self.database.get_persistent_browser_session(session_id, organization_id)
return await self.database.browser_sessions.get_persistent_browser_session(session_id, organization_id)
async def create_session(
self,
@ -283,7 +285,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"Creating new browser session",
organization_id=organization_id,
)
session = await self.database.create_persistent_browser_session(
session = await self.database.browser_sessions.create_persistent_browser_session(
organization_id=organization_id,
runnable_type=runnable_type,
runnable_id=runnable_id,
@ -350,7 +352,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
await browser_state.close()
return
# Set started_at so renewal knows the browser is live
await self.database.update_persistent_browser_session(
await self.database.browser_sessions.update_persistent_browser_session(
session_id,
organization_id=organization_id,
started_at=datetime.now(timezone.utc),
@ -375,7 +377,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
organization_id: str,
) -> None:
"""Occupy a specific browser session."""
await self.database.occupy_persistent_browser_session(
await self.database.browser_sessions.occupy_persistent_browser_session(
session_id=session_id,
runnable_type=runnable_type,
runnable_id=runnable_id,
@ -410,7 +412,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
async def release_browser_session(self, session_id: str, organization_id: str) -> None:
"""Release a specific browser session."""
await self.database.release_persistent_browser_session(session_id, organization_id)
await self.database.browser_sessions.release_persistent_browser_session(session_id, organization_id)
async def close_session(self, organization_id: str, browser_session_id: str) -> None:
"""Close a specific browser session."""
@ -491,13 +493,13 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
session_id=browser_session_id,
)
await self.database.close_persistent_browser_session(browser_session_id, organization_id)
await self.database.browser_sessions.close_persistent_browser_session(browser_session_id, organization_id)
if settings.BROWSER_STREAMING_MODE == "cdp":
await self.database.archive_browser_session_address(browser_session_id, organization_id)
await self.database.browser_sessions.archive_browser_session_address(browser_session_id, organization_id)
async def close_all_sessions(self, organization_id: str) -> None:
"""Close all browser sessions for an organization."""
browser_sessions = await self.database.get_active_persistent_browser_sessions(organization_id)
browser_sessions = await self.database.browser_sessions.get_active_persistent_browser_sessions(organization_id)
for browser_session in browser_sessions:
await self.close_session(organization_id, browser_session.persistent_browser_session_id)
@ -505,17 +507,17 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"""Close sessions left active by a previous process."""
if settings.BROWSER_STREAMING_MODE != "cdp":
return
stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions()
stale_sessions = await self.database.browser_sessions.get_uncompleted_persistent_browser_sessions()
for db_session in stale_sessions:
LOG.info(
"Closing stale browser session from previous run",
session_id=db_session.persistent_browser_session_id,
organization_id=db_session.organization_id,
)
await self.database.close_persistent_browser_session(
await self.database.browser_sessions.close_persistent_browser_session(
db_session.persistent_browser_session_id, db_session.organization_id
)
await self.database.archive_browser_session_address(
await self.database.browser_sessions.archive_browser_session_address(
db_session.persistent_browser_session_id, db_session.organization_id
)
@ -524,7 +526,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"""Close all browser sessions across all organizations."""
LOG.info("Closing PersistentSessionsManager")
if cls.instance:
active_sessions = await cls.instance.database.get_all_active_persistent_browser_sessions()
active_sessions = await cls.instance.database.browser_sessions.get_all_active_persistent_browser_sessions()
for db_session in active_sessions:
await cls.instance.close_session(db_session.organization_id, db_session.persistent_browser_session_id)
LOG.info("PersistentSessionsManager is closed")

View file

@ -22,21 +22,21 @@ def create_forge_stub_app() -> ForgeApp:
fake_app_module.AGENT_FUNCTION.validate_block_execution = AsyncMock()
fake_app_module.AGENT_FUNCTION.validate_code_block = AsyncMock()
fake_app_module.agent = _LazyNamespace()
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.get_last_task_for_workflow_run = AsyncMock()
fake_app_module.DATABASE.get_workflow_run = AsyncMock()
fake_app_module.DATABASE.get_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.get_task = AsyncMock()
fake_app_module.DATABASE.update_task = AsyncMock()
fake_app_module.DATABASE.update_task_v2 = AsyncMock()
fake_app_module.DATABASE.organizations = _LazyNamespace()
fake_app_module.DATABASE.workflows = _LazyNamespace()
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.update_workflow_run = AsyncMock()
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.observer.update_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.observer.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.tasks.get_last_task_for_workflow_run = AsyncMock()
fake_app_module.DATABASE.workflow_runs.get_workflow_run = AsyncMock()
fake_app_module.DATABASE.observer.get_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.tasks.get_task = AsyncMock()
fake_app_module.DATABASE.tasks.update_task = AsyncMock()
fake_app_module.DATABASE.observer.update_task_v2 = AsyncMock()
fake_app_module.DATABASE.organizations.get_organization = AsyncMock()
fake_app_module.DATABASE.workflows.get_workflow = AsyncMock()
fake_app_module.DATABASE.observer.create_workflow_run_block = AsyncMock()
fake_app_module.DATABASE.workflow_runs.update_workflow_run = AsyncMock()
fake_app_module.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter = AsyncMock()
fake_app_module.DATABASE.observer.update_workflow_run_block = AsyncMock()
fake_app_module.LLM_API_HANDLER = AsyncMock()
fake_app_module.SECONDARY_LLM_API_HANDLER = AsyncMock()
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()

View file

@ -196,7 +196,10 @@ def test_agent_db_has_typed_repo_attributes():
assert isinstance(db.tasks, TasksRepository)
assert isinstance(db.credentials, CredentialRepository)
assert hasattr(db, "get_task") # backward compat delegate
assert hasattr(db, "workflows")
# Migrated domains no longer have delegates on AgentDB:
assert not hasattr(db, "create_workflow")
assert not hasattr(db, "get_organization")
assert not hasattr(db, "get_credential")
def test_agent_db_delegates_route_to_repositories():

View file

@ -242,10 +242,10 @@ def setup_parallel_verification_mocks(
extract_action: Any | None = None,
) -> ParallelVerificationMocks:
create_step_mock = AsyncMock(return_value=next_step)
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
monkeypatch.setattr(app.DATABASE.tasks, "create_step", create_step_mock)
get_task_steps_mock = AsyncMock(return_value=[step])
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
monkeypatch.setattr(app.DATABASE.tasks, "get_task_steps", get_task_steps_mock)
sleep_mock = AsyncMock(return_value=None)
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)

View file

@ -53,7 +53,7 @@ def mock_app():
mock = MagicMock()
mock.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
mock.DATABASE = MagicMock()
mock.DATABASE.get_step = AsyncMock(return_value=MagicMock())
mock.DATABASE.tasks.get_step = AsyncMock(return_value=MagicMock())
return mock

View file

@ -41,8 +41,8 @@ async def test_cached_content_removed_from_non_extract_prompts() -> None:
# Ensure app dependencies referenced inside the handler resolve to async mocks.
forge_module.app.ARTIFACT_MANAGER = MagicMock()
forge_module.app.DATABASE = MagicMock()
forge_module.app.DATABASE.update_step = AsyncMock()
forge_module.app.DATABASE.update_thought = AsyncMock()
forge_module.app.DATABASE.tasks.update_step = AsyncMock()
forge_module.app.DATABASE.observer.update_thought = AsyncMock()
with (
patch("skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", return_value=mock_config),

View file

@ -109,9 +109,9 @@ async def test_batch_actions_preserve_per_task_ordering() -> None:
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
@ -187,9 +187,9 @@ async def test_batch_actions_without_reverse_would_be_wrong() -> None:
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=actions_descending)
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=actions_descending)
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
@ -250,9 +250,9 @@ async def test_batch_actions_preserve_none_element_id() -> None:
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[action_extract])
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=[action_extract])
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")

View file

@ -99,7 +99,7 @@ class TestScreenshotAfterExecutionBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
await ScriptSkyvernPage._create_screenshot_after_execution()
@ -132,7 +132,7 @@ class TestScreenshotAfterExecutionBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
await ScriptSkyvernPage._create_screenshot_after_execution()
@ -170,7 +170,7 @@ class TestHtmlActionAfterExecutionBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
mock_frame = MagicMock()
@ -210,7 +210,7 @@ class TestHtmlActionAfterExecutionBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
mock_frame = MagicMock()
@ -249,7 +249,7 @@ class TestFinalScreenshotBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
await ScriptSkyvernPage._create_final_screenshot()
@ -283,7 +283,7 @@ class TestFinalScreenshotBundling:
):
mock_ctx.ensure_context.return_value = context
mock_get_bs.return_value = mock_browser_state
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
mock_app.ARTIFACT_MANAGER = mock_manager
await ScriptSkyvernPage._create_final_screenshot()
@ -317,9 +317,9 @@ class TestUpdateWorkflowBlockFlush:
):
mock_ctx.current.return_value = context
mock_app.ARTIFACT_MANAGER = mock_manager
mock_app.DATABASE.update_step = AsyncMock()
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
mock_app.DATABASE.tasks.update_step = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])
@ -348,9 +348,9 @@ class TestUpdateWorkflowBlockFlush:
):
mock_ctx.current.return_value = context
mock_app.ARTIFACT_MANAGER = mock_manager
mock_app.DATABASE.update_step = AsyncMock()
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
mock_app.DATABASE.tasks.update_step = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])
@ -379,7 +379,7 @@ class TestUpdateWorkflowBlockFlush:
):
mock_ctx.current.return_value = context
mock_app.ARTIFACT_MANAGER = mock_manager
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
mock_app.WORKFLOW_SERVICE.send_workflow_response = AsyncMock()
mock_run_ctx.get_run_context.return_value = None
@ -407,9 +407,9 @@ class TestUpdateWorkflowBlockFlush:
):
mock_ctx.current.return_value = context
mock_app.ARTIFACT_MANAGER = mock_manager
mock_app.DATABASE.update_step = AsyncMock()
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
mock_app.DATABASE.tasks.update_step = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])

View file

@ -1,11 +1,10 @@
"""Tests that update_credential() accepts user_context and save_browser_session_intent
on both CredentialRepository and CredentialsMixin."""
on CredentialRepository."""
from unittest.mock import MagicMock, patch
import pytest
from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
from tests.unit.conftest import MockAsyncSessionCtx, make_mock_session
@ -15,13 +14,6 @@ def _make_credential_repo(mock_credential: MagicMock) -> CredentialRepository:
return CredentialRepository(session_factory=lambda: MockAsyncSessionCtx(mock_session))
def _make_credential_mixin(mock_credential: MagicMock) -> CredentialsMixin:
mock_session = make_mock_session(mock_credential)
mixin = CredentialsMixin.__new__(CredentialsMixin)
mixin.Session = lambda: MockAsyncSessionCtx(mock_session) # type: ignore[assignment]
return mixin
# --- CredentialRepository tests ---
@ -75,40 +67,3 @@ async def test_repo_update_credential_unset_params_not_applied() -> None:
assert mock_credential.user_context == "existing"
assert mock_credential.save_browser_session_intent is True
# --- CredentialsMixin tests ---
@pytest.mark.asyncio
async def test_mixin_update_credential_accepts_user_context() -> None:
mock_credential = MagicMock()
mock_credential.name = "test"
mock_credential.user_context = None
mixin = _make_credential_mixin(mock_credential)
with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()):
await mixin.update_credential(
credential_id="cred_123",
organization_id="org_123",
user_context="Click SSO button first",
)
assert mock_credential.user_context == "Click SSO button first"
@pytest.mark.asyncio
async def test_mixin_update_credential_accepts_save_browser_session_intent() -> None:
mock_credential = MagicMock()
mock_credential.name = "test"
mock_credential.save_browser_session_intent = False
mixin = _make_credential_mixin(mock_credential)
with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()):
await mixin.update_credential(
credential_id="cred_123",
organization_id="org_123",
save_browser_session_intent=True,
)
assert mock_credential.save_browser_session_intent is True

View file

@ -444,7 +444,7 @@ async def test_handle_action_navigates_back_from_blank_page_after_download() ->
mock_app = MagicMock()
mock_app.BROWSER_MANAGER.get_for_task.return_value = browser_state
mock_app.DATABASE.create_action = AsyncMock(return_value=action)
mock_app.DATABASE.workflow_params.create_action = AsyncMock(return_value=action)
mock_app.STORAGE = MagicMock()
with (
@ -517,7 +517,7 @@ async def test_handle_action_does_not_navigate_back_when_page_url_unchanged() ->
mock_app = MagicMock()
mock_app.BROWSER_MANAGER.get_for_task.return_value = browser_state
mock_app.DATABASE.create_action = AsyncMock(return_value=action)
mock_app.DATABASE.workflow_params.create_action = AsyncMock(return_value=action)
mock_app.STORAGE = MagicMock()
with (

View file

@ -84,7 +84,7 @@ async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
# Simulate FailedToNavigateToUrl scenario
from skyvern.exceptions import FailedToNavigateToUrl
@ -103,8 +103,8 @@ async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
assert result is True
# Verify errors were stored
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
mock_app.DATABASE.tasks.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "page_not_found"
@ -138,9 +138,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.BROWSER_MANAGER.get_for_task.return_value = mock_browser_state
mock_app.DATABASE.get_task_steps = AsyncMock(return_value=[step, step, step])
mock_app.DATABASE.get_task = AsyncMock(return_value=task)
mock_app.DATABASE.update_task = AsyncMock(return_value=task)
mock_app.DATABASE.tasks.get_task_steps = AsyncMock(return_value=[step, step, step])
mock_app.DATABASE.tasks.get_task = AsyncMock(return_value=task)
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=task)
# create_step is awaited in handle_failed_step retry branch; avoid MagicMock in await
next_step = make_step(
now,
@ -151,9 +151,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
retry_index=step.retry_index + 1,
output=None,
)
mock_app.DATABASE.create_step = AsyncMock(return_value=next_step)
mock_app.DATABASE.tasks.create_step = AsyncMock(return_value=next_step)
# Async mock that forwards to mock_app.DATABASE.update_task so we never await MagicMock inside real update_task
# Async mock that forwards to mock_app.DATABASE.tasks.update_task so we never await MagicMock inside real update_task
async def mock_update_task(
_self,
task,
@ -173,7 +173,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
updates["errors"] = errors
if failure_category is not None:
updates["failure_category"] = failure_category
return await mock_app.DATABASE.update_task(task.task_id, organization_id=task.organization_id, **updates)
return await mock_app.DATABASE.tasks.update_task(
task.task_id, organization_id=task.organization_id, **updates
)
with patch.object(ForgeAgent, "summary_failure_reason_for_max_retries", mock_summary):
with patch.object(ForgeAgent, "update_task", mock_update_task):
@ -182,8 +184,8 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
assert result is None # No next step when max retries exceeded
# Verify errors include both system and user-defined errors
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
mock_app.DATABASE.tasks.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
errors = call_kwargs["errors"]
assert len(errors) == 2
@ -220,7 +222,7 @@ async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
# Simulate ScrapingFailed scenario
from skyvern.exceptions import ScrapingFailed
@ -233,8 +235,8 @@ async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
assert result is True
# Verify errors were stored
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
mock_app.DATABASE.tasks.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "login_required"
@ -268,12 +270,12 @@ async def test_multiple_failures_accumulate_errors(agent, mock_browser_state):
with create_error_detection_mocks(first_detected):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
await agent.fail_task(task, step, "First failure", mock_browser_state)
# Only new errors are passed — DB handles appending to existing ones
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
@ -304,15 +306,15 @@ async def test_error_detection_with_workflow_task(agent, mock_browser_state):
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Workflow task failed", mock_browser_state)
assert result is True
# Verify errors were stored for workflow task
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
mock_app.DATABASE.tasks.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert call_kwargs["task_id"] == task.task_id
# workflow_run_id is not passed in the update call, only task_id and errors
assert "workflow_run_id" not in call_kwargs
@ -350,7 +352,7 @@ async def test_error_detection_performance_doesnt_block_failure(agent, mock_brow
mock_detect.side_effect = slow_detection
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
import time

View file

@ -72,7 +72,7 @@ async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_brow
mock_detect.return_value = detected_errors
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
@ -87,8 +87,8 @@ async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_brow
)
# Verify task errors were updated in database
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
mock_app.DATABASE.tasks.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert call_kwargs["task_id"] == task.task_id
assert call_kwargs["organization_id"] == task.organization_id
assert len(call_kwargs["errors"]) == 1
@ -112,7 +112,7 @@ async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
new_callable=AsyncMock,
) as mock_detect:
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
@ -122,7 +122,7 @@ async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
mock_detect.assert_not_called()
# Verify database update was NOT called for errors
mock_app.DATABASE.update_task.assert_not_called()
mock_app.DATABASE.tasks.update_task.assert_not_called()
@pytest.mark.asyncio
@ -150,7 +150,7 @@ async def test_fail_task_without_browser_state(agent):
mock_detect.return_value = []
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
# Call without browser_state
result = await agent.fail_task(task, step, "Task failed", browser_state=None)
@ -188,7 +188,7 @@ async def test_fail_task_without_step(agent, mock_browser_state):
new_callable=AsyncMock,
) as mock_detect:
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
# Call without step
result = await agent.fail_task(task, None, "Task failed", mock_browser_state)
@ -228,7 +228,7 @@ async def test_fail_task_error_detection_fails_gracefully(agent, mock_browser_st
mock_detect.side_effect = Exception("Detection failed")
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
# Should not raise exception
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
@ -270,14 +270,14 @@ async def test_fail_task_multiple_errors_detected(agent, mock_browser_state):
mock_detect.return_value = detected_errors
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Verify only new errors were passed (DB handles appending to existing errors)
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 2
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
assert call_kwargs["errors"][1]["error_code"] == "address_invalid"
@ -308,14 +308,14 @@ async def test_fail_task_no_errors_detected(agent, mock_browser_state):
mock_detect.return_value = []
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
mock_app.DATABASE.tasks.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Database update for errors should not be called
mock_app.DATABASE.update_task.assert_not_called()
mock_app.DATABASE.tasks.update_task.assert_not_called()
@pytest.mark.asyncio

View file

@ -336,7 +336,7 @@ async def test_transform_forloop_block_with_mocked_db() -> None:
):
mock_get_wfr.return_value = mock_workflow_run_resp
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(
return_value=[
mock_forloop_run_block,
mock_child_run_block,
@ -345,8 +345,8 @@ async def test_transform_forloop_block_with_mocked_db() -> None:
# B1 optimization: Mock batch methods instead of individual queries
mock_task.task_id = "task_extraction_789"
mock_action.task_id = "task_extraction_789"
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[mock_action])
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=[mock_action])
# Call the transformation
result = await transform_workflow_run_to_code_gen_input(

View file

@ -0,0 +1,15 @@
"""
Verify the no-direct-db-delegates hook script stays consistent.
"""
import subprocess
def test_hook_passes_on_current_codebase() -> None:
"""The hook should pass cleanly — all legacy files are in the allowlist."""
result = subprocess.run(
["./scripts/check_no_direct_db_delegates.sh"],
capture_output=True,
text=True,
)
assert result.returncode == 0, f"Hook unexpectedly failed. New direct delegate calls found:\n{result.stdout}"

View file

@ -348,7 +348,7 @@ async def test_agent_step_persists_artifacts_when_using_speculative_plan(
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.create_action", AsyncMock())
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.workflow_params.create_action", AsyncMock())
monkeypatch.setattr(
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
AsyncMock(return_value=False),

View file

@ -9,7 +9,7 @@ from skyvern.forge.sdk.routes import agent_protocol
@pytest.mark.asyncio
async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: pytest.MonkeyPatch) -> None:
mock_database = SimpleNamespace(
mock_workflow_runs = SimpleNamespace(
get_all_runs_v2=AsyncMock(
return_value=[
{
@ -28,6 +28,7 @@ async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: py
]
)
)
mock_database = SimpleNamespace(workflow_runs=mock_workflow_runs)
monkeypatch.setattr(agent_protocol.app, "DATABASE", mock_database)
response = await agent_protocol.get_runs_v2(
@ -37,7 +38,7 @@ async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: py
search_key="abc",
)
mock_database.get_all_runs_v2.assert_awaited_once_with(
mock_workflow_runs.get_all_runs_v2.assert_awaited_once_with(
"org_123",
page=2,
page_size=5,

View file

@ -163,7 +163,7 @@ class TestTriggerScriptReviewerCap:
self.mock_cache.get_lock = MagicMock(return_value=mock_lock)
# Mock is_script_pinned to return False (not pinned) so tests reach the cap logic
self._pin_patcher = patch(
"skyvern.forge.sdk.workflow.service.app.DATABASE.is_script_pinned",
"skyvern.forge.sdk.workflow.service.app.DATABASE.scripts.is_script_pinned",
new_callable=AsyncMock,
return_value=False,
)

View file

@ -421,12 +421,12 @@ async def test_terminate_calls_handler_and_raises(mock_scraped_page, mock_ai):
return_value=mock_context,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
new_callable=AsyncMock,
return_value=mock_task,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
new_callable=AsyncMock,
return_value=mock_step,
),
@ -481,12 +481,12 @@ async def test_terminate_raises_even_when_task_not_found(mock_scraped_page, mock
return_value=mock_context,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
new_callable=AsyncMock,
return_value=None,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
new_callable=AsyncMock,
return_value=None,
),
@ -538,12 +538,12 @@ async def test_terminate_raises_even_when_handler_fails(mock_scraped_page, mock_
return_value=mock_context,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
new_callable=AsyncMock,
return_value=mock_task,
),
patch(
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
new_callable=AsyncMock,
return_value=mock_step,
),

View file

@ -526,8 +526,9 @@ def _make_app_mocks() -> tuple[MagicMock, MagicMock]:
mock_storage.store_artifact = AsyncMock()
mock_database = MagicMock()
mock_database.bulk_create_artifacts = AsyncMock()
mock_database.update_action_screenshot_artifact_id = AsyncMock()
mock_database.artifacts = MagicMock()
mock_database.artifacts.bulk_create_artifacts = AsyncMock()
mock_database.artifacts.update_action_screenshot_artifact_id = AsyncMock()
return mock_storage, mock_database
@ -557,9 +558,9 @@ class TestFlushStepArchive:
await manager.flush_step_archive(step.step_id)
mock_storage.store_artifact.assert_awaited_once()
mock_database.bulk_create_artifacts.assert_awaited_once()
mock_database.artifacts.bulk_create_artifacts.assert_awaited_once()
# The artifact list should include the parent + 6 member rows (scrape produces 6 entries)
call_args = mock_database.bulk_create_artifacts.call_args[0][0]
call_args = mock_database.artifacts.bulk_create_artifacts.call_args[0][0]
assert len(call_args) == 7 # 1 parent + 6 members
@pytest.mark.asyncio
@ -612,7 +613,7 @@ class TestFlushStepArchive:
# store_artifact and bulk_create_artifacts should only be called once
assert mock_storage.store_artifact.await_count == 1
assert mock_database.bulk_create_artifacts.await_count == 1
assert mock_database.artifacts.bulk_create_artifacts.await_count == 1
@pytest.mark.asyncio
async def test_flush_nonexistent_step_id_is_noop(self) -> None:
@ -626,7 +627,7 @@ class TestFlushStepArchive:
await manager.flush_step_archive("nonexistent_step_id")
mock_storage.store_artifact.assert_not_awaited()
mock_database.bulk_create_artifacts.assert_not_awaited()
mock_database.artifacts.bulk_create_artifacts.assert_not_awaited()
@pytest.mark.asyncio
async def test_flush_applies_pending_screenshot_updates(self) -> None:
@ -652,7 +653,7 @@ class TestFlushStepArchive:
mock_app.DATABASE = mock_database
await manager.flush_step_archive(step.step_id)
mock_database.update_action_screenshot_artifact_id.assert_awaited_once_with(
mock_database.artifacts.update_action_screenshot_artifact_id.assert_awaited_once_with(
organization_id="org_1",
action_id="action_1",
screenshot_artifact_id="art_1",
@ -682,10 +683,10 @@ class TestFlushStepArchive:
await manager.flush_step_archive(step.step_id)
# Reset call counts to detect any additional calls from wait_for_upload_aiotasks
mock_storage.store_artifact.reset_mock()
mock_database.bulk_create_artifacts.reset_mock()
mock_database.artifacts.bulk_create_artifacts.reset_mock()
# Simulate the end-of-task flush fallback
await manager.wait_for_upload_aiotasks([step.task_id])
# The fallback should find nothing to flush — no extra uploads
mock_storage.store_artifact.assert_not_awaited()
mock_database.bulk_create_artifacts.assert_not_awaited()
mock_database.artifacts.bulk_create_artifacts.assert_not_awaited()

View file

@ -13,7 +13,7 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
organization = SimpleNamespace(organization_id="org_123")
user_url = "https://example.com"
app.DATABASE.create_task_v2.return_value = SimpleNamespace(
app.DATABASE.observer.create_task_v2.return_value = SimpleNamespace(
observer_cruise_id="tsk_123",
workflow_run_id=None,
url=user_url,
@ -24,14 +24,14 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
title=DEFAULT_WORKFLOW_TITLE,
)
app.WORKFLOW_SERVICE.setup_workflow_run.return_value = SimpleNamespace(workflow_run_id="wr_123")
app.DATABASE.update_task_v2.return_value = SimpleNamespace(
app.DATABASE.observer.update_task_v2.return_value = SimpleNamespace(
observer_cruise_id="tsk_123",
workflow_run_id="wr_123",
workflow_id="wf_123",
workflow_permanent_id="wpid_123",
url=user_url,
)
app.DATABASE.create_task_run.return_value = SimpleNamespace(run_id="tsk_123")
app.DATABASE.tasks.create_task_run.return_value = SimpleNamespace(run_id="tsk_123")
await initialize_task_v2(
organization=organization,
@ -40,7 +40,7 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
create_task_run=True,
)
app.DATABASE.create_task_run.assert_awaited_once_with(
app.DATABASE.tasks.create_task_run.assert_awaited_once_with(
task_run_type=RunType.task_v2,
organization_id="org_123",
run_id="tsk_123",

View file

@ -20,7 +20,7 @@ async def test_org_email_bitwarden_auth_falls_back_to_global_credentials(
)
)
class FakeOrgRepo:
class FakeOrganizationsRepo:
async def get_valid_org_auth_token(self, organization_id: str, token_type: str) -> object:
assert organization_id == "org-1"
assert token_type == "bitwarden_credential"
@ -28,7 +28,7 @@ async def test_org_email_bitwarden_auth_falls_back_to_global_credentials(
class FakeDatabase:
def __init__(self) -> None:
self.organizations = FakeOrgRepo()
self.organizations = FakeOrganizationsRepo()
fake_app = SimpleNamespace(DATABASE=FakeDatabase())
monkeypatch.setattr(cm, "app", fake_app)

View file

@ -186,8 +186,8 @@ async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: p
ctx.values["flag"] = True
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
app.DATABASE.update_workflow_run_block.reset_mock()
app.DATABASE.create_or_update_workflow_run_output_parameter.reset_mock()
app.DATABASE.observer.update_workflow_run_block.reset_mock()
app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter.reset_mock()
result = await block.execute(
workflow_run_id="run-1",
@ -202,7 +202,7 @@ async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: p
assert ctx.blocks_metadata["cond"]["branch_taken"] == "next"
# Get the actual call arguments
call_args = app.DATABASE.update_workflow_run_block.call_args
call_args = app.DATABASE.observer.update_workflow_run_block.call_args
assert call_args.kwargs["workflow_run_block_id"] == "wrb-1"
assert call_args.kwargs["output"] == metadata
assert call_args.kwargs["status"] == BlockStatus.completed

View file

@ -246,7 +246,7 @@ class TestBuildBlockResultPassesErrorCodes:
from skyvern.forge import app
app.DATABASE.update_workflow_run_block.reset_mock()
app.DATABASE.observer.update_workflow_run_block.reset_mock()
result = await block.build_block_result(
success=False,
@ -257,8 +257,8 @@ class TestBuildBlockResultPassesErrorCodes:
error_codes=["FILE_PARSER_ERROR"],
)
app.DATABASE.update_workflow_run_block.assert_called_once()
call_kwargs = app.DATABASE.update_workflow_run_block.call_args[1]
app.DATABASE.observer.update_workflow_run_block.assert_called_once()
call_kwargs = app.DATABASE.observer.update_workflow_run_block.call_args[1]
assert call_kwargs["error_codes"] == ["FILE_PARSER_ERROR"]
assert result.error_codes == ["FILE_PARSER_ERROR"]
@ -268,7 +268,7 @@ class TestBuildBlockResultPassesErrorCodes:
from skyvern.forge import app
app.DATABASE.update_workflow_run_block.reset_mock()
app.DATABASE.observer.update_workflow_run_block.reset_mock()
result = await block.build_block_result(
success=True,
@ -278,7 +278,7 @@ class TestBuildBlockResultPassesErrorCodes:
organization_id="org_test",
)
app.DATABASE.update_workflow_run_block.assert_called_once()
call_kwargs = app.DATABASE.update_workflow_run_block.call_args[1]
app.DATABASE.observer.update_workflow_run_block.assert_called_once()
call_kwargs = app.DATABASE.observer.update_workflow_run_block.call_args[1]
assert call_kwargs["error_codes"] is None
assert result.error_codes == []

View file

@ -145,7 +145,7 @@ class TestCheckTriggerDepth:
mock_run = MagicMock()
mock_run.parent_workflow_run_id = None
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=mock_run)
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(return_value=mock_run)
depth = await block._check_trigger_depth("wr_current")
assert depth == 0
@ -158,7 +158,7 @@ class TestCheckTriggerDepth:
run_no_parent.parent_workflow_run_id = None
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=[run_with_parent, run_no_parent])
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=[run_with_parent, run_no_parent])
depth = await block._check_trigger_depth("wr_current")
assert depth == 1
@ -172,7 +172,7 @@ class TestCheckTriggerDepth:
runs.append(run)
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=runs)
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=runs)
with pytest.raises(InvalidWorkflowDefinition, match="depth exceeds maximum"):
await block._check_trigger_depth("wr_current")
@ -186,7 +186,7 @@ class TestCheckTriggerDepth:
runs.append(run)
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=runs)
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=runs)
depth = await block._check_trigger_depth("wr_current")
assert depth == block.MAX_TRIGGER_DEPTH - 1
@ -194,7 +194,7 @@ class TestCheckTriggerDepth:
async def test_run_not_found_returns_zero(self) -> None:
block = _make_block()
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=None)
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(return_value=None)
depth = await block._check_trigger_depth("wr_nonexistent")
assert depth == 0

View file

@ -43,7 +43,7 @@ def _make_session(status: str) -> PersistentBrowserSession:
async def test_rejects_update_when_already_final(desired_status: str):
"""A finalized session must not accept any status update."""
db = AsyncMock()
db.get_persistent_browser_session.return_value = _make_session(
db.browser_sessions.get_persistent_browser_session.return_value = _make_session(
PersistentBrowserSessionStatus.completed,
)