mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
AgentDB Phase 7: Migrate remaining 8 domains to typed repos (#5366)
This commit is contained in:
parent
58fed69496
commit
26b8f4d73e
84 changed files with 811 additions and 8804 deletions
60
scripts/check_no_direct_db_delegates.sh
Executable file
60
scripts/check_no_direct_db_delegates.sh
Executable file
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env bash
|
||||
# Detect direct calls to AgentDB backward-compatible delegate methods.
|
||||
# New code must use repository attributes (e.g. db.tasks.create_task)
|
||||
# instead of the legacy delegates (e.g. db.create_task).
|
||||
#
|
||||
# Called by tests/unit/test_no_direct_db_delegates.py
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
AGENT_DB="skyvern/forge/sdk/db/agent_db.py"
|
||||
|
||||
# Extract delegate method names from agent_db.py.
|
||||
# These are the "async def <name>" lines inside the delegate section (after line 170).
|
||||
delegate_methods=$(
|
||||
awk 'NR > 170 && /^ async def / { gsub(/.*async def /,""); gsub(/\(.*/,""); print }' "$AGENT_DB" \
|
||||
| sort -u
|
||||
)
|
||||
|
||||
if [ -z "$delegate_methods" ]; then
|
||||
echo "ERROR: Could not extract delegate methods from $AGENT_DB" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Build a grep alternation pattern for delegate method names.
|
||||
methods_pattern=$(echo "$delegate_methods" | paste -sd'|' -)
|
||||
|
||||
# Search for direct delegate calls on known AgentDB access patterns:
|
||||
# app.DATABASE.<method>( — should be app.DATABASE.<repo>.<method>(
|
||||
# REPLICA_DATABASE.<method>( — should be REPLICA_DATABASE.<repo>.<method>(
|
||||
# Exclude the delegate file itself and tests.
|
||||
db_pattern="(DATABASE|REPLICA_DATABASE)\.(${methods_pattern})\("
|
||||
|
||||
# Legacy files that still use direct delegates (grandfathered in).
|
||||
ALLOWLIST=(
|
||||
"$AGENT_DB"
|
||||
"tests/"
|
||||
"run_streaming.py"
|
||||
)
|
||||
|
||||
exclude_args=()
|
||||
for allowed in "${ALLOWLIST[@]}"; do
|
||||
exclude_args+=(":!${allowed}")
|
||||
done
|
||||
|
||||
violations=$(
|
||||
git grep -n -E "$db_pattern" -- '*.py' "${exclude_args[@]}" \
|
||||
2>/dev/null \
|
||||
|| true
|
||||
)
|
||||
|
||||
if [ -n "$violations" ]; then
|
||||
echo "Direct AgentDB delegate calls found. Use repository attributes instead."
|
||||
echo " e.g. app.DATABASE.tasks.create_task(...) not app.DATABASE.create_task(...)"
|
||||
echo ""
|
||||
echo "$violations"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No direct delegate calls found."
|
||||
exit 0
|
||||
|
|
@ -3296,13 +3296,13 @@ async def create_or_update_script_block(
|
|||
block_code_bytes = block_code if isinstance(block_code, bytes) else block_code.encode("utf-8")
|
||||
try:
|
||||
# Step 3: Create script block in database
|
||||
script_block = await app.DATABASE.get_script_block_by_label(
|
||||
script_block = await app.DATABASE.scripts.get_script_block_by_label(
|
||||
organization_id=organization_id,
|
||||
script_revision_id=script_revision_id,
|
||||
script_block_label=block_label,
|
||||
)
|
||||
if not script_block:
|
||||
script_block = await app.DATABASE.create_script_block(
|
||||
script_block = await app.DATABASE.scripts.create_script_block(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -3318,7 +3318,7 @@ async def create_or_update_script_block(
|
|||
for value in [run_signature, workflow_run_id, workflow_run_block_id, input_fields, requires_agent]
|
||||
):
|
||||
# Update metadata when new values are provided
|
||||
script_block = await app.DATABASE.update_script_block(
|
||||
script_block = await app.DATABASE.scripts.update_script_block(
|
||||
script_block_id=script_block.script_block_id,
|
||||
organization_id=organization_id,
|
||||
run_signature=run_signature,
|
||||
|
|
@ -3336,13 +3336,13 @@ async def create_or_update_script_block(
|
|||
# Create artifact and upload to S3
|
||||
artifact_id = None
|
||||
if update and script_block.script_file_id:
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id,
|
||||
script_block.script_file_id,
|
||||
organization_id,
|
||||
)
|
||||
if script_file and script_file.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
asyncio.create_task(app.STORAGE.store_artifact(artifact, block_code_bytes))
|
||||
else:
|
||||
|
|
@ -3357,7 +3357,7 @@ async def create_or_update_script_block(
|
|||
)
|
||||
|
||||
# Create script file record
|
||||
script_file = await app.DATABASE.create_script_file(
|
||||
script_file = await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -3371,7 +3371,7 @@ async def create_or_update_script_block(
|
|||
)
|
||||
|
||||
# update script block with script file id
|
||||
await app.DATABASE.update_script_block(
|
||||
await app.DATABASE.scripts.update_script_block(
|
||||
script_block_id=script_block.script_block_id,
|
||||
organization_id=organization_id,
|
||||
script_file_id=script_file.file_id,
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
|
||||
organization_id = context.organization_id if context else None
|
||||
step_id = context.step_id if context else None
|
||||
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
single_click_prompt = prompt_engine.load_prompt(
|
||||
template="single-click-action",
|
||||
navigation_goal=intention,
|
||||
|
|
@ -199,7 +199,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
"The element may not exist on the current page."
|
||||
)
|
||||
task_id = context.task_id if context else None
|
||||
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
if organization_id and task and step:
|
||||
actions = parse_actions(
|
||||
task, step.step_id, step.order, self.scraped_page, json_response.get("actions", [])
|
||||
|
|
@ -244,8 +244,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
task_id = context.task_id
|
||||
step_id = context.step_id
|
||||
workflow_run_id = context.workflow_run_id
|
||||
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
|
||||
if intention:
|
||||
try:
|
||||
|
|
@ -373,8 +373,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
task_id = context.task_id
|
||||
step_id = context.step_id
|
||||
workflow_run_id = context.workflow_run_id
|
||||
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
|
||||
if intention:
|
||||
try:
|
||||
|
|
@ -478,8 +478,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
option_value = value or ""
|
||||
context = skyvern_context.current()
|
||||
if context and context.task_id and context.step_id and context.organization_id:
|
||||
task = await app.DATABASE.get_task(context.task_id, organization_id=context.organization_id)
|
||||
step = await app.DATABASE.get_step(context.step_id, organization_id=context.organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(context.task_id, organization_id=context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(context.step_id, organization_id=context.organization_id)
|
||||
if intention and task and step:
|
||||
try:
|
||||
prompt = context.prompt if context else None
|
||||
|
|
@ -610,7 +610,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
step = None
|
||||
organization_id = context.organization_id if context else None
|
||||
if context and context.organization_id and context.step_id:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
step_id=context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -662,7 +662,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
if not context or not context.organization_id or not context.workflow_permanent_id:
|
||||
return
|
||||
block_label = self.current_label or "unknown"
|
||||
await app.DATABASE.record_branch_hit(
|
||||
await app.DATABASE.scripts.record_branch_hit(
|
||||
organization_id=context.organization_id,
|
||||
workflow_permanent_id=context.workflow_permanent_id,
|
||||
block_label=block_label,
|
||||
|
|
@ -734,7 +734,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
# Record an element fallback episode for the feedback loop
|
||||
if context.workflow_run_id and context.workflow_permanent_id:
|
||||
try:
|
||||
await app.DATABASE.create_fallback_episode(
|
||||
await app.DATABASE.scripts.create_fallback_episode(
|
||||
organization_id=context.organization_id,
|
||||
workflow_permanent_id=context.workflow_permanent_id,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
|
|
@ -792,7 +792,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
)
|
||||
step = None
|
||||
if context and context.organization_id and context.task_id and context.step_id:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
step_id=context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -868,7 +868,7 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
|
||||
step = None
|
||||
if context.organization_id and context.task_id and context.step_id:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
step_id=context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -944,8 +944,8 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
|||
task_id = context.task_id
|
||||
step_id = context.step_id
|
||||
|
||||
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
task = await app.DATABASE.tasks.get_task(task_id, organization_id) if task_id and organization_id else None
|
||||
step = await app.DATABASE.tasks.get_step(step_id, organization_id) if step_id and organization_id else None
|
||||
|
||||
if not task or not step:
|
||||
LOG.warning("ai_act: missing task or step", task_id=task_id, step_id=step_id)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
async def _get_or_create_browser_state(cls, browser_session_id: str | None = None) -> BrowserState:
|
||||
context = skyvern_context.current()
|
||||
if context and context.workflow_run_id and context.organization_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=context.workflow_run_id, organization_id=context.organization_id
|
||||
)
|
||||
if workflow_run:
|
||||
|
|
@ -101,7 +101,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
async def _get_browser_state(cls) -> BrowserState | None:
|
||||
context = skyvern_context.current()
|
||||
if context and context.workflow_run_id and context.organization_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=context.workflow_run_id, organization_id=context.organization_id
|
||||
)
|
||||
if workflow_run:
|
||||
|
|
@ -153,7 +153,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
organization_id = context.organization_id
|
||||
download_timeout = BROWSER_DOWNLOAD_TIMEOUT
|
||||
if context.task_id:
|
||||
task = await app.DATABASE.get_task(context.task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(context.task_id, organization_id=organization_id)
|
||||
if task and task.download_timeout:
|
||||
download_timeout = task.download_timeout
|
||||
await check_downloading_files_and_wait_for_download_to_complete(
|
||||
|
|
@ -392,7 +392,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
|
||||
except Exception:
|
||||
LOG.warning("Failed to generate action reasoning, using fallback", action_type=action_type)
|
||||
await app.DATABASE.update_action_reasoning(
|
||||
await app.DATABASE.workflow_params.update_action_reasoning(
|
||||
organization_id=organization_id,
|
||||
action_id=action_id,
|
||||
reasoning=reasoning,
|
||||
|
|
@ -496,7 +496,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
created_by="script",
|
||||
)
|
||||
|
||||
created_action = await app.DATABASE.create_action(action)
|
||||
created_action = await app.DATABASE.workflow_params.create_action(action)
|
||||
# Skip LLM reasoning in script mode — use static string instead.
|
||||
# Build a descriptive label from the selector for the timeline.
|
||||
if context and context.script_mode:
|
||||
|
|
@ -521,7 +521,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
else:
|
||||
label = selector[:60]
|
||||
reasoning = f"Script execution: {label}" if label else "Script execution"
|
||||
await app.DATABASE.update_action_reasoning(
|
||||
await app.DATABASE.workflow_params.update_action_reasoning(
|
||||
organization_id=str(context.organization_id),
|
||||
action_id=str(created_action.action_id),
|
||||
reasoning=reasoning,
|
||||
|
|
@ -586,7 +586,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
screenshot = await browser_state.take_post_action_screenshot(scrolling_number=0)
|
||||
|
||||
if screenshot:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -642,7 +642,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
html = await skyvern_frame.get_content()
|
||||
|
||||
if html:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -690,7 +690,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
screenshot = await browser_state.take_fullpage_screenshot()
|
||||
|
||||
if screenshot:
|
||||
step = await app.DATABASE.get_step(
|
||||
step = await app.DATABASE.tasks.get_step(
|
||||
context.step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -981,8 +981,8 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
await app.AGENT_FUNCTION.auto_solve_captchas(self.page)
|
||||
return None
|
||||
|
||||
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
|
||||
if task and step:
|
||||
solve_captcha_handler = ActionHandler._handled_action_types[ActionType.SOLVE_CAPTCHA]
|
||||
action = SolveCaptchaAction(
|
||||
|
|
@ -1014,15 +1014,15 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
if context.script_mode:
|
||||
print(" ⏭ Skipping complete() verification (--no-verify)")
|
||||
return
|
||||
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
|
||||
if task and step:
|
||||
# CRITICAL: Update step.output with actions_and_results BEFORE validation
|
||||
# This ensures complete_verify() can access action history (including download info)
|
||||
# when checking if the goal was achieved
|
||||
await self._update_step_output_before_complete(context)
|
||||
# Refresh step to get updated output for validation
|
||||
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
|
||||
if not step:
|
||||
return
|
||||
|
||||
|
|
@ -1062,8 +1062,8 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
msg += ": " + "; ".join(errors)
|
||||
raise ScriptTerminationException(msg)
|
||||
|
||||
task = await app.DATABASE.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.get_step(context.step_id, context.organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(context.task_id, context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(context.step_id, context.organization_id)
|
||||
if task and step:
|
||||
action = TerminateAction(
|
||||
organization_id=context.organization_id,
|
||||
|
|
@ -1116,7 +1116,7 @@ class ScriptSkyvernPage(SkyvernPage):
|
|||
errors=errors,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
step_id=context.step_id,
|
||||
task_id=context.task_id,
|
||||
organization_id=context.organization_id,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
|
|||
workflow_definition_blocks = workflow.workflow_definition.blocks
|
||||
|
||||
# get workflow run blocks for task execution data
|
||||
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
workflow_run_blocks.sort(key=lambda x: x.created_at)
|
||||
|
|
@ -93,7 +93,7 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
|
|||
if all_task_ids:
|
||||
task_ids_list = list(all_task_ids)
|
||||
# Single query for all tasks
|
||||
tasks = await app.DATABASE.get_tasks_by_ids(task_ids=task_ids_list, organization_id=organization_id)
|
||||
tasks = await app.DATABASE.tasks.get_tasks_by_ids(task_ids=task_ids_list, organization_id=organization_id)
|
||||
tasks_by_id = {task.task_id: task for task in tasks}
|
||||
LOG.debug(
|
||||
"Batch fetched tasks for code gen",
|
||||
|
|
@ -102,7 +102,9 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz
|
|||
)
|
||||
|
||||
# Single query for all actions (returns desc order for timeline; reverse for chronological)
|
||||
all_actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids_list, organization_id=organization_id)
|
||||
all_actions = await app.DATABASE.tasks.get_tasks_actions(
|
||||
task_ids=task_ids_list, organization_id=organization_id
|
||||
)
|
||||
all_actions.reverse()
|
||||
for action in all_actions:
|
||||
if action.task_id:
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ class ForgeAgent:
|
|||
LOG.info("No browser state found for workflow run, setting task url to empty string")
|
||||
task_url = ""
|
||||
|
||||
task = await app.DATABASE.create_task(
|
||||
task = await app.DATABASE.tasks.create_task(
|
||||
url=task_url,
|
||||
task_type=task_block.task_type,
|
||||
complete_criterion=task_block.complete_criterion,
|
||||
|
|
@ -244,13 +244,13 @@ class ForgeAgent:
|
|||
task_retry=task_retry,
|
||||
)
|
||||
# Update task status to running
|
||||
task = await app.DATABASE.update_task(
|
||||
task = await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
status=TaskStatus.running,
|
||||
)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
step = await app.DATABASE.tasks.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
|
|
@ -270,14 +270,14 @@ class ForgeAgent:
|
|||
totp_verification_url = str(task_request.totp_verification_url) if task_request.totp_verification_url else None
|
||||
# validate browser session id
|
||||
if task_request.browser_session_id:
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session(
|
||||
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
|
||||
session_id=task_request.browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not browser_session:
|
||||
raise BrowserSessionNotFound(browser_session_id=task_request.browser_session_id)
|
||||
|
||||
task = await app.DATABASE.create_task(
|
||||
task = await app.DATABASE.tasks.create_task(
|
||||
url=str(task_request.url),
|
||||
title=task_request.title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
|
|
@ -345,7 +345,7 @@ class ForgeAgent:
|
|||
|
||||
workflow_run: WorkflowRun | None = None
|
||||
if task.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
|
@ -381,7 +381,9 @@ class ForgeAgent:
|
|||
)
|
||||
return step, None, None
|
||||
|
||||
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id)
|
||||
refreshed_task = await app.DATABASE.tasks.get_task(
|
||||
task_id=task.task_id, organization_id=organization.organization_id
|
||||
)
|
||||
if refreshed_task:
|
||||
task = refreshed_task
|
||||
|
||||
|
|
@ -413,7 +415,7 @@ class ForgeAgent:
|
|||
or settings.MAX_STEPS_PER_RUN
|
||||
)
|
||||
if max_steps_per_run and task.max_steps_per_run != max_steps_per_run:
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization.organization_id,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
|
|
@ -938,7 +940,7 @@ class ForgeAgent:
|
|||
# Only pass new errors — update_task() appends to existing errors
|
||||
if detected_errors:
|
||||
new_errors = [error.model_dump() for error in detected_errors]
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
errors=new_errors,
|
||||
|
|
@ -1321,7 +1323,7 @@ class ForgeAgent:
|
|||
action_order=action_idx,
|
||||
)
|
||||
detailed_agent_step_output.actions_and_results[action_idx] = (action, [action_result])
|
||||
action.action_id = (await app.DATABASE.create_action(action=action)).action_id
|
||||
action.action_id = (await app.DATABASE.workflow_params.create_action(action=action)).action_id
|
||||
await self.record_artifacts_after_action(task, step, browser_state, engine, action)
|
||||
break
|
||||
|
||||
|
|
@ -1570,7 +1572,7 @@ class ForgeAgent:
|
|||
):
|
||||
working_page = await browser_state.must_get_working_page()
|
||||
# refresh task in case the extracted information is updated previously
|
||||
refreshed_task = await app.DATABASE.get_task(task.task_id, task.organization_id)
|
||||
refreshed_task = await app.DATABASE.tasks.get_task(task.task_id, task.organization_id)
|
||||
assert refreshed_task is not None
|
||||
task = refreshed_task
|
||||
extract_action = await self.create_extract_action(task, step, scraped_page)
|
||||
|
|
@ -1664,7 +1666,7 @@ class ForgeAgent:
|
|||
cached_tokens = first_response.usage.input_tokens_details.cached_tokens or 0
|
||||
reasoning_tokens = first_response.usage.output_tokens_details.reasoning_tokens or 0
|
||||
llm_cost = (3.0 / 1000000) * input_tokens + (12.0 / 1000000) * output_tokens
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=task.organization_id,
|
||||
|
|
@ -1777,7 +1779,7 @@ class ForgeAgent:
|
|||
cached_tokens = current_response.usage.input_tokens_details.cached_tokens or 0
|
||||
reasoning_tokens = current_response.usage.output_tokens_details.reasoning_tokens or 0
|
||||
llm_cost = (3.0 / 1000000) * input_tokens + (12.0 / 1000000) * output_tokens
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=task.organization_id,
|
||||
|
|
@ -2107,7 +2109,7 @@ class ForgeAgent:
|
|||
or incremental_reasoning_tokens is not None
|
||||
or incremental_cached_tokens is not None
|
||||
):
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
|
|
@ -2452,7 +2454,7 @@ class ForgeAgent:
|
|||
)
|
||||
else:
|
||||
try:
|
||||
await app.DATABASE.update_action_screenshot_artifact_id(
|
||||
await app.DATABASE.artifacts.update_action_screenshot_artifact_id(
|
||||
organization_id=action.organization_id,
|
||||
action_id=action.action_id,
|
||||
screenshot_artifact_id=screenshot_artifact_id,
|
||||
|
|
@ -3406,7 +3408,7 @@ class ForgeAgent:
|
|||
Find the last successful ScrapeAction for the task and return the extracted information.
|
||||
"""
|
||||
# TODO: make sure we can get extracted information with the ExtractAction change
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
steps = await app.DATABASE.tasks.get_task_steps(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
|
|
@ -3439,7 +3441,7 @@ class ForgeAgent:
|
|||
Find the TerminateAction for the task and return the reasoning.
|
||||
# TODO (kerem): Also return meaningful exceptions when we add them [WYV-311]
|
||||
"""
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
steps = await app.DATABASE.tasks.get_task_steps(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
|
|
@ -3475,7 +3477,9 @@ class ForgeAgent:
|
|||
"""
|
||||
# refresh the task from the db to get the latest status
|
||||
try:
|
||||
refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
refreshed_task = await app.DATABASE.tasks.get_task(
|
||||
task_id=task.task_id, organization_id=task.organization_id
|
||||
)
|
||||
if not refreshed_task:
|
||||
LOG.error("Failed to get task from db when clean up task", task_id=task.task_id)
|
||||
raise TaskNotFound(task_id=task.task_id)
|
||||
|
|
@ -3579,7 +3583,7 @@ class ForgeAgent:
|
|||
task_id=task.task_id,
|
||||
)
|
||||
return
|
||||
last_step = await app.DATABASE.get_latest_step(task.task_id, organization_id=task.organization_id)
|
||||
last_step = await app.DATABASE.tasks.get_latest_step(task.task_id, organization_id=task.organization_id)
|
||||
|
||||
task_response = await self.build_task_response(task=task, last_step=last_step)
|
||||
# try to build the new TaskRunResponse for backward compatibility
|
||||
|
|
@ -3622,7 +3626,7 @@ class ForgeAgent:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
webhook_failure_reason="",
|
||||
|
|
@ -3635,7 +3639,7 @@ class ForgeAgent:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",
|
||||
|
|
@ -3665,7 +3669,7 @@ class ForgeAgent:
|
|||
downloaded_files: list[FileInfo] | None = None
|
||||
|
||||
# get the artifact of the screenshot and get the screenshot_url
|
||||
screenshot_artifact = await app.DATABASE.get_artifact(
|
||||
screenshot_artifact = await app.DATABASE.artifacts.get_artifact(
|
||||
task_id=task.task_id,
|
||||
step_id=last_step.step_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_FINAL,
|
||||
|
|
@ -3689,9 +3693,11 @@ class ForgeAgent:
|
|||
LOG.warning("Timeout getting recordings", browser_session_id=task.browser_session_id)
|
||||
|
||||
if recording_url is None:
|
||||
first_step = await app.DATABASE.get_first_step(task_id=task.task_id, organization_id=task.organization_id)
|
||||
first_step = await app.DATABASE.tasks.get_first_step(
|
||||
task_id=task.task_id, organization_id=task.organization_id
|
||||
)
|
||||
if first_step:
|
||||
recording_artifact = await app.DATABASE.get_artifact(
|
||||
recording_artifact = await app.DATABASE.artifacts.get_artifact(
|
||||
task_id=task.task_id,
|
||||
step_id=first_step.step_id,
|
||||
artifact_type=ArtifactType.RECORDING,
|
||||
|
|
@ -3701,7 +3707,7 @@ class ForgeAgent:
|
|||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
# get the artifact of the last TASK_RESPONSE_ACTION_SCREENSHOT_COUNT screenshots and get the screenshot_url
|
||||
latest_action_screenshot_artifacts = await app.DATABASE.get_latest_n_artifacts(
|
||||
latest_action_screenshot_artifacts = await app.DATABASE.artifacts.get_latest_n_artifacts(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION],
|
||||
|
|
@ -3737,7 +3743,7 @@ class ForgeAgent:
|
|||
)
|
||||
|
||||
if need_browser_log:
|
||||
browser_console_log = await app.DATABASE.get_latest_artifact(
|
||||
browser_console_log = await app.DATABASE.artifacts.get_latest_artifact(
|
||||
task_id=task.task_id,
|
||||
artifact_types=[ArtifactType.BROWSER_CONSOLE_LOG],
|
||||
organization_id=task.organization_id,
|
||||
|
|
@ -3746,7 +3752,7 @@ class ForgeAgent:
|
|||
browser_console_log_url = await app.ARTIFACT_MANAGER.get_share_link(browser_console_log)
|
||||
|
||||
# get the latest task from the db to get the latest status, extracted_information, and failure_reason
|
||||
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
task_from_db = await app.DATABASE.tasks.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
if not task_from_db:
|
||||
LOG.error("Failed to get task from db when sending task response")
|
||||
raise TaskNotFound(task_id=task.task_id)
|
||||
|
|
@ -3886,7 +3892,7 @@ class ForgeAgent:
|
|||
|
||||
await save_step_logs(step.step_id)
|
||||
|
||||
return await app.DATABASE.update_step(
|
||||
return await app.DATABASE.tasks.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
|
|
@ -3904,7 +3910,7 @@ class ForgeAgent:
|
|||
failure_category: list[dict[str, Any]] | None = None,
|
||||
) -> Task:
|
||||
# refresh task from db to get the latest status
|
||||
task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
task_from_db = await app.DATABASE.tasks.get_task(task_id=task.task_id, organization_id=task.organization_id)
|
||||
if task_from_db:
|
||||
task = task_from_db
|
||||
|
||||
|
|
@ -3944,7 +3950,7 @@ class ForgeAgent:
|
|||
|
||||
await save_task_logs(task.task_id)
|
||||
LOG.info("Updating task in db", task_id=task.task_id, diff=update_comparison)
|
||||
return await app.DATABASE.update_task(
|
||||
return await app.DATABASE.tasks.update_task(
|
||||
task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
**updates,
|
||||
|
|
@ -3990,7 +3996,7 @@ class ForgeAgent:
|
|||
name=f"verify_goal_{step.step_id}",
|
||||
)
|
||||
|
||||
next_step = await app.DATABASE.create_step(
|
||||
next_step = await app.DATABASE.tasks.create_step(
|
||||
task_id=task.task_id,
|
||||
order=step.order + 1,
|
||||
retry_index=0,
|
||||
|
|
@ -4297,7 +4303,7 @@ class ForgeAgent:
|
|||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
next_step = await app.DATABASE.create_step(
|
||||
next_step = await app.DATABASE.tasks.create_step(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
order=step.order,
|
||||
|
|
@ -4316,7 +4322,7 @@ class ForgeAgent:
|
|||
llm_errors: list[str] = []
|
||||
|
||||
try:
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
steps = await app.DATABASE.tasks.get_task_steps(
|
||||
task_id=task.task_id, organization_id=organization.organization_id
|
||||
)
|
||||
for step_cnt, step in enumerate(steps):
|
||||
|
|
@ -4434,7 +4440,7 @@ class ForgeAgent:
|
|||
steps_without_actions = 0
|
||||
|
||||
try:
|
||||
steps = await app.DATABASE.get_task_steps(
|
||||
steps = await app.DATABASE.tasks.get_task_steps(
|
||||
task_id=task.task_id, organization_id=organization.organization_id
|
||||
)
|
||||
|
||||
|
|
@ -4722,7 +4728,7 @@ class ForgeAgent:
|
|||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
next_step = await app.DATABASE.create_step(
|
||||
next_step = await app.DATABASE.tasks.create_step(
|
||||
task_id=task.task_id,
|
||||
order=step.order + 1,
|
||||
retry_index=0,
|
||||
|
|
@ -4843,7 +4849,7 @@ class ForgeAgent:
|
|||
if not otp_value and (task.totp_verification_url or task.totp_identifier) and task.organization_id:
|
||||
workflow_id = workflow_permanent_id = None
|
||||
if task.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(task.workflow_run_id)
|
||||
if workflow_run:
|
||||
workflow_id = workflow_run.workflow_id
|
||||
workflow_permanent_id = workflow_run.workflow_permanent_id
|
||||
|
|
@ -4891,7 +4897,7 @@ class ForgeAgent:
|
|||
|
||||
@staticmethod
|
||||
async def get_task_errors(task: Task) -> list[UserDefinedError]:
|
||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
errors = []
|
||||
for step in steps:
|
||||
if step.output and step.output.errors:
|
||||
|
|
@ -4907,7 +4913,7 @@ class ForgeAgent:
|
|||
step_errors = detailed_step_output.extract_errors() or []
|
||||
task_errors.extend([error.model_dump() for error in step_errors])
|
||||
|
||||
return await app.DATABASE.update_task(
|
||||
return await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
errors=task_errors,
|
||||
|
|
@ -4962,7 +4968,7 @@ class ForgeAgent:
|
|||
"""
|
||||
Run the extraction flow when a task with a data extraction goal completes during parallel verification.
|
||||
"""
|
||||
refreshed_task = await app.DATABASE.get_task(task.task_id, task.organization_id)
|
||||
refreshed_task = await app.DATABASE.tasks.get_task(task.task_id, task.organization_id)
|
||||
if refreshed_task:
|
||||
task = refreshed_task
|
||||
|
||||
|
|
|
|||
|
|
@ -452,7 +452,7 @@ class AgentFunction:
|
|||
if not has_valid_step_status:
|
||||
reasons.append(f"invalid_step_status:{step.status}")
|
||||
# can't execute if the task has another step that is running
|
||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
|
||||
if not has_no_running_steps:
|
||||
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
|
||||
|
|
|
|||
|
|
@ -836,7 +836,7 @@ class LLMAPIHandlerFactory:
|
|||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
if step and not is_speculative_step:
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
|
|
@ -847,7 +847,7 @@ class LLMAPIHandlerFactory:
|
|||
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
|
||||
)
|
||||
if thought:
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=thought.organization_id,
|
||||
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
|
||||
|
|
@ -1302,7 +1302,7 @@ class LLMAPIHandlerFactory:
|
|||
_log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens)
|
||||
|
||||
if step and not is_speculative_step:
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
|
|
@ -1313,7 +1313,7 @@ class LLMAPIHandlerFactory:
|
|||
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
|
||||
)
|
||||
if thought:
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=thought.organization_id,
|
||||
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
|
||||
|
|
@ -1704,7 +1704,7 @@ class LLMCaller:
|
|||
|
||||
call_stats = await self.get_call_stats(response)
|
||||
if step and not is_speculative_step:
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
|
|
@ -1715,7 +1715,7 @@ class LLMCaller:
|
|||
incremental_cached_tokens=call_stats.cached_tokens,
|
||||
)
|
||||
if thought:
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=thought.organization_id,
|
||||
input_token_count=call_stats.input_tokens,
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ class ArtifactManager:
|
|||
if not workflow_run_block_id and context:
|
||||
workflow_run_block_id = context.parent_workflow_run_block_id
|
||||
|
||||
artifact = await app.DATABASE.create_artifact(
|
||||
artifact = await app.DATABASE.artifacts.create_artifact(
|
||||
artifact_id,
|
||||
artifact_type,
|
||||
uri,
|
||||
|
|
@ -521,7 +521,7 @@ class ArtifactManager:
|
|||
artifact_models = [artifact_data.artifact_model for artifact_data in request.artifacts]
|
||||
|
||||
# Bulk insert artifacts
|
||||
artifacts = await app.DATABASE.bulk_create_artifacts(artifact_models)
|
||||
artifacts = await app.DATABASE.artifacts.bulk_create_artifacts(artifact_models)
|
||||
|
||||
# Fire and forget upload tasks
|
||||
for artifact, artifact_data in zip(artifacts, request.artifacts):
|
||||
|
|
@ -823,7 +823,7 @@ class ArtifactManager:
|
|||
) -> None:
|
||||
if not artifact_id or not organization_id:
|
||||
return None
|
||||
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(artifact_id, organization_id)
|
||||
if not artifact:
|
||||
return
|
||||
# Fire and forget
|
||||
|
|
@ -1178,12 +1178,12 @@ class ArtifactManager:
|
|||
)
|
||||
for artifact_type, filename, artifact_id in accumulator.member_types
|
||||
]
|
||||
await app.DATABASE.bulk_create_artifacts([parent_model, *member_models])
|
||||
await app.DATABASE.artifacts.bulk_create_artifacts([parent_model, *member_models])
|
||||
|
||||
# Apply deferred action.screenshot_artifact_id updates now that artifact rows exist.
|
||||
for organization_id, action_id, artifact_id in accumulator.pending_action_screenshot_updates:
|
||||
try:
|
||||
await app.DATABASE.update_action_screenshot_artifact_id(
|
||||
await app.DATABASE.artifacts.update_action_screenshot_artifact_id(
|
||||
organization_id=organization_id,
|
||||
action_id=action_id,
|
||||
screenshot_artifact_id=artifact_id,
|
||||
|
|
@ -1267,7 +1267,7 @@ class ArtifactManager:
|
|||
)
|
||||
for filename, (artifact_type, _) in entries.items()
|
||||
]
|
||||
await app.DATABASE.bulk_create_artifacts([parent_model, *member_models])
|
||||
await app.DATABASE.artifacts.bulk_create_artifacts([parent_model, *member_models])
|
||||
|
||||
async def wait_for_upload_aiotasks(self, primary_keys: list[str]) -> None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError # noqa: F401
|
||||
from skyvern.forge.sdk.db.models import PersistentBrowserSessionModel
|
||||
from skyvern.forge.sdk.db.repositories.artifacts import ArtifactsRepository
|
||||
from skyvern.forge.sdk.db.repositories.browser_sessions import BrowserSessionsRepository
|
||||
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
|
||||
|
|
@ -169,9 +170,6 @@ class AgentDB(BaseAlchemyDB):
|
|||
|
||||
# ======================================================================
|
||||
# Backward-compatible delegate methods
|
||||
# TODO(SKY-62): These delegates erase type information (*args: Any -> Any).
|
||||
# Migrate callers to use typed repository attributes directly
|
||||
# (e.g., db.tasks.get_task(...) instead of db.get_task(...)), then remove.
|
||||
# ======================================================================
|
||||
|
||||
# -- Task delegates --
|
||||
|
|
@ -529,17 +527,17 @@ class AgentDB(BaseAlchemyDB):
|
|||
async def close_persistent_browser_session(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.close_persistent_browser_session(*args, **kwargs)
|
||||
|
||||
async def get_all_active_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_all_active_persistent_browser_sessions(*args, **kwargs)
|
||||
async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
return await self.browser_sessions.get_all_active_persistent_browser_sessions()
|
||||
|
||||
async def archive_browser_session_address(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.archive_browser_session_address(*args, **kwargs)
|
||||
|
||||
async def get_uncompleted_persistent_browser_sessions(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_uncompleted_persistent_browser_sessions(*args, **kwargs)
|
||||
async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
return await self.browser_sessions.get_uncompleted_persistent_browser_sessions()
|
||||
|
||||
async def get_debug_session_by_browser_session_id(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.browser_sessions.get_debug_session_by_browser_session_id(*args, **kwargs)
|
||||
return await self.debug.get_debug_session_by_browser_session_id(*args, **kwargs)
|
||||
|
||||
# -- Schedule delegates --
|
||||
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
# DEPRECATED: These mixin classes have been superseded by the repository pattern
|
||||
# in skyvern/forge/sdk/db/repositories/. AgentDB no longer inherits from these
|
||||
# mixins — it composes repository instances instead. These files are retained
|
||||
# temporarily to avoid breaking any downstream code that may import from here.
|
||||
# TODO(SKY-62): Remove after 2026-Q2 once all imports are verified migrated.
|
||||
from skyvern.forge.sdk.db.mixins.artifacts import ArtifactsMixin
|
||||
from skyvern.forge.sdk.db.mixins.base import BaseAlchemyDB, read_retry
|
||||
from skyvern.forge.sdk.db.mixins.browser_sessions import BrowserSessionsMixin
|
||||
from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin
|
||||
from skyvern.forge.sdk.db.mixins.debug import DebugMixin
|
||||
from skyvern.forge.sdk.db.mixins.folders import FoldersMixin
|
||||
from skyvern.forge.sdk.db.mixins.observer import ObserverMixin
|
||||
from skyvern.forge.sdk.db.mixins.organizations import OrganizationsMixin
|
||||
from skyvern.forge.sdk.db.mixins.otp import OTPMixin
|
||||
from skyvern.forge.sdk.db.mixins.schedules import SchedulesMixin
|
||||
from skyvern.forge.sdk.db.mixins.scripts import ScriptsMixin
|
||||
from skyvern.forge.sdk.db.mixins.tasks import TasksMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflow_parameters import WorkflowParametersMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflow_runs import WorkflowRunsMixin
|
||||
from skyvern.forge.sdk.db.mixins.workflows import WorkflowsMixin
|
||||
|
||||
__all__ = [
|
||||
"ArtifactsMixin",
|
||||
"BaseAlchemyDB",
|
||||
"BrowserSessionsMixin",
|
||||
"CredentialsMixin",
|
||||
"DebugMixin",
|
||||
"FoldersMixin",
|
||||
"ObserverMixin",
|
||||
"OrganizationsMixin",
|
||||
"OTPMixin",
|
||||
"SchedulesMixin",
|
||||
"ScriptsMixin",
|
||||
"TasksMixin",
|
||||
"WorkflowParametersMixin",
|
||||
"WorkflowRunsMixin",
|
||||
"WorkflowsMixin",
|
||||
"read_retry",
|
||||
]
|
||||
|
|
@ -1,452 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import and_, delete, or_, select, update
|
||||
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.models import ActionModel, ArtifactModel
|
||||
from skyvern.forge.sdk.db.utils import convert_to_artifact
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
|
||||
import structlog
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactsMixin:
|
||||
"""Database operations for artifact management."""
|
||||
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
|
||||
async def get_run(self, run_id: str, organization_id: str | None = None) -> Run | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@db_operation("create_artifact")
|
||||
async def create_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
artifact_type: str,
|
||||
uri: str,
|
||||
organization_id: str,
|
||||
step_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
ai_suggestion_id: str | None = None,
|
||||
) -> Artifact:
|
||||
async with self.Session() as session:
|
||||
new_artifact = ArtifactModel(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
observer_cruise_id=task_v2_id,
|
||||
observer_thought_id=thought_id,
|
||||
run_id=run_id,
|
||||
ai_suggestion_id=ai_suggestion_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_artifact)
|
||||
await session.commit()
|
||||
await session.refresh(new_artifact)
|
||||
return convert_to_artifact(new_artifact, self.debug_enabled)
|
||||
|
||||
@db_operation("bulk_create_artifacts")
|
||||
async def bulk_create_artifacts(
|
||||
self,
|
||||
artifact_models: list[ArtifactModel],
|
||||
) -> list[Artifact]:
|
||||
"""
|
||||
Bulk create multiple artifacts in a single database transaction.
|
||||
|
||||
Args:
|
||||
artifact_models: List of ArtifactModel instances to insert
|
||||
|
||||
Returns:
|
||||
List of created Artifact objects
|
||||
"""
|
||||
if not artifact_models:
|
||||
return []
|
||||
|
||||
async with self.Session() as session:
|
||||
session.add_all(artifact_models)
|
||||
await session.commit()
|
||||
|
||||
# Refresh all artifacts to get their created_at and modified_at values
|
||||
for artifact in artifact_models:
|
||||
await session.refresh(artifact)
|
||||
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifact_models]
|
||||
|
||||
@db_operation("get_artifacts_for_task_v2")
|
||||
async def get_artifacts_for_task_v2(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
organization_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ArtifactModel)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
if artifact_types:
|
||||
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
query = query.order_by(ArtifactModel.created_at)
|
||||
if artifacts := (await session.scalars(query)).all():
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_artifacts_for_task_step")
|
||||
async def get_artifacts_for_task_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
if artifacts := (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(ArtifactModel.created_at)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_artifacts_for_run")
|
||||
async def get_artifacts_for_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
group_by_type: bool = False,
|
||||
sort_by: str = "created_at",
|
||||
) -> dict[ArtifactType, list[Artifact]] | list[Artifact]:
|
||||
"""Return artifacts associated with a run.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to get artifacts for
|
||||
organization_id: The ID of the organization that owns the run
|
||||
artifact_types: Optional list of artifact types to filter by
|
||||
group_by_type: If True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If False, returns a flat list of artifacts. Defaults to False.
|
||||
sort_by: Field to sort artifacts by. Must be one of: 'created_at', 'step_id', 'task_id'.
|
||||
Defaults to 'created_at'.
|
||||
|
||||
Returns:
|
||||
If group_by_type is True, returns a dictionary mapping artifact types to lists of artifacts.
|
||||
If group_by_type is False, returns a list of artifacts sorted by the specified field.
|
||||
|
||||
Raises:
|
||||
ValueError: If sort_by is not one of the allowed values
|
||||
"""
|
||||
allowed_sort_fields = {"created_at", "step_id", "task_id"}
|
||||
if sort_by not in allowed_sort_fields:
|
||||
raise ValueError(f"sort_by must be one of {allowed_sort_fields}")
|
||||
run = await self.get_run(run_id, organization_id=organization_id)
|
||||
|
||||
async with self.Session() as session:
|
||||
query = select(ArtifactModel).filter_by(organization_id=organization_id)
|
||||
|
||||
if run:
|
||||
# Workflow run — filter by workflow_run_id
|
||||
query = query.filter_by(workflow_run_id=run.workflow_run_id)
|
||||
elif run_id.startswith("tsk_"):
|
||||
# Task run — get_run only handles workflow runs,
|
||||
# so fall back to filtering by task_id for task-based artifacts
|
||||
query = query.filter_by(task_id=run_id)
|
||||
else:
|
||||
return []
|
||||
|
||||
if artifact_types:
|
||||
query = query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == "created_at":
|
||||
query = query.order_by(ArtifactModel.created_at)
|
||||
elif sort_by == "step_id":
|
||||
query = query.order_by(ArtifactModel.step_id, ArtifactModel.created_at)
|
||||
elif sort_by == "task_id":
|
||||
query = query.order_by(ArtifactModel.task_id, ArtifactModel.created_at)
|
||||
|
||||
# Execute query and convert to Artifact objects
|
||||
artifacts = [
|
||||
convert_to_artifact(artifact, self.debug_enabled) for artifact in (await session.scalars(query)).all()
|
||||
]
|
||||
|
||||
# Group artifacts by type if requested
|
||||
if group_by_type:
|
||||
result: dict[ArtifactType, list[Artifact]] = {}
|
||||
for artifact in artifacts:
|
||||
if artifact.artifact_type not in result:
|
||||
result[artifact.artifact_type] = []
|
||||
result[artifact.artifact_type].append(artifact)
|
||||
return result
|
||||
|
||||
return artifacts
|
||||
|
||||
@db_operation("get_artifact_by_id")
|
||||
async def get_artifact_by_id(
|
||||
self,
|
||||
artifact_id: str,
|
||||
organization_id: str,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
if artifact := (
|
||||
await session.scalars(
|
||||
select(ArtifactModel).filter_by(artifact_id=artifact_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_artifact_by_id_no_org")
|
||||
async def get_artifact_by_id_no_org(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> Artifact | None:
|
||||
"""Fetch an artifact by ID without an organization filter.
|
||||
|
||||
Only use this when the caller has already verified authorization through
|
||||
an out-of-band mechanism (e.g. a valid HMAC-signed URL).
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
if artifact := (await session.scalars(select(ArtifactModel).filter_by(artifact_id=artifact_id))).first():
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_artifacts_by_ids")
|
||||
async def get_artifacts_by_ids(
|
||||
self,
|
||||
artifact_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> list[Artifact]:
|
||||
if not artifact_ids:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
artifacts = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter(ArtifactModel.artifact_id.in_(artifact_ids))
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
|
||||
@db_operation("get_artifacts_by_entity_id")
|
||||
async def get_artifacts_by_entity_id(
|
||||
self,
|
||||
*,
|
||||
organization_id: str | None,
|
||||
artifact_type: ArtifactType | None = None,
|
||||
task_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Artifact]:
|
||||
async with self.Session() as session:
|
||||
# Build base query
|
||||
query = select(ArtifactModel)
|
||||
|
||||
if artifact_type is not None:
|
||||
query = query.filter_by(artifact_type=artifact_type)
|
||||
if task_id is not None:
|
||||
query = query.filter_by(task_id=task_id)
|
||||
if step_id is not None:
|
||||
query = query.filter_by(step_id=step_id)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter_by(workflow_run_id=workflow_run_id)
|
||||
if workflow_run_block_id is not None:
|
||||
query = query.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
if thought_id is not None:
|
||||
query = query.filter_by(observer_thought_id=thought_id)
|
||||
if task_v2_id is not None:
|
||||
query = query.filter_by(observer_cruise_id=task_v2_id)
|
||||
# Handle backward compatibility where old artifact rows were stored with organization_id NULL
|
||||
if organization_id is not None:
|
||||
query = query.filter(
|
||||
or_(ArtifactModel.organization_id == organization_id, ArtifactModel.organization_id.is_(None))
|
||||
)
|
||||
|
||||
query = query.order_by(ArtifactModel.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
artifacts = (await session.scalars(query)).all()
|
||||
LOG.debug("Artifacts fetched", count=len(artifacts))
|
||||
return [convert_to_artifact(a, self.debug_enabled) for a in artifacts]
|
||||
|
||||
@db_operation("get_artifact_by_entity_id")
|
||||
async def get_artifact_by_entity_id(
|
||||
self,
|
||||
*,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
thought_id: str | None = None,
|
||||
task_v2_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
artifacts = await self.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=artifact_type,
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
thought_id=thought_id,
|
||||
task_v2_id=task_v2_id,
|
||||
limit=1,
|
||||
)
|
||||
return artifacts[0] if artifacts else None
|
||||
|
||||
@db_operation("get_artifact")
|
||||
async def get_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(artifact_type=artifact_type)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_artifact_for_run")
|
||||
async def get_artifact_for_run(
|
||||
self,
|
||||
run_id: str,
|
||||
artifact_type: ArtifactType,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
async with self.Session() as session:
|
||||
artifact = (
|
||||
await session.scalars(
|
||||
select(ArtifactModel)
|
||||
.filter(ArtifactModel.run_id == run_id)
|
||||
.filter(ArtifactModel.artifact_type == artifact_type)
|
||||
.filter(ArtifactModel.organization_id == organization_id)
|
||||
.order_by(ArtifactModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
if artifact:
|
||||
return convert_to_artifact(artifact, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_artifact")
|
||||
async def get_latest_artifact(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Artifact | None:
|
||||
artifacts = await self.get_latest_n_artifacts(
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
artifact_types=artifact_types,
|
||||
organization_id=organization_id,
|
||||
n=1,
|
||||
)
|
||||
if artifacts:
|
||||
return artifacts[0]
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_n_artifacts")
|
||||
async def get_latest_n_artifacts(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str | None = None,
|
||||
artifact_types: list[ArtifactType] | None = None,
|
||||
organization_id: str | None = None,
|
||||
n: int = 1,
|
||||
) -> list[Artifact] | None:
|
||||
async with self.Session() as session:
|
||||
artifact_query = select(ArtifactModel).filter_by(task_id=task_id)
|
||||
if organization_id:
|
||||
artifact_query = artifact_query.filter_by(organization_id=organization_id)
|
||||
if step_id:
|
||||
artifact_query = artifact_query.filter_by(step_id=step_id)
|
||||
if artifact_types:
|
||||
artifact_query = artifact_query.filter(ArtifactModel.artifact_type.in_(artifact_types))
|
||||
|
||||
artifacts = (await session.scalars(artifact_query.order_by(ArtifactModel.created_at.desc()))).fetchmany(n)
|
||||
if artifacts:
|
||||
return [convert_to_artifact(artifact, self.debug_enabled) for artifact in artifacts]
|
||||
return None
|
||||
|
||||
@db_operation("delete_task_artifacts")
|
||||
async def delete_task_artifacts(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
ArtifactModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("delete_task_v2_artifacts")
|
||||
async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(ArtifactModel).where(
|
||||
and_(
|
||||
ArtifactModel.observer_cruise_id == task_v2_id,
|
||||
ArtifactModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_action_screenshot_artifact_id")
|
||||
async def update_action_screenshot_artifact_id(
|
||||
self, *, organization_id: str, action_id: str, screenshot_artifact_id: str
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
await session.execute(
|
||||
update(ActionModel)
|
||||
.where(ActionModel.action_id == action_id, ActionModel.organization_id == organization_id)
|
||||
.values(screenshot_artifact_id=screenshot_artifact_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
"""Shared dependencies for mixin modules.
|
||||
|
||||
This module is the single import point for base classes and utilities that
|
||||
sibling mixin modules need (e.g. ``BaseAlchemyDB``, ``read_retry``).
|
||||
Centralising the import here means every mixin can do
|
||||
``from .base import BaseAlchemyDB, read_retry`` instead of reaching up to
|
||||
the parent package, keeping coupling explicit and consistent.
|
||||
"""
|
||||
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry
|
||||
|
||||
__all__ = ["BaseAlchemyDB", "read_retry"]
|
||||
|
|
@ -1,518 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import asc, case, select
|
||||
|
||||
from skyvern.exceptions import BrowserProfileNotFound
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.mixins.base import read_retry
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
BrowserProfileModel,
|
||||
DebugSessionModel,
|
||||
PersistentBrowserSessionModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow_run, serialize_proxy_location
|
||||
from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
|
||||
Extensions,
|
||||
PersistentBrowserSession,
|
||||
PersistentBrowserType,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
import structlog
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class BrowserSessionsMixin:
|
||||
"""Database operations for browser profiles and persistent browser sessions.
|
||||
|
||||
.. deprecated::
|
||||
This mixin is part of the legacy database layer. New code should use the
|
||||
repository classes in ``skyvern.forge.sdk.db.repositories`` instead.
|
||||
|
||||
Cross-mixin migrations already completed:
|
||||
- ``get_last_workflow_run_for_browser_session`` → ``WorkflowRunsRepository``
|
||||
(queries workflow runs as the primary entity, browser session is just a filter).
|
||||
"""
|
||||
|
||||
Session: _SessionFactory
|
||||
|
||||
@db_operation("get_last_workflow_run_for_browser_session")
|
||||
async def get_last_workflow_run_for_browser_session(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
# Deprecated: moved to WorkflowRunsRepository.get_last_workflow_run_for_browser_session
|
||||
# (skyvern/forge/sdk/db/repositories/workflow_runs.py). The primary entity is the
|
||||
# workflow run, not the browser session. This copy remains for legacy mixin compatibility.
|
||||
async with self.Session() as session:
|
||||
# check if there's a queued run
|
||||
query = select(WorkflowRunModel).filter_by(browser_session_id=browser_session_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
|
||||
queue_query = query.filter_by(status=WorkflowRunStatus.queued)
|
||||
queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc())
|
||||
workflow_run = (await session.scalars(queue_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
|
||||
# check if there's a running run
|
||||
running_query = query.filter_by(status=WorkflowRunStatus.running)
|
||||
running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None))
|
||||
running_query = running_query.order_by(WorkflowRunModel.started_at.desc())
|
||||
workflow_run = (await session.scalars(running_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
@db_operation("create_browser_profile")
|
||||
async def create_browser_profile(
|
||||
self,
|
||||
organization_id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
) -> BrowserProfile:
|
||||
async with self.Session() as session:
|
||||
browser_profile = BrowserProfileModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(browser_profile)
|
||||
await session.commit()
|
||||
await session.refresh(browser_profile)
|
||||
return BrowserProfile.model_validate(browser_profile)
|
||||
|
||||
@db_operation("get_browser_profile")
|
||||
async def get_browser_profile(
|
||||
self,
|
||||
profile_id: str,
|
||||
organization_id: str,
|
||||
include_deleted: bool = False,
|
||||
) -> BrowserProfile | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BrowserProfileModel)
|
||||
.filter_by(browser_profile_id=profile_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
if not include_deleted:
|
||||
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
|
||||
browser_profile = (await session.scalars(query)).first()
|
||||
if not browser_profile:
|
||||
return None
|
||||
return BrowserProfile.model_validate(browser_profile)
|
||||
|
||||
@db_operation("list_browser_profiles")
|
||||
async def list_browser_profiles(
|
||||
self,
|
||||
organization_id: str,
|
||||
include_deleted: bool = False,
|
||||
) -> list[BrowserProfile]:
|
||||
async with self.Session() as session:
|
||||
query = select(BrowserProfileModel).filter_by(organization_id=organization_id)
|
||||
if not include_deleted:
|
||||
query = query.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
browser_profiles = await session.scalars(query.order_by(asc(BrowserProfileModel.created_at)))
|
||||
return [BrowserProfile.model_validate(profile) for profile in browser_profiles.all()]
|
||||
|
||||
@db_operation("delete_browser_profile")
|
||||
async def delete_browser_profile(
|
||||
self,
|
||||
profile_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BrowserProfileModel)
|
||||
.filter_by(browser_profile_id=profile_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(BrowserProfileModel.deleted_at.is_(None))
|
||||
)
|
||||
browser_profile = (await session.scalars(query)).first()
|
||||
if not browser_profile:
|
||||
raise BrowserProfileNotFound(profile_id=profile_id, organization_id=organization_id)
|
||||
browser_profile.deleted_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_active_persistent_browser_sessions")
|
||||
async def get_active_persistent_browser_sessions(
|
||||
self,
|
||||
organization_id: str,
|
||||
active_hours: int = 24,
|
||||
) -> list[PersistentBrowserSession]:
|
||||
"""Get all active persistent browser sessions for an organization."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(completed_at=None)
|
||||
.filter(PersistentBrowserSessionModel.created_at > datetime.utcnow() - timedelta(hours=active_hours))
|
||||
)
|
||||
sessions = result.scalars().all()
|
||||
return [PersistentBrowserSession.model_validate(session) for session in sessions]
|
||||
|
||||
@db_operation("get_persistent_browser_sessions_history")
|
||||
async def get_persistent_browser_sessions_history(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
lookback_hours: int = 24 * 7,
|
||||
) -> list[PersistentBrowserSession]:
|
||||
"""Get persistent browser sessions history for an organization."""
|
||||
async with self.Session() as session:
|
||||
open_first = case(
|
||||
(
|
||||
PersistentBrowserSessionModel.status == "running",
|
||||
0, # open
|
||||
),
|
||||
else_=1, # not open
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter(
|
||||
PersistentBrowserSessionModel.created_at > (datetime.utcnow() - timedelta(hours=lookback_hours))
|
||||
)
|
||||
.order_by(
|
||||
open_first.asc(), # open sessions first
|
||||
PersistentBrowserSessionModel.created_at.desc(), # then newest within each group
|
||||
)
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
sessions = result.scalars().all()
|
||||
return [PersistentBrowserSession.model_validate(session) for session in sessions]
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_persistent_browser_session_by_runnable_id", log_errors=False)
|
||||
async def get_persistent_browser_session_by_runnable_id(
|
||||
self, runnable_id: str, organization_id: str | None = None
|
||||
) -> PersistentBrowserSession | None:
|
||||
"""Get a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(runnable_id=runnable_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(completed_at=None)
|
||||
)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
persistent_browser_session = (await session.scalars(query)).first()
|
||||
if persistent_browser_session:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
return None
|
||||
|
||||
@db_operation("get_persistent_browser_session")
|
||||
async def get_persistent_browser_session(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> PersistentBrowserSession | None:
|
||||
"""Get a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
return None
|
||||
|
||||
@db_operation("create_persistent_browser_session")
|
||||
async def create_persistent_browser_session(
|
||||
self,
|
||||
organization_id: str,
|
||||
runnable_type: str | None = None,
|
||||
runnable_id: str | None = None,
|
||||
timeout_minutes: int | None = None,
|
||||
proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL,
|
||||
extensions: list[Extensions] | None = None,
|
||||
browser_type: PersistentBrowserType | None = None,
|
||||
browser_profile_id: str | None = None,
|
||||
) -> PersistentBrowserSession:
|
||||
"""Create a new persistent browser session."""
|
||||
extensions_str: list[str] | None = (
|
||||
[extension.value for extension in extensions] if extensions is not None else None
|
||||
)
|
||||
async with self.Session() as session:
|
||||
browser_session = PersistentBrowserSessionModel(
|
||||
organization_id=organization_id,
|
||||
runnable_type=runnable_type,
|
||||
runnable_id=runnable_id,
|
||||
timeout_minutes=timeout_minutes,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
extensions=extensions_str,
|
||||
browser_type=browser_type.value if browser_type else None,
|
||||
browser_profile_id=browser_profile_id,
|
||||
)
|
||||
session.add(browser_session)
|
||||
await session.commit()
|
||||
await session.refresh(browser_session)
|
||||
return PersistentBrowserSession.model_validate(browser_session)
|
||||
|
||||
@db_operation("update_persistent_browser_session")
|
||||
async def update_persistent_browser_session(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
timeout_minutes: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
completed_at: datetime | None = None,
|
||||
started_at: datetime | None = None,
|
||||
) -> PersistentBrowserSession:
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if not persistent_browser_session:
|
||||
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
|
||||
|
||||
if status:
|
||||
persistent_browser_session.status = status
|
||||
if timeout_minutes:
|
||||
persistent_browser_session.timeout_minutes = timeout_minutes
|
||||
if completed_at:
|
||||
persistent_browser_session.completed_at = completed_at
|
||||
if started_at:
|
||||
persistent_browser_session.started_at = started_at
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
|
||||
@db_operation("set_persistent_browser_session_browser_address")
|
||||
async def set_persistent_browser_session_browser_address(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
browser_address: str | None,
|
||||
ip_address: str | None,
|
||||
ecs_task_arn: str | None,
|
||||
organization_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set the browser address for a persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
if browser_address:
|
||||
persistent_browser_session.browser_address = browser_address
|
||||
# once the address is set, the session is started
|
||||
persistent_browser_session.started_at = datetime.utcnow()
|
||||
if ip_address:
|
||||
persistent_browser_session.ip_address = ip_address
|
||||
if ecs_task_arn:
|
||||
persistent_browser_session.ecs_task_arn = ecs_task_arn
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
|
||||
|
||||
@db_operation("update_persistent_browser_session_compute_cost")
|
||||
async def update_persistent_browser_session_compute_cost(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str,
|
||||
instance_type: str,
|
||||
vcpu_millicores: int,
|
||||
memory_mb: int,
|
||||
duration_ms: int,
|
||||
compute_cost: float,
|
||||
) -> None:
|
||||
"""Update the compute cost fields for a persistent browser session"""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.instance_type = instance_type
|
||||
persistent_browser_session.vcpu_millicores = vcpu_millicores
|
||||
persistent_browser_session.memory_mb = memory_mb
|
||||
persistent_browser_session.duration_ms = duration_ms
|
||||
persistent_browser_session.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("mark_persistent_browser_session_deleted")
|
||||
async def mark_persistent_browser_session_deleted(self, session_id: str, organization_id: str) -> None:
|
||||
"""Mark a persistent browser session as deleted."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.deleted_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("occupy_persistent_browser_session")
|
||||
async def occupy_persistent_browser_session(
|
||||
self, session_id: str, runnable_type: str, runnable_id: str, organization_id: str
|
||||
) -> None:
|
||||
"""Occupy a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.runnable_type = runnable_type
|
||||
persistent_browser_session.runnable_id = runnable_id
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("release_persistent_browser_session")
|
||||
async def release_persistent_browser_session(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str,
|
||||
) -> PersistentBrowserSession:
|
||||
"""Release a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
persistent_browser_session.runnable_type = None
|
||||
persistent_browser_session.runnable_id = None
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
else:
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("close_persistent_browser_session")
|
||||
async def close_persistent_browser_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession:
|
||||
"""Close a specific persistent browser session."""
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if persistent_browser_session:
|
||||
if persistent_browser_session.completed_at:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
persistent_browser_session.completed_at = datetime.utcnow()
|
||||
persistent_browser_session.status = "completed"
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@db_operation("get_all_active_persistent_browser_sessions")
|
||||
async def get_all_active_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
"""Get all active persistent browser sessions across all organizations."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(select(PersistentBrowserSessionModel).filter_by(deleted_at=None))
|
||||
return result.scalars().all()
|
||||
|
||||
@db_operation("archive_browser_session_address")
|
||||
async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None:
|
||||
"""Suffix browser_address with a unique tag so the unique constraint
|
||||
no longer blocks new sessions that reuse the same local address."""
|
||||
async with self.Session() as session:
|
||||
row = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not row or not row.browser_address:
|
||||
return
|
||||
if "::closed::" in row.browser_address:
|
||||
return
|
||||
|
||||
row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}"
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_uncompleted_persistent_browser_sessions")
|
||||
async def get_uncompleted_persistent_browser_sessions(self) -> list[PersistentBrowserSessionModel]:
|
||||
"""Get all browser sessions that have not been completed or deleted."""
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@db_operation("get_debug_session_by_browser_session_id")
|
||||
async def get_debug_session_by_browser_session_id(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
model = (await session.scalars(query)).first()
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import CredentialModel, OrganizationBitwardenCollectionModel
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class CredentialsMixin:
|
||||
"""Database operations for credential and Bitwarden collection management."""
|
||||
|
||||
Session: _SessionFactory
|
||||
|
||||
@db_operation("create_credential")
|
||||
async def create_credential(
|
||||
self,
|
||||
organization_id: str,
|
||||
name: str,
|
||||
vault_type: CredentialVaultType,
|
||||
item_id: str,
|
||||
credential_type: CredentialType,
|
||||
username: str | None,
|
||||
totp_type: str,
|
||||
card_last4: str | None,
|
||||
card_brand: str | None,
|
||||
totp_identifier: str | None = None,
|
||||
secret_label: str | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = CredentialModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=credential_type,
|
||||
username=username,
|
||||
totp_type=totp_type,
|
||||
totp_identifier=totp_identifier,
|
||||
card_last4=card_last4,
|
||||
card_brand=card_brand,
|
||||
secret_label=secret_label,
|
||||
)
|
||||
session.add(credential)
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("get_credential")
|
||||
async def get_credential(self, credential_id: str, organization_id: str) -> Credential | None:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
if credential:
|
||||
return Credential.model_validate(credential)
|
||||
return None
|
||||
|
||||
@db_operation("get_credentials")
|
||||
async def get_credentials(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
vault_type: str | None = None,
|
||||
) -> list[Credential]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(CredentialModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
if vault_type is not None:
|
||||
query = query.filter(CredentialModel.vault_type == vault_type)
|
||||
credentials = (
|
||||
await session.scalars(
|
||||
query.order_by(CredentialModel.created_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
)
|
||||
).all()
|
||||
return [Credential.model_validate(credential) for credential in credentials]
|
||||
|
||||
@db_operation("update_credential")
|
||||
async def update_credential(
|
||||
self,
|
||||
credential_id: str,
|
||||
organization_id: str,
|
||||
name: str | None = None,
|
||||
browser_profile_id: str | None = None,
|
||||
tested_url: str | None = None,
|
||||
user_context: str | None = None,
|
||||
save_browser_session_intent: bool | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
if name is not None:
|
||||
credential.name = name
|
||||
if browser_profile_id is not None:
|
||||
credential.browser_profile_id = browser_profile_id
|
||||
if tested_url is not None:
|
||||
credential.tested_url = tested_url
|
||||
if user_context is not None:
|
||||
credential.user_context = user_context
|
||||
if save_browser_session_intent is not None:
|
||||
credential.save_browser_session_intent = save_browser_session_intent
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("update_credential_vault_data")
|
||||
async def update_credential_vault_data(
|
||||
self,
|
||||
credential_id: str,
|
||||
organization_id: str,
|
||||
item_id: str,
|
||||
name: str,
|
||||
credential_type: CredentialType,
|
||||
username: str | None = None,
|
||||
totp_type: str = "none",
|
||||
totp_identifier: str | None = None,
|
||||
card_last4: str | None = None,
|
||||
card_brand: str | None = None,
|
||||
secret_label: str | None = None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(CredentialModel.deleted_at.is_(None))
|
||||
.with_for_update()
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
credential.item_id = item_id
|
||||
credential.name = name
|
||||
credential.credential_type = credential_type
|
||||
credential.username = username
|
||||
credential.totp_type = totp_type
|
||||
credential.totp_identifier = totp_identifier
|
||||
credential.card_last4 = card_last4
|
||||
credential.card_brand = card_brand
|
||||
credential.secret_label = secret_label
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return Credential.model_validate(credential)
|
||||
|
||||
@db_operation("delete_credential")
|
||||
async def delete_credential(self, credential_id: str, organization_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
credential = (
|
||||
await session.scalars(
|
||||
select(CredentialModel)
|
||||
.filter_by(credential_id=credential_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not credential:
|
||||
raise NotFoundError(f"Credential {credential_id} not found")
|
||||
credential.deleted_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return None
|
||||
|
||||
@db_operation("create_organization_bitwarden_collection")
|
||||
async def create_organization_bitwarden_collection(
|
||||
self,
|
||||
organization_id: str,
|
||||
collection_id: str,
|
||||
) -> OrganizationBitwardenCollection:
|
||||
async with self.Session() as session:
|
||||
organization_bitwarden_collection = OrganizationBitwardenCollectionModel(
|
||||
organization_id=organization_id, collection_id=collection_id
|
||||
)
|
||||
session.add(organization_bitwarden_collection)
|
||||
await session.commit()
|
||||
await session.refresh(organization_bitwarden_collection)
|
||||
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
|
||||
|
||||
@db_operation("get_organization_bitwarden_collection")
|
||||
async def get_organization_bitwarden_collection(
|
||||
self,
|
||||
organization_id: str,
|
||||
) -> OrganizationBitwardenCollection | None:
|
||||
async with self.Session() as session:
|
||||
organization_bitwarden_collection = (
|
||||
await session.scalars(
|
||||
select(OrganizationBitwardenCollectionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if organization_bitwarden_collection:
|
||||
return OrganizationBitwardenCollection.model_validate(organization_bitwarden_collection)
|
||||
return None
|
||||
|
|
@ -1,259 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
BlockRunModel,
|
||||
DebugSessionModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession, DebugSessionRun
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class DebugMixin:
|
||||
Session: _SessionFactory
|
||||
"""Database operations for debug sessions and block runs."""
|
||||
|
||||
@db_operation("get_debug_session")
|
||||
async def get_debug_session(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
debug_session = (
|
||||
await session.scalars(
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
.order_by(DebugSessionModel.created_at.desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not debug_session:
|
||||
return None
|
||||
|
||||
return DebugSession.model_validate(debug_session)
|
||||
|
||||
@db_operation("get_latest_block_run")
|
||||
async def get_latest_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
) -> BlockRun | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BlockRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(block_label=block_label)
|
||||
.order_by(BlockRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return BlockRun.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_latest_completed_block_run")
|
||||
async def get_latest_completed_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> BlockRun | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(BlockRunModel)
|
||||
.join(WorkflowRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.filter(BlockRunModel.organization_id == organization_id)
|
||||
.filter(BlockRunModel.user_id == user_id)
|
||||
.filter(BlockRunModel.block_label == block_label)
|
||||
.filter(WorkflowRunModel.status == WorkflowRunStatus.completed)
|
||||
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.order_by(BlockRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return BlockRun.model_validate(model) if model else None
|
||||
|
||||
@db_operation("create_block_run")
|
||||
async def create_block_run(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
block_label: str,
|
||||
output_parameter_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
block_run = BlockRunModel(
|
||||
organization_id=organization_id,
|
||||
user_id=user_id,
|
||||
block_label=block_label,
|
||||
output_parameter_id=output_parameter_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
session.add(block_run)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_latest_debug_session_for_user")
|
||||
async def get_latest_debug_session_for_user(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
.order_by(DebugSessionModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_debug_session_by_id")
|
||||
async def get_debug_session_by_id(
|
||||
self,
|
||||
debug_session_id: str,
|
||||
organization_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(debug_session_id=debug_session_id)
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_workflow_runs_by_debug_session_id")
|
||||
async def get_workflow_runs_by_debug_session_id(
|
||||
self,
|
||||
debug_session_id: str,
|
||||
organization_id: str,
|
||||
) -> list[DebugSessionRun]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowRunModel, BlockRunModel)
|
||||
.join(BlockRunModel, BlockRunModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.debug_session_id == debug_session_id)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
||||
results = (await session.execute(query)).all()
|
||||
|
||||
debug_session_runs = []
|
||||
for workflow_run, block_run in results:
|
||||
debug_session_runs.append(
|
||||
DebugSessionRun(
|
||||
ai_fallback=workflow_run.ai_fallback,
|
||||
block_label=block_run.block_label,
|
||||
browser_session_id=workflow_run.browser_session_id,
|
||||
code_gen=workflow_run.code_gen,
|
||||
debug_session_id=workflow_run.debug_session_id,
|
||||
failure_reason=workflow_run.failure_reason,
|
||||
output_parameter_id=block_run.output_parameter_id,
|
||||
run_with=workflow_run.run_with,
|
||||
script_run_id=workflow_run.script_run.get("script_run_id") if workflow_run.script_run else None,
|
||||
status=workflow_run.status,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_permanent_id=workflow_run.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
created_at=workflow_run.created_at,
|
||||
queued_at=workflow_run.queued_at,
|
||||
started_at=workflow_run.started_at,
|
||||
finished_at=workflow_run.finished_at,
|
||||
)
|
||||
)
|
||||
|
||||
return debug_session_runs
|
||||
|
||||
@db_operation("complete_debug_sessions")
|
||||
async def complete_debug_sessions(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
user_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
) -> list[DebugSession]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
)
|
||||
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if workflow_permanent_id:
|
||||
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
|
||||
models = (await session.scalars(query)).all()
|
||||
|
||||
for model in models:
|
||||
model.status = "completed"
|
||||
|
||||
debug_sessions = [DebugSession.model_validate(model) for model in models]
|
||||
|
||||
await session.commit()
|
||||
|
||||
return debug_sessions
|
||||
|
||||
@db_operation("create_debug_session")
|
||||
async def create_debug_session(
|
||||
self,
|
||||
*,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
vnc_streaming_supported: bool,
|
||||
) -> DebugSession:
|
||||
async with self.Session() as session:
|
||||
debug_session = DebugSessionModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
user_id=user_id,
|
||||
browser_session_id=browser_session_id,
|
||||
vnc_streaming_supported=vnc_streaming_supported,
|
||||
status="created",
|
||||
)
|
||||
|
||||
session.add(debug_session)
|
||||
await session.commit()
|
||||
await session.refresh(debug_session)
|
||||
|
||||
return DebugSession.model_validate(debug_session)
|
||||
|
|
@ -1,368 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import func, or_, select, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.models import FolderModel, WorkflowModel
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class FoldersMixin:
|
||||
"""Database operations for folder management."""
|
||||
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
|
||||
async def get_workflow_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
version: int | None = None,
|
||||
ignore_version: int | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@db_operation("create_folder")
|
||||
async def create_folder(
|
||||
self,
|
||||
organization_id: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
) -> FolderModel:
|
||||
"""Create a new folder."""
|
||||
async with self.Session() as session:
|
||||
folder = FolderModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
)
|
||||
session.add(folder)
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
@db_operation("get_folders")
|
||||
async def get_folders(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
search_query: str | None = None,
|
||||
) -> list[FolderModel]:
|
||||
"""Get all folders for an organization with pagination and optional search."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel).filter_by(organization_id=organization_id).filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
stmt = stmt.filter(
|
||||
or_(
|
||||
FolderModel.title.ilike(search_pattern),
|
||||
FolderModel.description.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(FolderModel.modified_at.desc())
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@db_operation("get_folder")
|
||||
async def get_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> FolderModel | None:
|
||||
"""Get a folder by ID."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@db_operation("update_folder")
|
||||
async def update_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> FolderModel | None:
|
||||
"""Update a folder's title or description."""
|
||||
async with self.Session() as session:
|
||||
stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
folder = result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
if title is not None:
|
||||
folder.title = title
|
||||
if description is not None:
|
||||
folder.description = description
|
||||
|
||||
folder.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
@db_operation("get_workflow_permanent_ids_in_folder")
|
||||
async def get_workflow_permanent_ids_in_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> list[str]:
|
||||
"""Get workflow permanent IDs (latest versions only) in a folder."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Get workflow_permanent_ids where the latest version is in this folder
|
||||
stmt = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@db_operation("soft_delete_folder")
|
||||
async def soft_delete_folder(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
delete_workflows: bool = False,
|
||||
) -> bool:
|
||||
"""Soft delete a folder. Optionally delete all workflows in the folder."""
|
||||
async with self.Session() as session:
|
||||
# Check if folder exists
|
||||
folder_stmt = (
|
||||
select(FolderModel)
|
||||
.filter_by(folder_id=folder_id, organization_id=organization_id)
|
||||
.filter(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
folder_result = await session.execute(folder_stmt)
|
||||
folder = folder_result.scalar_one_or_none()
|
||||
if not folder:
|
||||
return False
|
||||
|
||||
# If delete_workflows is True, delete all workflows in the folder
|
||||
if delete_workflows:
|
||||
# Get workflow permanent IDs in the folder (inline logic)
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
workflow_permanent_ids_stmt = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(workflow_permanent_ids_stmt)
|
||||
workflow_permanent_ids = list(result.scalars().all())
|
||||
|
||||
# Soft delete all workflows with these permanent IDs in a single bulk update
|
||||
if workflow_permanent_ids:
|
||||
update_workflows_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.utcnow())
|
||||
)
|
||||
await session.execute(update_workflows_query)
|
||||
else:
|
||||
# Just remove folder_id from all workflows in this folder
|
||||
update_workflows_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.values(folder_id=None, modified_at=datetime.utcnow())
|
||||
)
|
||||
await session.execute(update_workflows_query)
|
||||
|
||||
# Soft delete the folder
|
||||
folder.deleted_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@db_operation("get_folder_workflow_count")
|
||||
async def get_folder_workflow_count(
|
||||
self,
|
||||
folder_id: str,
|
||||
organization_id: str,
|
||||
) -> int:
|
||||
"""Get the count of workflows (latest versions only) in a folder."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow (same pattern as get_workflows_by_organization_id)
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Count workflows where the latest version is in this folder
|
||||
stmt = (
|
||||
select(func.count(WorkflowModel.workflow_permanent_id))
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id == folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
@db_operation("get_folder_workflow_counts_batch")
|
||||
async def get_folder_workflow_counts_batch(
|
||||
self,
|
||||
folder_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> dict[str, int]:
|
||||
"""Get workflow counts for multiple folders in a single query."""
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version for each workflow
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Count workflows grouped by folder_id
|
||||
stmt = (
|
||||
select(
|
||||
WorkflowModel.folder_id,
|
||||
func.count(WorkflowModel.workflow_permanent_id).label("count"),
|
||||
)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.where(WorkflowModel.folder_id.in_(folder_ids))
|
||||
.group_by(WorkflowModel.folder_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to dict; folders with no workflows will be absent from the result
|
||||
return {row.folder_id: row.count for row in rows}
|
||||
|
||||
@db_operation("update_workflow_folder")
|
||||
async def update_workflow_folder(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
folder_id: str | None,
|
||||
) -> Workflow | None:
|
||||
"""Update folder assignment for the latest version of a workflow."""
|
||||
# Get the latest version of the workflow
|
||||
latest_workflow = await self.get_workflow_by_permanent_id(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not latest_workflow:
|
||||
return None
|
||||
|
||||
async with self.Session() as session:
|
||||
# Validate folder exists in-org if folder_id is provided
|
||||
if folder_id:
|
||||
stmt = (
|
||||
select(FolderModel.folder_id)
|
||||
.where(FolderModel.folder_id == folder_id)
|
||||
.where(FolderModel.organization_id == organization_id)
|
||||
.where(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
if (await session.scalar(stmt)) is None:
|
||||
raise ValueError(f"Folder {folder_id} not found")
|
||||
|
||||
workflow_model = await session.get(WorkflowModel, latest_workflow.workflow_id)
|
||||
if workflow_model:
|
||||
workflow_model.folder_id = folder_id
|
||||
workflow_model.modified_at = datetime.utcnow()
|
||||
|
||||
# Update folder's modified_at in the same transaction
|
||||
if folder_id:
|
||||
folder_model = await session.get(FolderModel, folder_id)
|
||||
if folder_model:
|
||||
folder_model.modified_at = datetime.utcnow()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workflow_model)
|
||||
|
||||
return convert_to_workflow(workflow_model, self.debug_enabled)
|
||||
return None
|
||||
|
|
@ -1,574 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.mixins.base import read_retry
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
TaskV2Model,
|
||||
ThoughtModel,
|
||||
WorkflowRunBlockModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_task_v2,
|
||||
convert_to_workflow_run_block,
|
||||
serialize_proxy_location,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunEngine
|
||||
from skyvern.schemas.workflows import BlockStatus, BlockType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class ObserverMixin:
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
"""Database operations for observer tasks (TaskV2), thoughts, and workflow run blocks."""
|
||||
|
||||
# Cross-mixin method stubs (provided by TasksMixin at runtime)
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
raise NotImplementedError
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_task_v2", log_errors=False)
|
||||
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
|
||||
async with self.Session() as session:
|
||||
if task_v2 := (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("delete_thoughts")
|
||||
async def delete_thoughts(self, task_v2_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(ThoughtModel).where(
|
||||
and_(
|
||||
ThoughtModel.observer_cruise_id == task_v2_id,
|
||||
ThoughtModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_task_v2_by_workflow_run_id")
|
||||
async def get_task_v2_by_workflow_run_id(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> TaskV2 | None:
|
||||
async with self.Session() as session:
|
||||
if task_v2 := (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_thought")
|
||||
async def get_thought(self, thought_id: str, organization_id: str | None = None) -> Thought | None:
|
||||
async with self.Session() as session:
|
||||
if thought := (
|
||||
await session.scalars(
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_thought_id=thought_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return Thought.model_validate(thought)
|
||||
return None
|
||||
|
||||
@db_operation("get_thoughts")
|
||||
async def get_thoughts(
|
||||
self,
|
||||
*,
|
||||
task_v2_id: str,
|
||||
thought_types: list[ThoughtType],
|
||||
organization_id: str,
|
||||
) -> list[Thought]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(ThoughtModel.created_at)
|
||||
)
|
||||
if thought_types:
|
||||
query = query.filter(ThoughtModel.observer_thought_type.in_(thought_types))
|
||||
thoughts = (await session.scalars(query)).all()
|
||||
return [Thought.model_validate(thought) for thought in thoughts]
|
||||
|
||||
@db_operation("create_task_v2")
|
||||
async def create_task_v2(
|
||||
self,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
prompt: str | None = None,
|
||||
url: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
totp_identifier: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
extracted_information_schema: dict | list | str | None = None,
|
||||
error_code_mapping: dict | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_address: str | None = None,
|
||||
run_with: str | None = None,
|
||||
) -> TaskV2:
|
||||
async with self.Session() as session:
|
||||
new_task_v2 = TaskV2Model(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_verification_url,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
organization_id=organization_id,
|
||||
model=model,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_address=browser_address,
|
||||
run_with=run_with,
|
||||
)
|
||||
session.add(new_task_v2)
|
||||
await session.commit()
|
||||
await session.refresh(new_task_v2)
|
||||
return convert_to_task_v2(new_task_v2, debug_enabled=self.debug_enabled)
|
||||
|
||||
@db_operation("create_thought")
|
||||
async def create_thought(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
user_input: str | None = None,
|
||||
observation: str | None = None,
|
||||
thought: str | None = None,
|
||||
answer: str | None = None,
|
||||
thought_scenario: str | None = None,
|
||||
thought_type: str = ThoughtType.plan,
|
||||
output: dict[str, Any] | None = None,
|
||||
input_token_count: int | None = None,
|
||||
output_token_count: int | None = None,
|
||||
reasoning_token_count: int | None = None,
|
||||
cached_token_count: int | None = None,
|
||||
thought_cost: float | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Thought:
|
||||
async with self.Session() as session:
|
||||
new_thought = ThoughtModel(
|
||||
observer_cruise_id=task_v2_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
user_input=user_input,
|
||||
observation=observation,
|
||||
thought=thought,
|
||||
answer=answer,
|
||||
observer_thought_scenario=thought_scenario,
|
||||
observer_thought_type=thought_type,
|
||||
output=output,
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
reasoning_token_count=reasoning_token_count,
|
||||
cached_token_count=cached_token_count,
|
||||
thought_cost=thought_cost,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(new_thought)
|
||||
await session.commit()
|
||||
await session.refresh(new_thought)
|
||||
return Thought.model_validate(new_thought)
|
||||
|
||||
@db_operation("update_thought")
|
||||
async def update_thought(
|
||||
self,
|
||||
thought_id: str,
|
||||
workflow_run_block_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
observation: str | None = None,
|
||||
thought: str | None = None,
|
||||
answer: str | None = None,
|
||||
output: dict[str, Any] | None = None,
|
||||
input_token_count: int | None = None,
|
||||
output_token_count: int | None = None,
|
||||
reasoning_token_count: int | None = None,
|
||||
cached_token_count: int | None = None,
|
||||
thought_cost: float | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Thought:
|
||||
async with self.Session() as session:
|
||||
thought_obj = (
|
||||
await session.scalars(
|
||||
select(ThoughtModel)
|
||||
.filter_by(observer_thought_id=thought_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if thought_obj:
|
||||
if workflow_run_block_id:
|
||||
thought_obj.workflow_run_block_id = workflow_run_block_id
|
||||
if workflow_run_id:
|
||||
thought_obj.workflow_run_id = workflow_run_id
|
||||
if workflow_id:
|
||||
thought_obj.workflow_id = workflow_id
|
||||
if workflow_permanent_id:
|
||||
thought_obj.workflow_permanent_id = workflow_permanent_id
|
||||
if observation:
|
||||
thought_obj.observation = observation
|
||||
if thought:
|
||||
thought_obj.thought = thought
|
||||
if answer:
|
||||
thought_obj.answer = answer
|
||||
if output:
|
||||
thought_obj.output = output
|
||||
if input_token_count:
|
||||
thought_obj.input_token_count = input_token_count
|
||||
if output_token_count:
|
||||
thought_obj.output_token_count = output_token_count
|
||||
if reasoning_token_count:
|
||||
thought_obj.reasoning_token_count = reasoning_token_count
|
||||
if cached_token_count:
|
||||
thought_obj.cached_token_count = cached_token_count
|
||||
if thought_cost:
|
||||
thought_obj.thought_cost = thought_cost
|
||||
await session.commit()
|
||||
await session.refresh(thought_obj)
|
||||
return Thought.model_validate(thought_obj)
|
||||
raise NotFoundError(f"Thought {thought_id}")
|
||||
|
||||
@db_operation("update_task_v2")
|
||||
async def update_task_v2(
|
||||
self,
|
||||
task_v2_id: str,
|
||||
status: TaskV2Status | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
url: str | None = None,
|
||||
prompt: str | None = None,
|
||||
summary: str | None = None,
|
||||
output: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
failure_category: list[dict[str, Any]] | None = None,
|
||||
) -> TaskV2:
|
||||
async with self.Session() as session:
|
||||
task_v2 = (
|
||||
await session.scalars(
|
||||
select(TaskV2Model)
|
||||
.filter_by(observer_cruise_id=task_v2_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if task_v2:
|
||||
if status:
|
||||
task_v2.status = status
|
||||
if status == TaskV2Status.queued and task_v2.queued_at is None:
|
||||
task_v2.queued_at = datetime.utcnow()
|
||||
if status == TaskV2Status.running and task_v2.started_at is None:
|
||||
task_v2.started_at = datetime.utcnow()
|
||||
if status.is_final() and task_v2.finished_at is None:
|
||||
task_v2.finished_at = datetime.utcnow()
|
||||
if workflow_run_id:
|
||||
task_v2.workflow_run_id = workflow_run_id
|
||||
if workflow_id:
|
||||
task_v2.workflow_id = workflow_id
|
||||
if workflow_permanent_id:
|
||||
task_v2.workflow_permanent_id = workflow_permanent_id
|
||||
if url:
|
||||
task_v2.url = url
|
||||
if prompt:
|
||||
task_v2.prompt = prompt
|
||||
if summary:
|
||||
task_v2.summary = summary
|
||||
if output:
|
||||
task_v2.output = output
|
||||
if webhook_failure_reason is not None:
|
||||
task_v2.webhook_failure_reason = webhook_failure_reason
|
||||
if failure_category is not None:
|
||||
task_v2.failure_category = failure_category
|
||||
await session.commit()
|
||||
await session.refresh(task_v2)
|
||||
return convert_to_task_v2(task_v2, debug_enabled=self.debug_enabled)
|
||||
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
|
||||
|
||||
@db_operation("create_workflow_run_block")
|
||||
async def create_workflow_run_block(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
parent_workflow_run_block_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
label: str | None = None,
|
||||
block_type: BlockType | None = None,
|
||||
status: BlockStatus = BlockStatus.running,
|
||||
output: dict | list | str | None = None,
|
||||
continue_on_failure: bool = False,
|
||||
engine: RunEngine | None = None,
|
||||
current_value: str | None = None,
|
||||
current_index: int | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
new_workflow_run_block = WorkflowRunBlockModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
parent_workflow_run_block_id=parent_workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
task_id=task_id,
|
||||
label=label,
|
||||
block_type=block_type,
|
||||
status=status,
|
||||
output=output,
|
||||
continue_on_failure=continue_on_failure,
|
||||
engine=engine,
|
||||
current_value=current_value,
|
||||
current_index=current_index,
|
||||
)
|
||||
session.add(new_workflow_run_block)
|
||||
await session.commit()
|
||||
await session.refresh(new_workflow_run_block)
|
||||
|
||||
task = None
|
||||
if task_id:
|
||||
task = await self.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(new_workflow_run_block, task=task)
|
||||
|
||||
@db_operation("delete_workflow_run_blocks")
|
||||
async def delete_workflow_run_blocks(self, workflow_run_id: str, organization_id: str | None = None) -> None:
|
||||
async with self.Session() as session:
|
||||
stmt = delete(WorkflowRunBlockModel).where(
|
||||
and_(
|
||||
WorkflowRunBlockModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowRunBlockModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_workflow_run_block")
|
||||
async def update_workflow_run_block(
|
||||
self,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
status: BlockStatus | None = None,
|
||||
output: dict | list | str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
task_id: str | None = None,
|
||||
loop_values: list | None = None,
|
||||
current_value: str | None = None,
|
||||
current_index: int | None = None,
|
||||
recipients: list[str] | None = None,
|
||||
attachments: list[str] | None = None,
|
||||
subject: str | None = None,
|
||||
body: str | None = None,
|
||||
prompt: str | None = None,
|
||||
wait_sec: int | None = None,
|
||||
description: str | None = None,
|
||||
block_workflow_run_id: str | None = None,
|
||||
engine: str | None = None,
|
||||
# HTTP request block parameters
|
||||
http_request_method: str | None = None,
|
||||
http_request_url: str | None = None,
|
||||
http_request_headers: dict[str, str] | None = None,
|
||||
http_request_body: dict[str, Any] | None = None,
|
||||
http_request_parameters: dict[str, Any] | None = None,
|
||||
http_request_timeout: int | None = None,
|
||||
http_request_follow_redirects: bool | None = None,
|
||||
ai_fallback_triggered: bool | None = None,
|
||||
# block-level error codes (e.g. ["FILE_PARSER_ERROR"])
|
||||
error_codes: list[str] | None = None,
|
||||
# human interaction block
|
||||
instructions: str | None = None,
|
||||
positive_descriptor: str | None = None,
|
||||
negative_descriptor: str | None = None,
|
||||
# conditional block
|
||||
executed_branch_id: str | None = None,
|
||||
executed_branch_expression: str | None = None,
|
||||
executed_branch_result: bool | None = None,
|
||||
executed_branch_next_block: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
if status:
|
||||
workflow_run_block.status = status
|
||||
if output:
|
||||
workflow_run_block.output = output
|
||||
if task_id:
|
||||
workflow_run_block.task_id = task_id
|
||||
if failure_reason:
|
||||
workflow_run_block.failure_reason = failure_reason
|
||||
# Use `is not None` instead of truthiness checks so that falsy
|
||||
# values like current_index=0, empty loop_values=[], or
|
||||
# current_value="" are correctly persisted. Without this,
|
||||
# the first loop iteration (index 0) loses its metadata.
|
||||
if loop_values is not None:
|
||||
workflow_run_block.loop_values = loop_values
|
||||
if current_value is not None:
|
||||
workflow_run_block.current_value = current_value
|
||||
if current_index is not None:
|
||||
workflow_run_block.current_index = current_index
|
||||
if recipients:
|
||||
workflow_run_block.recipients = recipients
|
||||
if attachments:
|
||||
workflow_run_block.attachments = attachments
|
||||
if subject:
|
||||
workflow_run_block.subject = subject
|
||||
if body:
|
||||
workflow_run_block.body = body
|
||||
if prompt:
|
||||
workflow_run_block.prompt = prompt
|
||||
if wait_sec:
|
||||
workflow_run_block.wait_sec = wait_sec
|
||||
if description:
|
||||
workflow_run_block.description = description
|
||||
if block_workflow_run_id:
|
||||
workflow_run_block.block_workflow_run_id = block_workflow_run_id
|
||||
if engine:
|
||||
workflow_run_block.engine = engine
|
||||
# HTTP request block fields
|
||||
if http_request_method:
|
||||
workflow_run_block.http_request_method = http_request_method
|
||||
if http_request_url:
|
||||
workflow_run_block.http_request_url = http_request_url
|
||||
if http_request_headers:
|
||||
workflow_run_block.http_request_headers = http_request_headers
|
||||
if http_request_body:
|
||||
workflow_run_block.http_request_body = http_request_body
|
||||
if http_request_parameters:
|
||||
workflow_run_block.http_request_parameters = http_request_parameters
|
||||
if http_request_timeout:
|
||||
workflow_run_block.http_request_timeout = http_request_timeout
|
||||
if http_request_follow_redirects is not None:
|
||||
workflow_run_block.http_request_follow_redirects = http_request_follow_redirects
|
||||
if ai_fallback_triggered is not None:
|
||||
workflow_run_block.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
|
||||
if error_codes is not None:
|
||||
workflow_run_block.error_codes = error_codes
|
||||
# human interaction block fields
|
||||
if instructions:
|
||||
workflow_run_block.instructions = instructions
|
||||
if positive_descriptor:
|
||||
workflow_run_block.positive_descriptor = positive_descriptor
|
||||
if negative_descriptor:
|
||||
workflow_run_block.negative_descriptor = negative_descriptor
|
||||
# conditional block fields
|
||||
if executed_branch_id:
|
||||
workflow_run_block.executed_branch_id = executed_branch_id
|
||||
if executed_branch_expression is not None:
|
||||
workflow_run_block.executed_branch_expression = executed_branch_expression
|
||||
if executed_branch_result is not None:
|
||||
workflow_run_block.executed_branch_result = executed_branch_result
|
||||
if executed_branch_next_block is not None:
|
||||
workflow_run_block.executed_branch_next_block = executed_branch_next_block
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_block)
|
||||
else:
|
||||
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
task = await self.get_task(task_id, organization_id=workflow_run_block.organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
|
||||
@db_operation("get_workflow_run_block")
|
||||
async def get_workflow_run_block(
|
||||
self,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_block_id=workflow_run_block_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
task = await self.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
|
||||
|
||||
@db_operation("get_workflow_run_block_by_task_id")
|
||||
async def get_workflow_run_block_by_task_id(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRunBlock:
|
||||
async with self.Session() as session:
|
||||
workflow_run_block = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run_block:
|
||||
task = None
|
||||
task_id = workflow_run_block.task_id
|
||||
if task_id:
|
||||
task = await self.get_task(task_id, organization_id=organization_id)
|
||||
return convert_to_workflow_run_block(workflow_run_block, task=task)
|
||||
raise NotFoundError(f"WorkflowRunBlock not found by {task_id}")
|
||||
|
||||
@db_operation("get_workflow_run_blocks")
|
||||
async def get_workflow_run_blocks(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRunBlock]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_blocks = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(WorkflowRunBlockModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
tasks = await self.get_tasks_by_workflow_run_id(workflow_run_id)
|
||||
tasks_dict = {task.task_id: task for task in tasks}
|
||||
return [
|
||||
convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id))
|
||||
for workflow_run_block in workflow_run_blocks
|
||||
]
|
||||
|
|
@ -1,380 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Literal, overload
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.mixins.base import read_retry
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
OrganizationAuthTokenModel,
|
||||
OrganizationModel,
|
||||
TaskModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_organization,
|
||||
convert_to_organization_auth_token,
|
||||
)
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
BitwardenCredential,
|
||||
BitwardenOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class OrganizationsMixin:
|
||||
Session: _SessionFactory
|
||||
"""Database operations for organization and auth-token management."""
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_active_verification_requests", log_errors=False)
|
||||
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||
"""Return active 2FA verification requests for an organization.
|
||||
|
||||
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||
Used to provide initial state when a WebSocket notification client connects.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
async with self.Session() as session:
|
||||
# Tasks waiting for verification (exclude finalized tasks)
|
||||
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||
task_rows = (
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter_by(workflow_run_id=None)
|
||||
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for t in task_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": t.task_id,
|
||||
"workflow_run_id": None,
|
||||
"verification_code_identifier": t.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
t.verification_code_polling_started_at.isoformat()
|
||||
if t.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
# Workflow runs waiting for verification (exclude finalized runs)
|
||||
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||
wr_rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for wr in wr_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": None,
|
||||
"workflow_run_id": wr.workflow_run_id,
|
||||
"verification_code_identifier": wr.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
wr.verification_code_polling_started_at.isoformat()
|
||||
if wr.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
@db_operation("get_all_organizations")
|
||||
async def get_all_organizations(self) -> list[Organization]:
|
||||
async with self.Session() as session:
|
||||
organizations = (await session.scalars(select(OrganizationModel))).all()
|
||||
return [convert_to_organization(organization) for organization in organizations]
|
||||
|
||||
@db_operation("get_organization")
|
||||
async def get_organization(self, organization_id: str) -> Organization | None:
|
||||
async with self.Session() as session:
|
||||
if organization := (
|
||||
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||
).first():
|
||||
return convert_to_organization(organization)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_organization_by_domain")
|
||||
async def get_organization_by_domain(self, domain: str) -> Organization | None:
|
||||
async with self.Session() as session:
|
||||
if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first():
|
||||
return convert_to_organization(organization)
|
||||
return None
|
||||
|
||||
@db_operation("create_organization")
|
||||
async def create_organization(
|
||||
self,
|
||||
organization_name: str,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
max_retries_per_step: int | None = None,
|
||||
domain: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> Organization:
|
||||
async with self.Session() as session:
|
||||
org = OrganizationModel(
|
||||
organization_id=organization_id,
|
||||
organization_name=organization_name,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
max_retries_per_step=max_retries_per_step,
|
||||
domain=domain,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
|
||||
return convert_to_organization(org)
|
||||
|
||||
@db_operation("update_organization")
|
||||
async def update_organization(
|
||||
self,
|
||||
organization_id: str,
|
||||
organization_name: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
max_retries_per_step: int | None = None,
|
||||
) -> Organization:
|
||||
async with self.Session() as session:
|
||||
organization = (
|
||||
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
||||
).first()
|
||||
if not organization:
|
||||
raise NotFoundError
|
||||
if organization_name:
|
||||
organization.organization_name = organization_name
|
||||
if webhook_callback_url:
|
||||
organization.webhook_callback_url = webhook_callback_url
|
||||
if max_steps_per_run:
|
||||
organization.max_steps_per_run = max_steps_per_run
|
||||
if max_retries_per_step:
|
||||
organization.max_retries_per_step = max_retries_per_step
|
||||
await session.commit()
|
||||
await session.refresh(organization)
|
||||
return Organization.model_validate(organization)
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["api", "onepassword_service_account", "custom_credential_service"],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token( # type: ignore
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["azure_client_secret_credential"],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token( # type: ignore
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal["bitwarden_credential"],
|
||||
) -> BitwardenOrganizationAuthToken | None: ...
|
||||
|
||||
@db_operation("get_valid_org_auth_token")
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[
|
||||
"api",
|
||||
"onepassword_service_account",
|
||||
"azure_client_secret_credential",
|
||||
"bitwarden_credential",
|
||||
"custom_credential_service",
|
||||
],
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken | None:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return await convert_to_organization_auth_token(token, token_type)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_valid_org_auth_tokens")
|
||||
async def get_valid_org_auth_tokens(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> list[OrganizationAuthToken]:
|
||||
async with self.Session() as session:
|
||||
tokens = (
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||
|
||||
@db_operation("validate_org_auth_token")
|
||||
async def validate_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
valid: bool | None = True,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | None:
|
||||
encrypted_token = ""
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
)
|
||||
if encrypted_token:
|
||||
query = query.filter_by(encrypted_token=encrypted_token)
|
||||
else:
|
||||
query = query.filter_by(token=token)
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("create_org_auth_token")
|
||||
async def create_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str | AzureClientSecretCredential | BitwardenCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
||||
if not isinstance(token, BitwardenCredential):
|
||||
raise TypeError("Expected BitwardenCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
auth_token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=plaintext_token,
|
||||
encrypted_token=encrypted_token,
|
||||
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||
)
|
||||
session.add(auth_token)
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
@db_operation("invalidate_org_auth_tokens")
|
||||
async def invalidate_org_auth_tokens(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> None:
|
||||
"""Invalidate all existing tokens of a specific type for an organization."""
|
||||
async with self.Session() as session:
|
||||
await session.execute(
|
||||
update(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.values(valid=False)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("replace_org_auth_token")
|
||||
async def replace_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str | AzureClientSecretCredential | BitwardenCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
||||
"""Atomically invalidate existing tokens and create a new one in a single transaction."""
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
||||
if not isinstance(token, BitwardenCredential):
|
||||
raise TypeError("Expected BitwardenCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
# Invalidate existing tokens
|
||||
await session.execute(
|
||||
update(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(valid=True)
|
||||
.values(valid=False)
|
||||
)
|
||||
# Create new token
|
||||
auth_token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=plaintext_token,
|
||||
encrypted_token=encrypted_token,
|
||||
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||
)
|
||||
session.add(auth_token)
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import and_, asc, select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.models import TOTPCodeModel
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
|
||||
class OTPMixin:
|
||||
"""Database operations for OTP/TOTP management."""
|
||||
|
||||
Session: _SessionFactory
|
||||
|
||||
@db_operation("get_otp_codes")
|
||||
async def get_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
otp_type: OTPType | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[TOTPCode]:
|
||||
"""
|
||||
1. filter by:
|
||||
- organization_id
|
||||
- totp_identifier
|
||||
- workflow_run_id (optional)
|
||||
2. make sure created_at is within the valid lifespan
|
||||
3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc
|
||||
4. apply an optional limit at the DB layer
|
||||
"""
|
||||
all_null = and_(
|
||||
TOTPCodeModel.task_id.is_(None),
|
||||
TOTPCodeModel.workflow_id.is_(None),
|
||||
TOTPCodeModel.workflow_run_id.is_(None),
|
||||
)
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(totp_identifier=totp_identifier)
|
||||
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||
)
|
||||
if otp_type:
|
||||
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
query = query.order_by(asc(all_null), TOTPCodeModel.created_at.desc())
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
totp_codes = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(code) for code in totp_codes]
|
||||
|
||||
@db_operation("get_otp_codes_by_run")
|
||||
async def get_otp_codes_by_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
limit: int = 1,
|
||||
) -> list[TOTPCode]:
|
||||
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||
|
||||
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||
The user submits codes manually via the UI, and this method finds them by run context.
|
||||
"""
|
||||
if not workflow_run_id and not task_id:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||
)
|
||||
if workflow_run_id:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
elif task_id:
|
||||
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
results = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(r) for r in results]
|
||||
|
||||
@db_operation("get_recent_otp_codes")
|
||||
async def get_recent_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
limit: int = 50,
|
||||
valid_lifespan_minutes: int | None = None,
|
||||
otp_type: OTPType | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
) -> list[TOTPCode]:
|
||||
"""
|
||||
Return recent otp codes for an organization ordered by newest first with optional
|
||||
workflow_run_id filtering.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(TOTPCodeModel).filter_by(organization_id=organization_id)
|
||||
|
||||
if valid_lifespan_minutes is not None:
|
||||
query = query.filter(
|
||||
TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)
|
||||
)
|
||||
|
||||
if otp_type:
|
||||
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
||||
if workflow_run_id is not None:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
if totp_identifier:
|
||||
query = query.filter(TOTPCodeModel.totp_identifier == totp_identifier)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
totp_codes = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_codes]
|
||||
|
||||
@db_operation("create_otp_code")
|
||||
async def create_otp_code(
|
||||
self,
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
content: str,
|
||||
code: str,
|
||||
otp_type: OTPType,
|
||||
task_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
source: str | None = None,
|
||||
expired_at: datetime | None = None,
|
||||
) -> TOTPCode:
|
||||
async with self.Session() as session:
|
||||
new_totp_code = TOTPCodeModel(
|
||||
organization_id=organization_id,
|
||||
totp_identifier=totp_identifier,
|
||||
content=content,
|
||||
code=code,
|
||||
task_id=task_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
source=source,
|
||||
expired_at=expired_at,
|
||||
otp_type=otp_type,
|
||||
)
|
||||
session.add(new_totp_code)
|
||||
await session.commit()
|
||||
await session.refresh(new_totp_code)
|
||||
return TOTPCode.model_validate(new_totp_code)
|
||||
|
|
@ -1,654 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import func, or_, select, text, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import ScheduleLimitExceededError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowScheduleModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow_schedule
|
||||
from skyvern.forge.sdk.schemas.workflow_schedules import OrganizationScheduleItem, WorkflowSchedule
|
||||
from skyvern.forge.sdk.workflow.schedules import compute_next_run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from skyvern.forge.sdk.db._sentinels import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class SchedulesMixin:
|
||||
"""Database operations for workflow schedules.
|
||||
|
||||
.. deprecated::
|
||||
This mixin is part of the legacy database layer. New code should use the
|
||||
repository classes in ``skyvern.forge.sdk.db.repositories`` instead.
|
||||
|
||||
Cross-mixin migrations already completed:
|
||||
- ``soft_delete_workflow_and_schedules_by_permanent_id`` → ``WorkflowsRepository``
|
||||
(operates on workflows as the primary entity, schedules are a side-effect).
|
||||
"""
|
||||
|
||||
Session: _SessionFactory
|
||||
engine: AsyncEngine
|
||||
debug_enabled: bool
|
||||
_sqlite_schedule_lock: asyncio.Lock | None
|
||||
|
||||
@db_operation("create_workflow_schedule")
|
||||
async def create_workflow_schedule(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
temporal_schedule_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> WorkflowSchedule:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = WorkflowScheduleModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cron_expression=cron_expression,
|
||||
timezone=timezone,
|
||||
enabled=enabled,
|
||||
parameters=parameters,
|
||||
temporal_schedule_id=temporal_schedule_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_schedule)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("create_workflow_schedule_with_limit")
|
||||
async def create_workflow_schedule_with_limit(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
max_schedules: int | None,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> tuple[WorkflowSchedule, int]:
|
||||
"""Create a schedule atomically with limit enforcement.
|
||||
|
||||
On PostgreSQL, uses an advisory lock to serialize concurrent creates for
|
||||
the same workflow, preventing TOCTOU races on the schedule count.
|
||||
|
||||
On SQLite, uses an asyncio.Lock (set on AgentDB.__init__) since SQLite
|
||||
is single-writer and has no advisory lock support.
|
||||
|
||||
Returns (created_schedule, count_before_insert).
|
||||
Raises ScheduleLimitExceededError if count >= max_schedules.
|
||||
"""
|
||||
# SQLite: serialize via Python lock (no advisory locks available).
|
||||
# The lock is held across the count-check + insert to prevent TOCTOU.
|
||||
sqlite_lock = getattr(self, "_sqlite_schedule_lock", None)
|
||||
if sqlite_lock is not None:
|
||||
async with sqlite_lock:
|
||||
return await self._create_schedule_with_limit_inner(
|
||||
organization_id,
|
||||
workflow_permanent_id,
|
||||
max_schedules,
|
||||
cron_expression,
|
||||
timezone,
|
||||
enabled,
|
||||
parameters,
|
||||
name,
|
||||
description,
|
||||
use_advisory_lock=False,
|
||||
)
|
||||
return await self._create_schedule_with_limit_inner(
|
||||
organization_id,
|
||||
workflow_permanent_id,
|
||||
max_schedules,
|
||||
cron_expression,
|
||||
timezone,
|
||||
enabled,
|
||||
parameters,
|
||||
name,
|
||||
description,
|
||||
use_advisory_lock=True,
|
||||
)
|
||||
|
||||
# Intentionally not decorated with @db_operation — errors are caught by the
|
||||
# outer create_workflow_schedule_with_limit which owns the operation name.
|
||||
async def _create_schedule_with_limit_inner(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
max_schedules: int | None,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
*,
|
||||
use_advisory_lock: bool,
|
||||
) -> tuple[WorkflowSchedule, int]:
|
||||
async with self.Session() as session:
|
||||
if use_advisory_lock:
|
||||
lock_key = f"schedule:{organization_id}:{workflow_permanent_id}"
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(:key))"),
|
||||
{"key": lock_key},
|
||||
)
|
||||
|
||||
count = (
|
||||
await session.execute(
|
||||
select(func.count()).where(
|
||||
WorkflowScheduleModel.organization_id == organization_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if max_schedules is not None and count >= max_schedules:
|
||||
raise ScheduleLimitExceededError(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
current_count=count,
|
||||
max_allowed=max_schedules,
|
||||
)
|
||||
|
||||
workflow_schedule = WorkflowScheduleModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cron_expression=cron_expression,
|
||||
timezone=timezone,
|
||||
enabled=enabled,
|
||||
parameters=parameters,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_schedule)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled), count
|
||||
|
||||
@db_operation("set_temporal_schedule_id")
|
||||
async def set_temporal_schedule_id(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
temporal_schedule_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.temporal_schedule_id = temporal_schedule_id
|
||||
workflow_schedule.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("update_workflow_schedule")
|
||||
async def update_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
cron_expression: str,
|
||||
timezone: str,
|
||||
enabled: bool,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
temporal_schedule_id: str | None | object = _UNSET,
|
||||
name: str | None | object = _UNSET,
|
||||
description: str | None | object = _UNSET,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.cron_expression = cron_expression
|
||||
workflow_schedule.timezone = timezone
|
||||
workflow_schedule.enabled = enabled
|
||||
workflow_schedule.parameters = parameters
|
||||
if temporal_schedule_id is not _UNSET:
|
||||
workflow_schedule.temporal_schedule_id = temporal_schedule_id
|
||||
if name is not _UNSET:
|
||||
workflow_schedule.name = name
|
||||
if description is not _UNSET:
|
||||
workflow_schedule.description = description
|
||||
workflow_schedule.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_schedule_by_id")
|
||||
async def get_workflow_schedule_by_id(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_schedules")
|
||||
async def get_workflow_schedules(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> list[WorkflowSchedule]:
|
||||
async with self.Session() as session:
|
||||
rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
|
||||
|
||||
@db_operation("get_all_enabled_schedules")
|
||||
async def get_all_enabled_schedules(
|
||||
self,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowSchedule]:
|
||||
"""Fetch all enabled, non-deleted schedules, optionally filtered by org."""
|
||||
async with self.Session() as session:
|
||||
stmt = select(WorkflowScheduleModel).where(
|
||||
WorkflowScheduleModel.enabled.is_(True),
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
if organization_id:
|
||||
stmt = stmt.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
rows = (await session.scalars(stmt)).all()
|
||||
return [convert_to_workflow_schedule(r, self.debug_enabled) for r in rows]
|
||||
|
||||
@db_operation("has_schedule_fired_since")
|
||||
async def has_schedule_fired_since(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
since: datetime,
|
||||
) -> bool:
|
||||
"""Check if a workflow_run exists for the given schedule since a timestamp."""
|
||||
from sqlalchemy import exists as sa_exists
|
||||
|
||||
async with self.Session() as session:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(
|
||||
sa_exists().where(
|
||||
WorkflowRunModel.workflow_schedule_id == workflow_schedule_id,
|
||||
WorkflowRunModel.created_at >= since,
|
||||
)
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
return bool(row)
|
||||
|
||||
@db_operation("update_workflow_schedule_enabled")
|
||||
async def update_workflow_schedule_enabled(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
enabled: bool,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
workflow_schedule.enabled = enabled
|
||||
workflow_schedule.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("delete_workflow_schedule")
|
||||
async def delete_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel).filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.deleted_at = datetime.utcnow()
|
||||
workflow_schedule.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("restore_workflow_schedule")
|
||||
async def restore_workflow_schedule(
|
||||
self,
|
||||
workflow_schedule_id: str,
|
||||
organization_id: str,
|
||||
) -> WorkflowSchedule | None:
|
||||
async with self.Session() as session:
|
||||
workflow_schedule = (
|
||||
await session.scalars(
|
||||
select(WorkflowScheduleModel)
|
||||
.filter_by(
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
.filter(WorkflowScheduleModel.deleted_at.isnot(None))
|
||||
)
|
||||
).first()
|
||||
if not workflow_schedule:
|
||||
return None
|
||||
|
||||
workflow_schedule.deleted_at = None
|
||||
workflow_schedule.modified_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await session.refresh(workflow_schedule)
|
||||
return convert_to_workflow_schedule(workflow_schedule, self.debug_enabled)
|
||||
|
||||
@db_operation("count_workflow_schedules")
|
||||
async def count_workflow_schedules(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(func.count()).where(
|
||||
WorkflowScheduleModel.organization_id == organization_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id,
|
||||
WorkflowScheduleModel.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
@db_operation("list_organization_schedules")
|
||||
async def list_organization_schedules(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
enabled_filter: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[OrganizationScheduleItem], int]:
|
||||
"""
|
||||
List all schedules for an organization, joined with workflow titles.
|
||||
Returns (schedules, total_count).
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
# Subquery to get the latest version title per workflow_permanent_id
|
||||
latest_version_sq = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(WorkflowModel.workflow_permanent_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
workflow_title_sq = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
WorkflowModel.title,
|
||||
)
|
||||
.join(
|
||||
latest_version_sq,
|
||||
(WorkflowModel.workflow_permanent_id == latest_version_sq.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == latest_version_sq.c.max_version),
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Base query: schedules joined with workflow titles
|
||||
base_filter = (
|
||||
select(WorkflowScheduleModel, workflow_title_sq.c.title.label("workflow_title"))
|
||||
.outerjoin(
|
||||
workflow_title_sq,
|
||||
WorkflowScheduleModel.workflow_permanent_id == workflow_title_sq.c.workflow_permanent_id,
|
||||
)
|
||||
.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if enabled_filter is not None:
|
||||
base_filter = base_filter.where(WorkflowScheduleModel.enabled == enabled_filter)
|
||||
|
||||
if search:
|
||||
base_filter = base_filter.where(
|
||||
or_(
|
||||
workflow_title_sq.c.title.icontains(search, autoescape=True),
|
||||
WorkflowScheduleModel.name.icontains(search, autoescape=True),
|
||||
)
|
||||
)
|
||||
|
||||
# Count query
|
||||
count_query = select(func.count()).select_from(base_filter.subquery())
|
||||
total_count = (await session.execute(count_query)).scalar_one()
|
||||
|
||||
# Data query with pagination
|
||||
data_query = (
|
||||
base_filter.order_by(WorkflowScheduleModel.created_at.desc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
)
|
||||
rows = (await session.execute(data_query)).all()
|
||||
|
||||
# Materialize row data while session is open
|
||||
raw_schedules = []
|
||||
for row in rows:
|
||||
schedule_model = row[0]
|
||||
raw_schedules.append(
|
||||
(
|
||||
schedule_model.workflow_schedule_id,
|
||||
schedule_model.organization_id,
|
||||
schedule_model.workflow_permanent_id,
|
||||
row[1] or "Untitled Workflow",
|
||||
schedule_model.cron_expression,
|
||||
schedule_model.timezone,
|
||||
schedule_model.enabled,
|
||||
schedule_model.parameters,
|
||||
schedule_model.name,
|
||||
schedule_model.description,
|
||||
schedule_model.created_at,
|
||||
schedule_model.modified_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Compute next_run outside session scope (pure CPU, no DB needed)
|
||||
schedules: list[OrganizationScheduleItem] = []
|
||||
for (
|
||||
ws_id,
|
||||
org_id,
|
||||
wpid,
|
||||
title,
|
||||
cron_expr,
|
||||
tz,
|
||||
enabled,
|
||||
params,
|
||||
name,
|
||||
description,
|
||||
created,
|
||||
modified,
|
||||
) in raw_schedules:
|
||||
next_run = None
|
||||
if enabled:
|
||||
try:
|
||||
next_run = compute_next_run(cron_expr, tz)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to compute next_run for schedule",
|
||||
workflow_schedule_id=ws_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
schedules.append(
|
||||
OrganizationScheduleItem(
|
||||
workflow_schedule_id=ws_id,
|
||||
organization_id=org_id,
|
||||
workflow_permanent_id=wpid,
|
||||
workflow_title=title,
|
||||
cron_expression=cron_expr,
|
||||
timezone=tz,
|
||||
enabled=enabled,
|
||||
parameters=params,
|
||||
name=name,
|
||||
description=description,
|
||||
next_run=next_run,
|
||||
created_at=created,
|
||||
modified_at=modified,
|
||||
)
|
||||
)
|
||||
|
||||
return schedules, total_count
|
||||
|
||||
@db_operation("soft_delete_workflow_and_schedules_by_permanent_id")
|
||||
async def soft_delete_workflow_and_schedules_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Soft-delete a workflow and its active schedules in a single DB transaction.
|
||||
|
||||
.. deprecated::
|
||||
Moved to ``WorkflowsRepository.soft_delete_workflow_and_schedules_by_permanent_id``
|
||||
(skyvern/forge/sdk/db/repositories/workflows.py). The primary entity is the
|
||||
workflow, not the schedule, so it belongs in the workflows repository.
|
||||
This copy remains for backward compatibility with the legacy mixin layer.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
select_query = (
|
||||
select(WorkflowScheduleModel.workflow_schedule_id)
|
||||
.where(WorkflowScheduleModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
)
|
||||
if organization_id is not None:
|
||||
select_query = select_query.where(WorkflowScheduleModel.organization_id == organization_id)
|
||||
result = await session.execute(select_query)
|
||||
schedule_ids = list(result.scalars().all())
|
||||
|
||||
deleted_at = datetime.utcnow()
|
||||
if schedule_ids:
|
||||
update_schedules_query = (
|
||||
update(WorkflowScheduleModel)
|
||||
.where(WorkflowScheduleModel.workflow_schedule_id.in_(schedule_ids))
|
||||
.values(deleted_at=deleted_at)
|
||||
)
|
||||
await session.execute(update_schedules_query)
|
||||
|
||||
update_workflow_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
)
|
||||
if organization_id is not None:
|
||||
update_workflow_query = update_workflow_query.filter_by(organization_id=organization_id)
|
||||
await session.execute(update_workflow_query.values(deleted_at=deleted_at))
|
||||
await session.commit()
|
||||
return schedule_ids
|
||||
|
||||
@db_operation("soft_delete_orphaned_schedules")
|
||||
async def soft_delete_orphaned_schedules(self, limit: int = 500) -> list[tuple[str, str]]:
|
||||
"""Soft-delete orphaned schedules and return their identities.
|
||||
|
||||
Uses a single UPDATE ... RETURNING statement so orphan detection and
|
||||
soft-deletion happen atomically in one DB round-trip.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
active_workflow_exists = (
|
||||
select(WorkflowModel.workflow_permanent_id)
|
||||
.where(WorkflowModel.workflow_permanent_id == WorkflowScheduleModel.workflow_permanent_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.correlate(WorkflowScheduleModel)
|
||||
.exists()
|
||||
)
|
||||
orphaned_schedules = (
|
||||
select(
|
||||
WorkflowScheduleModel.workflow_schedule_id.label("workflow_schedule_id"),
|
||||
WorkflowScheduleModel.workflow_permanent_id.label("workflow_permanent_id"),
|
||||
)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
.where(~active_workflow_exists)
|
||||
.limit(limit)
|
||||
.cte("orphaned_schedules")
|
||||
)
|
||||
update_query = (
|
||||
update(WorkflowScheduleModel)
|
||||
.where(
|
||||
WorkflowScheduleModel.workflow_schedule_id.in_(select(orphaned_schedules.c.workflow_schedule_id))
|
||||
)
|
||||
.where(WorkflowScheduleModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.utcnow())
|
||||
.returning(
|
||||
WorkflowScheduleModel.workflow_schedule_id,
|
||||
WorkflowScheduleModel.workflow_permanent_id,
|
||||
)
|
||||
)
|
||||
result = await session.execute(update_query)
|
||||
await session.commit()
|
||||
return [(row[0], row[1]) for row in result.all()]
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,935 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.mixins.base import read_retry
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunStatus, RunType
|
||||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class TasksMixin:
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
|
||||
@db_operation("create_task")
|
||||
async def create_task(
|
||||
self,
|
||||
url: str,
|
||||
title: str | None,
|
||||
navigation_goal: str | None,
|
||||
data_extraction_goal: str | None,
|
||||
navigation_payload: dict[str, Any] | list | str | None,
|
||||
status: str = "created",
|
||||
complete_criterion: str | None = None,
|
||||
terminate_criterion: str | None = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
extracted_information_schema: dict[str, Any] | list | str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
order: int | None = None,
|
||||
retry: int | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
task_type: str = TaskType.general,
|
||||
application: str | None = None,
|
||||
include_action_history_in_verification: bool | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
browser_address: str | None = None,
|
||||
download_timeout: float | None = None,
|
||||
) -> Task:
|
||||
# Sanitize text fields to remove NUL bytes and control characters
|
||||
# that PostgreSQL cannot store in text columns
|
||||
def _sanitize(v: str | None) -> str | None:
|
||||
return sanitize_postgres_text(v) if isinstance(v, str) else v
|
||||
|
||||
navigation_goal = _sanitize(navigation_goal)
|
||||
data_extraction_goal = _sanitize(data_extraction_goal)
|
||||
title = _sanitize(title)
|
||||
url = sanitize_postgres_text(url)
|
||||
complete_criterion = _sanitize(complete_criterion)
|
||||
terminate_criterion = _sanitize(terminate_criterion)
|
||||
|
||||
async with self.Session() as session:
|
||||
new_task = TaskModel(
|
||||
status=status,
|
||||
task_type=task_type,
|
||||
url=url,
|
||||
title=title,
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
navigation_goal=navigation_goal,
|
||||
complete_criterion=complete_criterion,
|
||||
terminate_criterion=terminate_criterion,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
organization_id=organization_id,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
workflow_run_id=workflow_run_id,
|
||||
order=order,
|
||||
retry=retry,
|
||||
max_steps_per_run=max_steps_per_run,
|
||||
error_code_mapping=error_code_mapping,
|
||||
application=application,
|
||||
include_action_history_in_verification=include_action_history_in_verification,
|
||||
model=model,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_session_id=browser_session_id,
|
||||
browser_address=browser_address,
|
||||
download_timeout=download_timeout,
|
||||
)
|
||||
session.add(new_task)
|
||||
await session.commit()
|
||||
await session.refresh(new_task)
|
||||
return convert_to_task(new_task, self.debug_enabled)
|
||||
|
||||
@db_operation("create_step")
|
||||
async def create_step(
|
||||
self,
|
||||
task_id: str,
|
||||
order: int,
|
||||
retry_index: int,
|
||||
organization_id: str | None = None,
|
||||
status: StepStatus = StepStatus.created,
|
||||
created_by: str | None = None,
|
||||
) -> Step:
|
||||
async with self.Session() as session:
|
||||
new_step = StepModel(
|
||||
task_id=task_id,
|
||||
order=order,
|
||||
retry_index=retry_index,
|
||||
status=status,
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(new_step)
|
||||
await session.commit()
|
||||
await session.refresh(new_step)
|
||||
return convert_to_step(new_step, debug_enabled=self.debug_enabled)
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_task", log_errors=False)
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
async with self.Session() as session:
|
||||
query = select(TaskModel).filter_by(task_id=task_id)
|
||||
if organization_id is not None:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
if task_obj := (await session.scalars(query)).first():
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Task not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_tasks_by_ids")
|
||||
async def get_tasks_by_ids(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> list[Task]:
|
||||
async with self.Session() as session:
|
||||
tasks = (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter(TaskModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
|
||||
@db_operation("get_step")
|
||||
async def get_step(self, step_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel).filter_by(step_id=step_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
@db_operation("get_task_steps")
|
||||
async def get_task_steps(self, task_id: str, organization_id: str) -> list[Step]:
|
||||
async with self.Session() as session:
|
||||
if steps := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all():
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
else:
|
||||
return []
|
||||
|
||||
@db_operation("get_steps_by_task_ids")
|
||||
async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]:
|
||||
async with self.Session() as session:
|
||||
steps = (
|
||||
await session.scalars(
|
||||
select(StepModel).filter(StepModel.task_id.in_(task_ids)).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
|
||||
|
||||
@db_operation("get_total_unique_step_order_count_by_task_ids")
|
||||
async def get_total_unique_step_order_count_by_task_ids(
|
||||
self,
|
||||
*,
|
||||
task_ids: list[str],
|
||||
organization_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids
|
||||
Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s
|
||||
where s.task_id in task_ids
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
subq = (
|
||||
select(StepModel.task_id, StepModel.order)
|
||||
.where(StepModel.task_id.in_(task_ids))
|
||||
.where(StepModel.organization_id == organization_id)
|
||||
.distinct()
|
||||
.subquery()
|
||||
)
|
||||
query = select(func.count()).select_from(subq)
|
||||
return (await session.execute(query)).scalar()
|
||||
|
||||
@db_operation("get_task_step_models")
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
|
||||
async with self.Session() as session:
|
||||
return (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order)
|
||||
.order_by(StepModel.retry_index)
|
||||
)
|
||||
).all()
|
||||
|
||||
@db_operation("get_task_step_count")
|
||||
async def get_task_step_count(self, task_id: str, organization_id: str | None = None) -> int:
|
||||
async with self.Session() as session:
|
||||
result = await session.scalar(
|
||||
select(func.count(StepModel.step_id))
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
return result or 0
|
||||
|
||||
@db_operation("get_task_actions")
|
||||
async def get_task_actions(self, task_id: str, organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id == task_id)
|
||||
.order_by(ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("get_task_actions_hydrated")
|
||||
async def get_task_actions_hydrated(self, task_id: str, organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id == task_id)
|
||||
.order_by(ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [hydrate_action(action) for action in actions]
|
||||
|
||||
@db_operation("get_tasks_actions")
|
||||
async def get_tasks_actions(self, task_ids: list[str], organization_id: str | None = None) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.organization_id == organization_id)
|
||||
.filter(ActionModel.task_id.in_(task_ids))
|
||||
.order_by(ActionModel.created_at.desc())
|
||||
)
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [hydrate_action(action) for action in actions]
|
||||
|
||||
@db_operation("get_action_count_for_step")
|
||||
async def get_action_count_for_step(self, step_id: str, task_id: str, organization_id: str) -> int:
|
||||
"""Get count of actions for a step. Uses composite index for efficiency."""
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(func.count())
|
||||
.select_from(ActionModel)
|
||||
.where(ActionModel.organization_id == organization_id)
|
||||
.where(ActionModel.task_id == task_id)
|
||||
.where(ActionModel.step_id == step_id)
|
||||
)
|
||||
result = await session.scalar(query)
|
||||
return result or 0
|
||||
|
||||
@db_operation("get_first_step")
|
||||
async def get_first_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.order_by(StepModel.order.asc())
|
||||
.order_by(StepModel.retry_index.asc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Latest step not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_latest_step")
|
||||
async def get_latest_step(self, task_id: str, organization_id: str | None = None) -> Step | None:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(StepModel.status != StepStatus.canceled)
|
||||
.order_by(StepModel.order.desc())
|
||||
.order_by(StepModel.retry_index.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_step(step, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Latest step not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("update_step")
|
||||
async def update_step(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
status: StepStatus | None = None,
|
||||
output: AgentStepOutput | None = None,
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
incremental_cost: float | None = None,
|
||||
incremental_input_tokens: int | None = None,
|
||||
incremental_output_tokens: int | None = None,
|
||||
incremental_reasoning_tokens: int | None = None,
|
||||
incremental_cached_tokens: int | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> Step:
|
||||
async with self.Session() as session:
|
||||
if step := (
|
||||
await session.scalars(
|
||||
select(StepModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.filter_by(step_id=step_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
step.status = status
|
||||
|
||||
if status.is_terminal() and step.finished_at is None:
|
||||
step.finished_at = datetime.utcnow()
|
||||
if output is not None:
|
||||
step.output = output.model_dump(exclude_none=True)
|
||||
if is_last is not None:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if incremental_cost is not None:
|
||||
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||
if incremental_input_tokens is not None:
|
||||
step.input_token_count = incremental_input_tokens + (step.input_token_count or 0)
|
||||
if incremental_output_tokens is not None:
|
||||
step.output_token_count = incremental_output_tokens + (step.output_token_count or 0)
|
||||
if incremental_reasoning_tokens is not None:
|
||||
step.reasoning_token_count = incremental_reasoning_tokens + (step.reasoning_token_count or 0)
|
||||
if incremental_cached_tokens is not None:
|
||||
step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0)
|
||||
if created_by is not None:
|
||||
step.created_by = created_by
|
||||
|
||||
await session.commit()
|
||||
updated_step = await self.get_step(step_id, organization_id)
|
||||
if not updated_step:
|
||||
raise NotFoundError("Step not found")
|
||||
return updated_step
|
||||
else:
|
||||
raise NotFoundError("Step not found")
|
||||
|
||||
@db_operation("clear_task_failure_reason")
|
||||
async def clear_task_failure_reason(self, organization_id: str, task_id: str) -> Task:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.failure_reason = None
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("update_task")
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
extracted_information: dict[str, Any] | list | str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
failure_reason: str | None = None,
|
||||
errors: list[dict[str, Any]] | None = None,
|
||||
max_steps_per_run: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
failure_category: list[dict[str, Any]] | None = None,
|
||||
) -> Task:
|
||||
if (
|
||||
status is None
|
||||
and extracted_information is None
|
||||
and failure_reason is None
|
||||
and errors is None
|
||||
and max_steps_per_run is None
|
||||
and webhook_failure_reason is None
|
||||
and failure_category is None
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of status, extracted_information, or failure_reason must be provided to update the task"
|
||||
)
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if status == TaskStatus.queued and task.queued_at is None:
|
||||
task.queued_at = datetime.utcnow()
|
||||
if status == TaskStatus.running and task.started_at is None:
|
||||
task.started_at = datetime.utcnow()
|
||||
if status.is_final() and task.finished_at is None:
|
||||
task.finished_at = datetime.utcnow()
|
||||
if extracted_information is not None:
|
||||
task.extracted_information = extracted_information
|
||||
if failure_reason is not None:
|
||||
task.failure_reason = failure_reason
|
||||
if errors is not None:
|
||||
task.errors = (task.errors or []) + errors
|
||||
if max_steps_per_run is not None:
|
||||
task.max_steps_per_run = max_steps_per_run
|
||||
if webhook_failure_reason is not None:
|
||||
task.webhook_failure_reason = webhook_failure_reason
|
||||
if failure_category is not None:
|
||||
task.failure_category = failure_category
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("update_task_2fa_state")
|
||||
async def update_task_2fa_state(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
waiting_for_verification_code: bool,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> Task:
|
||||
"""Update task 2FA verification code waiting state."""
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
task.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if not waiting_for_verification_code:
|
||||
# Clear identifiers when no longer waiting
|
||||
task.verification_code_identifier = None
|
||||
task.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
@db_operation("bulk_update_tasks")
|
||||
async def bulk_update_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
status: TaskStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Bulk update tasks by their IDs.
|
||||
|
||||
Args:
|
||||
task_ids: List of task IDs to update
|
||||
status: Optional status to set for all tasks
|
||||
failure_reason: Optional failure reason to set for all tasks
|
||||
"""
|
||||
if not task_ids:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
update_values = {}
|
||||
if status:
|
||||
update_values["status"] = status.value
|
||||
if failure_reason:
|
||||
update_values["failure_reason"] = failure_reason
|
||||
|
||||
if update_values:
|
||||
update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values)
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_tasks")
|
||||
async def get_tasks(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
task_status: list[TaskStatus] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
only_standalone_tasks: bool = False,
|
||||
application: str | None = None,
|
||||
order_by_column: OrderBy = OrderBy.created_at,
|
||||
order: SortDirection = SortDirection.desc,
|
||||
) -> list[Task]:
|
||||
"""
|
||||
Get all tasks.
|
||||
:param page: Starts at 1
|
||||
:param page_size:
|
||||
:param task_status:
|
||||
:param workflow_run_id:
|
||||
:param only_standalone_tasks:
|
||||
:param order_by_column:
|
||||
:param order:
|
||||
:return:
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
query = (
|
||||
select(TaskModel, WorkflowRunModel.workflow_permanent_id)
|
||||
.join(WorkflowRunModel, TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id, isouter=True)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
)
|
||||
if task_status:
|
||||
query = query.filter(TaskModel.status.in_(task_status))
|
||||
if workflow_run_id:
|
||||
query = query.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
if only_standalone_tasks:
|
||||
query = query.filter(TaskModel.workflow_run_id.is_(None))
|
||||
if application:
|
||||
query = query.filter(TaskModel.application == application)
|
||||
order_by_col = getattr(TaskModel, order_by_column)
|
||||
query = (
|
||||
query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc())
|
||||
.limit(page_size)
|
||||
.offset(db_page * page_size)
|
||||
)
|
||||
|
||||
results = (await session.execute(query)).all()
|
||||
|
||||
return [
|
||||
convert_to_task(task, debug_enabled=self.debug_enabled, workflow_permanent_id=workflow_permanent_id)
|
||||
for task, workflow_permanent_id in results
|
||||
]
|
||||
|
||||
@db_operation("get_tasks_count")
|
||||
async def get_tasks_count(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_status: list[TaskStatus] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
only_standalone_tasks: bool = False,
|
||||
application: str | None = None,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
count_query = (
|
||||
select(func.count()).select_from(TaskModel).filter(TaskModel.organization_id == organization_id)
|
||||
)
|
||||
if task_status:
|
||||
count_query = count_query.filter(TaskModel.status.in_(task_status))
|
||||
if workflow_run_id:
|
||||
count_query = count_query.filter(TaskModel.workflow_run_id == workflow_run_id)
|
||||
if only_standalone_tasks:
|
||||
count_query = count_query.filter(TaskModel.workflow_run_id.is_(None))
|
||||
if application:
|
||||
count_query = count_query.filter(TaskModel.application == application)
|
||||
return (await session.execute(count_query)).scalar_one()
|
||||
|
||||
@db_operation("get_running_tasks_info_globally")
|
||||
async def get_running_tasks_info_globally(
|
||||
self,
|
||||
stale_threshold_hours: int = 24,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get information about running tasks across all organizations.
|
||||
Used by cleanup service to determine if cleanup should be skipped.
|
||||
|
||||
Args:
|
||||
stale_threshold_hours: Tasks not updated for this many hours are considered stale.
|
||||
|
||||
Returns:
|
||||
Tuple of (active_task_count, stale_task_count).
|
||||
Active tasks are those updated within the threshold.
|
||||
Stale tasks are those not updated within the threshold but still in running status.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
running_statuses = [TaskStatus.created, TaskStatus.queued, TaskStatus.running]
|
||||
stale_cutoff = datetime.utcnow() - timedelta(hours=stale_threshold_hours)
|
||||
|
||||
# Count active tasks (recently updated)
|
||||
active_query = (
|
||||
select(func.count())
|
||||
.select_from(TaskModel)
|
||||
.filter(TaskModel.status.in_(running_statuses))
|
||||
.filter(TaskModel.modified_at >= stale_cutoff)
|
||||
)
|
||||
active_count = (await session.execute(active_query)).scalar_one()
|
||||
|
||||
# Count stale tasks (not updated for a long time)
|
||||
stale_query = (
|
||||
select(func.count())
|
||||
.select_from(TaskModel)
|
||||
.filter(TaskModel.status.in_(running_statuses))
|
||||
.filter(TaskModel.modified_at < stale_cutoff)
|
||||
)
|
||||
stale_count = (await session.execute(stale_query)).scalar_one()
|
||||
|
||||
return (active_count, stale_count)
|
||||
|
||||
@db_operation("get_latest_task_by_workflow_id")
|
||||
async def get_latest_task_by_workflow_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_id: str,
|
||||
before: datetime | None = None,
|
||||
) -> Task | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskModel).filter_by(organization_id=organization_id).filter_by(workflow_id=workflow_id)
|
||||
if before:
|
||||
query = query.filter(TaskModel.created_at < before)
|
||||
task = (await session.scalars(query.order_by(TaskModel.created_at.desc()))).first()
|
||||
if task:
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_last_task_for_workflow_run")
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task, debug_enabled=self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("get_tasks_by_workflow_run_id")
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
async with self.Session() as session:
|
||||
tasks = (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
|
||||
|
||||
@db_operation("delete_task_steps")
|
||||
async def delete_task_steps(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete artifacts by filtering organization_id and task_id
|
||||
stmt = delete(StepModel).where(
|
||||
and_(
|
||||
StepModel.organization_id == organization_id,
|
||||
StepModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_previous_actions_for_task")
|
||||
async def get_previous_actions_for_task(self, task_id: str) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter_by(task_id=task_id)
|
||||
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
|
||||
)
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("delete_task_actions")
|
||||
async def delete_task_actions(self, organization_id: str, task_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# delete actions by filtering organization_id and task_id
|
||||
stmt = delete(ActionModel).where(
|
||||
and_(
|
||||
ActionModel.organization_id == organization_id,
|
||||
ActionModel.task_id == task_id,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def sync_task_run_status(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
status: str,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Best-effort write-through: propagate status from source table to task_runs.
|
||||
|
||||
Does NOT raise if the task_runs row is missing (race at creation time).
|
||||
"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
vals: dict[str, Any] = {"status": status}
|
||||
if started_at is not None:
|
||||
vals["started_at"] = started_at
|
||||
if finished_at is not None:
|
||||
vals["finished_at"] = finished_at
|
||||
stmt = (
|
||||
update(TaskRunModel)
|
||||
.where(TaskRunModel.run_id == run_id)
|
||||
.where(TaskRunModel.organization_id == organization_id)
|
||||
.values(**vals)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort task_run status sync failed",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: RunStatus | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
# script_run, started_at, finished_at are intentionally omitted here —
|
||||
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
|
||||
) -> Run:
|
||||
searchable_text = " ".join(filter(None, [title, url]))
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
status=status,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
debug_session_id=debug_session_id,
|
||||
searchable_text=searchable_text or None,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title is not None:
|
||||
task_run.title = title
|
||||
if url is not None:
|
||||
task_run.url = url
|
||||
if url_hash is not None:
|
||||
task_run.url_hash = url_hash
|
||||
if status is not None:
|
||||
task_run.status = status
|
||||
if started_at is not None:
|
||||
task_run.started_at = started_at
|
||||
if finished_at is not None:
|
||||
task_run.finished_at = finished_at
|
||||
|
||||
# Recompute searchable_text when title or url changes
|
||||
if title is not None or url is not None:
|
||||
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
|
||||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
|
@ -1,566 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
ActionModel,
|
||||
AISuggestionModel,
|
||||
AWSSecretParameterModel,
|
||||
AzureVaultCredentialParameterModel,
|
||||
Base,
|
||||
BitwardenCreditCardDataParameterModel,
|
||||
BitwardenLoginCredentialParameterModel,
|
||||
BitwardenSensitiveInformationParameterModel,
|
||||
CredentialParameterModel,
|
||||
OnePasswordCredentialParameterModel,
|
||||
OutputParameterModel,
|
||||
TaskGenerationModel,
|
||||
TaskModel,
|
||||
WorkflowCopilotChatMessageModel,
|
||||
WorkflowCopilotChatModel,
|
||||
WorkflowParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_aws_secret_parameter,
|
||||
convert_to_output_parameter,
|
||||
convert_to_workflow_copilot_chat_message,
|
||||
convert_to_workflow_parameter,
|
||||
hydrate_action,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||
WorkflowCopilotChat,
|
||||
WorkflowCopilotChatMessage,
|
||||
WorkflowCopilotChatSender,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
AWSSecretParameter,
|
||||
AzureVaultCredentialParameter,
|
||||
BitwardenCreditCardDataParameter,
|
||||
BitwardenLoginCredentialParameter,
|
||||
BitwardenSensitiveInformationParameter,
|
||||
ContextParameter,
|
||||
CredentialParameter,
|
||||
OnePasswordCredentialParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from skyvern.forge.sdk.db._sentinels import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowParametersMixin:
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
"""Database operations for workflow parameters, copilot chat, task generation, actions, and runs."""
|
||||
|
||||
@db_operation("create_workflow_parameter")
|
||||
async def create_workflow_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
workflow_parameter_type: WorkflowParameterType,
|
||||
key: str,
|
||||
default_value: Any,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
async with self.Session() as session:
|
||||
if default_value is None:
|
||||
pass
|
||||
elif workflow_parameter_type == WorkflowParameterType.JSON:
|
||||
default_value = json.dumps(default_value)
|
||||
else:
|
||||
default_value = str(default_value)
|
||||
workflow_parameter = WorkflowParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
)
|
||||
session.add(workflow_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_parameter)
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("create_aws_secret_parameter")
|
||||
async def create_aws_secret_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
aws_key: str,
|
||||
description: str | None = None,
|
||||
) -> AWSSecretParameter:
|
||||
async with self.Session() as session:
|
||||
aws_secret_parameter = AWSSecretParameterModel(
|
||||
workflow_id=workflow_id,
|
||||
key=key,
|
||||
aws_key=aws_key,
|
||||
description=description,
|
||||
)
|
||||
session.add(aws_secret_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(aws_secret_parameter)
|
||||
return convert_to_aws_secret_parameter(aws_secret_parameter)
|
||||
|
||||
@db_operation("create_output_parameter")
|
||||
async def create_output_parameter(
|
||||
self,
|
||||
workflow_id: str,
|
||||
key: str,
|
||||
description: str | None = None,
|
||||
) -> OutputParameter:
|
||||
async with self.Session() as session:
|
||||
output_parameter = OutputParameterModel(
|
||||
key=key,
|
||||
description=description,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(output_parameter)
|
||||
return convert_to_output_parameter(output_parameter)
|
||||
|
||||
@staticmethod
|
||||
def _convert_parameter_to_model(parameter: PARAMETER_TYPE) -> Base:
|
||||
"""Convert a parameter object to its corresponding SQLAlchemy model."""
|
||||
if isinstance(parameter, WorkflowParameter):
|
||||
if parameter.default_value is None:
|
||||
default_value = None
|
||||
elif parameter.workflow_parameter_type == WorkflowParameterType.JSON:
|
||||
default_value = json.dumps(parameter.default_value)
|
||||
else:
|
||||
default_value = str(parameter.default_value)
|
||||
return WorkflowParameterModel(
|
||||
workflow_parameter_id=parameter.workflow_parameter_id,
|
||||
workflow_parameter_type=parameter.workflow_parameter_type.value,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
workflow_id=parameter.workflow_id,
|
||||
default_value=default_value,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, OutputParameter):
|
||||
return OutputParameterModel(
|
||||
output_parameter_id=parameter.output_parameter_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
workflow_id=parameter.workflow_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, AWSSecretParameter):
|
||||
return AWSSecretParameterModel(
|
||||
aws_secret_parameter_id=parameter.aws_secret_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
aws_key=parameter.aws_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenLoginCredentialParameter):
|
||||
return BitwardenLoginCredentialParameterModel(
|
||||
bitwarden_login_credential_parameter_id=parameter.bitwarden_login_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_item_id=parameter.bitwarden_item_id,
|
||||
url_parameter_key=parameter.url_parameter_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenSensitiveInformationParameter):
|
||||
return BitwardenSensitiveInformationParameterModel(
|
||||
bitwarden_sensitive_information_parameter_id=parameter.bitwarden_sensitive_information_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_identity_key=parameter.bitwarden_identity_key,
|
||||
bitwarden_identity_fields=parameter.bitwarden_identity_fields,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, BitwardenCreditCardDataParameter):
|
||||
return BitwardenCreditCardDataParameterModel(
|
||||
bitwarden_credit_card_data_parameter_id=parameter.bitwarden_credit_card_data_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
|
||||
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
|
||||
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
|
||||
bitwarden_collection_id=parameter.bitwarden_collection_id,
|
||||
bitwarden_item_id=parameter.bitwarden_item_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, CredentialParameter):
|
||||
return CredentialParameterModel(
|
||||
credential_parameter_id=parameter.credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
credential_id=parameter.credential_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, OnePasswordCredentialParameter):
|
||||
return OnePasswordCredentialParameterModel(
|
||||
onepassword_credential_parameter_id=parameter.onepassword_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
vault_id=parameter.vault_id,
|
||||
item_id=parameter.item_id,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
elif isinstance(parameter, AzureVaultCredentialParameter):
|
||||
return AzureVaultCredentialParameterModel(
|
||||
azure_vault_credential_parameter_id=parameter.azure_vault_credential_parameter_id,
|
||||
workflow_id=parameter.workflow_id,
|
||||
key=parameter.key,
|
||||
description=parameter.description,
|
||||
vault_name=parameter.vault_name,
|
||||
username_key=parameter.username_key,
|
||||
password_key=parameter.password_key,
|
||||
totp_secret_key=parameter.totp_secret_key,
|
||||
deleted_at=parameter.deleted_at,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported workflow definition parameter type: {type(parameter).__name__}")
|
||||
|
||||
@db_operation("save_workflow_definition_parameters")
|
||||
async def save_workflow_definition_parameters(self, parameters: list[PARAMETER_TYPE]) -> None:
|
||||
"""Save multiple workflow definition parameters in a single transaction."""
|
||||
|
||||
# ContextParameter is not persisted
|
||||
parameters_to_save = [p for p in parameters if not isinstance(p, ContextParameter)]
|
||||
if not parameters_to_save:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
for parameter in parameters_to_save:
|
||||
model = self._convert_parameter_to_model(parameter)
|
||||
session.add(model)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_workflow_output_parameters")
|
||||
async def get_workflow_output_parameters(self, workflow_id: str) -> list[OutputParameter]:
|
||||
async with self.Session() as session:
|
||||
output_parameters = (
|
||||
await session.scalars(select(OutputParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
|
||||
@db_operation("get_workflow_output_parameters_by_ids")
|
||||
async def get_workflow_output_parameters_by_ids(self, output_parameter_ids: list[str]) -> list[OutputParameter]:
|
||||
async with self.Session() as session:
|
||||
output_parameters = (
|
||||
await session.scalars(
|
||||
select(OutputParameterModel).filter(
|
||||
OutputParameterModel.output_parameter_id.in_(output_parameter_ids)
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return [convert_to_output_parameter(parameter) for parameter in output_parameters]
|
||||
|
||||
@db_operation("get_workflow_parameters")
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
async with self.Session() as session:
|
||||
workflow_parameters = (
|
||||
await session.scalars(select(WorkflowParameterModel).filter_by(workflow_id=workflow_id))
|
||||
).all()
|
||||
return [convert_to_workflow_parameter(parameter) for parameter in workflow_parameters]
|
||||
|
||||
@db_operation("get_workflow_parameter")
|
||||
async def get_workflow_parameter(self, workflow_parameter_id: str) -> WorkflowParameter | None:
|
||||
async with self.Session() as session:
|
||||
if workflow_parameter := (
|
||||
await session.scalars(
|
||||
select(WorkflowParameterModel).filter_by(workflow_parameter_id=workflow_parameter_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_workflow_parameter(workflow_parameter, self.debug_enabled)
|
||||
return None
|
||||
|
||||
@db_operation("create_task_generation")
|
||||
async def create_task_generation(
|
||||
self,
|
||||
organization_id: str,
|
||||
user_prompt: str,
|
||||
user_prompt_hash: str,
|
||||
url: str | None = None,
|
||||
navigation_goal: str | None = None,
|
||||
navigation_payload: dict[str, Any] | None = None,
|
||||
data_extraction_goal: str | None = None,
|
||||
extracted_information_schema: dict[str, Any] | None = None,
|
||||
suggested_title: str | None = None,
|
||||
llm: str | None = None,
|
||||
llm_prompt: str | None = None,
|
||||
llm_response: str | None = None,
|
||||
source_task_generation_id: str | None = None,
|
||||
) -> TaskGeneration:
|
||||
async with self.Session() as session:
|
||||
new_task_generation = TaskGenerationModel(
|
||||
organization_id=organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=url,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
llm=llm,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=llm_response,
|
||||
suggested_title=suggested_title,
|
||||
source_task_generation_id=source_task_generation_id,
|
||||
)
|
||||
session.add(new_task_generation)
|
||||
await session.commit()
|
||||
await session.refresh(new_task_generation)
|
||||
return TaskGeneration.model_validate(new_task_generation)
|
||||
|
||||
@db_operation("create_ai_suggestion")
|
||||
async def create_ai_suggestion(
|
||||
self,
|
||||
organization_id: str,
|
||||
ai_suggestion_type: str,
|
||||
) -> AISuggestion:
|
||||
async with self.Session() as session:
|
||||
new_ai_suggestion = AISuggestionModel(
|
||||
organization_id=organization_id,
|
||||
ai_suggestion_type=ai_suggestion_type,
|
||||
)
|
||||
session.add(new_ai_suggestion)
|
||||
await session.commit()
|
||||
await session.refresh(new_ai_suggestion)
|
||||
return AISuggestion.model_validate(new_ai_suggestion)
|
||||
|
||||
@db_operation("create_workflow_copilot_chat")
|
||||
async def create_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> WorkflowCopilotChat:
|
||||
async with self.Session() as session:
|
||||
new_chat = WorkflowCopilotChatModel(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
)
|
||||
session.add(new_chat)
|
||||
await session.commit()
|
||||
await session.refresh(new_chat)
|
||||
return WorkflowCopilotChat.model_validate(new_chat)
|
||||
|
||||
@db_operation("update_workflow_copilot_chat")
|
||||
async def update_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
proposed_workflow: dict | None | object = _UNSET,
|
||||
auto_accept: bool | None = None,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
chat = (
|
||||
await session.scalars(
|
||||
select(WorkflowCopilotChatModel)
|
||||
.where(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.where(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
)
|
||||
).first()
|
||||
if not chat:
|
||||
return None
|
||||
|
||||
if proposed_workflow is not _UNSET:
|
||||
chat.proposed_workflow = proposed_workflow
|
||||
if auto_accept is not None:
|
||||
chat.auto_accept = auto_accept
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("create_workflow_copilot_chat_message")
|
||||
async def create_workflow_copilot_chat_message(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
sender: WorkflowCopilotChatSender,
|
||||
content: str,
|
||||
global_llm_context: str | None = None,
|
||||
) -> WorkflowCopilotChatMessage:
|
||||
async with self.Session() as session:
|
||||
new_message = WorkflowCopilotChatMessageModel(
|
||||
workflow_copilot_chat_id=workflow_copilot_chat_id,
|
||||
organization_id=organization_id,
|
||||
sender=sender,
|
||||
content=content,
|
||||
global_llm_context=global_llm_context,
|
||||
)
|
||||
session.add(new_message)
|
||||
await session.commit()
|
||||
await session.refresh(new_message)
|
||||
return convert_to_workflow_copilot_chat_message(new_message, self.debug_enabled)
|
||||
|
||||
@db_operation("get_workflow_copilot_chat_messages")
|
||||
async def get_workflow_copilot_chat_messages(
|
||||
self,
|
||||
workflow_copilot_chat_id: str,
|
||||
) -> list[WorkflowCopilotChatMessage]:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatMessageModel)
|
||||
.filter(WorkflowCopilotChatMessageModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
.order_by(WorkflowCopilotChatMessageModel.workflow_copilot_chat_message_id.asc())
|
||||
)
|
||||
messages = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_copilot_chat_message(message, self.debug_enabled) for message in messages]
|
||||
|
||||
@db_operation("get_workflow_copilot_chat_by_id")
|
||||
async def get_workflow_copilot_chat_by_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_copilot_chat_id: str,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatModel)
|
||||
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.filter(WorkflowCopilotChatModel.workflow_copilot_chat_id == workflow_copilot_chat_id)
|
||||
.order_by(WorkflowCopilotChatModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
chat = (await session.scalars(query)).first()
|
||||
if not chat:
|
||||
return None
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("get_latest_workflow_copilot_chat")
|
||||
async def get_latest_workflow_copilot_chat(
|
||||
self,
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> WorkflowCopilotChat | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(WorkflowCopilotChatModel)
|
||||
.filter(WorkflowCopilotChatModel.organization_id == organization_id)
|
||||
.filter(WorkflowCopilotChatModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.order_by(WorkflowCopilotChatModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
chat = (await session.scalars(query)).first()
|
||||
if not chat:
|
||||
return None
|
||||
return WorkflowCopilotChat.model_validate(chat)
|
||||
|
||||
@db_operation("get_task_generation_by_prompt_hash")
|
||||
async def get_task_generation_by_prompt_hash(
|
||||
self,
|
||||
user_prompt_hash: str,
|
||||
query_window_hours: int = settings.PROMPT_CACHE_WINDOW_HOURS,
|
||||
) -> TaskGeneration | None:
|
||||
before_time = datetime.utcnow() - timedelta(hours=query_window_hours)
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TaskGenerationModel)
|
||||
.filter_by(user_prompt_hash=user_prompt_hash)
|
||||
.filter(TaskGenerationModel.llm.is_not(None))
|
||||
.filter(TaskGenerationModel.created_at > before_time)
|
||||
)
|
||||
task_generation = (await session.scalars(query)).first()
|
||||
if not task_generation:
|
||||
return None
|
||||
return TaskGeneration.model_validate(task_generation)
|
||||
|
||||
@db_operation("create_action")
|
||||
async def create_action(self, action: Action) -> Action:
|
||||
async with self.Session() as session:
|
||||
new_action = ActionModel(
|
||||
action_type=action.action_type,
|
||||
source_action_id=action.source_action_id,
|
||||
organization_id=action.organization_id,
|
||||
workflow_run_id=action.workflow_run_id,
|
||||
task_id=action.task_id,
|
||||
step_id=action.step_id,
|
||||
step_order=action.step_order,
|
||||
action_order=action.action_order,
|
||||
status=action.status,
|
||||
reasoning=action.reasoning,
|
||||
intention=action.intention,
|
||||
response=action.response,
|
||||
element_id=action.element_id,
|
||||
skyvern_element_hash=action.skyvern_element_hash,
|
||||
skyvern_element_data=action.skyvern_element_data,
|
||||
screenshot_artifact_id=action.screenshot_artifact_id,
|
||||
action_json=action.model_dump(),
|
||||
confidence_float=action.confidence_float,
|
||||
created_by=action.created_by,
|
||||
)
|
||||
session.add(new_action)
|
||||
await session.commit()
|
||||
await session.refresh(new_action)
|
||||
return hydrate_action(new_action)
|
||||
|
||||
@db_operation("update_action_reasoning")
|
||||
async def update_action_reasoning(
|
||||
self,
|
||||
organization_id: str,
|
||||
action_id: str,
|
||||
reasoning: str,
|
||||
) -> Action:
|
||||
async with self.Session() as session:
|
||||
action = (
|
||||
await session.scalars(
|
||||
select(ActionModel).filter_by(action_id=action_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if action:
|
||||
action.reasoning = reasoning
|
||||
await session.commit()
|
||||
await session.refresh(action)
|
||||
return Action.model_validate(action)
|
||||
raise NotFoundError(f"Action {action_id}")
|
||||
|
||||
@db_operation("retrieve_action_plan")
|
||||
async def retrieve_action_plan(self, task: Task) -> list[Action]:
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(TaskModel.task_id)
|
||||
.filter(TaskModel.url == task.url)
|
||||
.filter(TaskModel.navigation_goal == task.navigation_goal)
|
||||
.filter(TaskModel.status == TaskStatus.completed)
|
||||
.order_by(TaskModel.created_at.desc())
|
||||
.limit(1)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(ActionModel)
|
||||
.filter(ActionModel.task_id == subquery.c.task_id)
|
||||
.order_by(ActionModel.step_order, ActionModel.action_order, ActionModel.created_at)
|
||||
)
|
||||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
|
@ -1,974 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import Text, and_, cast, exists, func, literal, literal_column, or_, select, update
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.mixins.base import read_retry
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunBlockModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowRunOutputParameterModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_task,
|
||||
convert_to_workflow_run,
|
||||
convert_to_workflow_run_output_parameter,
|
||||
convert_to_workflow_run_parameter,
|
||||
serialize_proxy_location,
|
||||
)
|
||||
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggerType,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
from skyvern.forge.sdk.db._sentinels import _UNSET
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowRunsMixin:
|
||||
"""Database operations for workflow runs."""
|
||||
|
||||
Session: _SessionFactory
|
||||
engine: AsyncEngine
|
||||
debug_enabled: bool
|
||||
|
||||
@db_operation("get_running_workflow_runs_info_globally")
|
||||
async def get_running_workflow_runs_info_globally(
|
||||
self,
|
||||
stale_threshold_hours: int = 24,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get information about running workflow runs across all organizations.
|
||||
Used by cleanup service to determine if cleanup should be skipped.
|
||||
|
||||
Args:
|
||||
stale_threshold_hours: Workflow runs not updated for this many hours are considered stale.
|
||||
|
||||
Returns:
|
||||
Tuple of (active_workflow_count, stale_workflow_count).
|
||||
Active workflows are those updated within the threshold.
|
||||
Stale workflows are those not updated within the threshold but still in running status.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
running_statuses = [
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.paused,
|
||||
]
|
||||
stale_cutoff = datetime.utcnow() - timedelta(hours=stale_threshold_hours)
|
||||
|
||||
# Count active workflow runs (recently updated)
|
||||
active_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.status.in_(running_statuses))
|
||||
.filter(WorkflowRunModel.modified_at >= stale_cutoff)
|
||||
)
|
||||
active_count = (await session.execute(active_query)).scalar_one()
|
||||
|
||||
# Count stale workflow runs (not updated for a long time)
|
||||
stale_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.status.in_(running_statuses))
|
||||
.filter(WorkflowRunModel.modified_at < stale_cutoff)
|
||||
)
|
||||
stale_count = (await session.execute(stale_query)).scalar_one()
|
||||
|
||||
return (active_count, stale_count)
|
||||
|
||||
@db_operation("create_workflow_run")
|
||||
async def create_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
browser_session_id: str | None = None,
|
||||
browser_profile_id: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
browser_address: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
run_with: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
ai_fallback: bool | None = None,
|
||||
code_gen: bool | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
trigger_type: WorkflowRunTriggerType | None = None,
|
||||
workflow_schedule_id: str | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if workflow_run_id is not None:
|
||||
kwargs["workflow_run_id"] = workflow_run_id
|
||||
workflow_run = WorkflowRunModel(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
browser_profile_id=browser_profile_id,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
status="created",
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
browser_address=browser_address,
|
||||
sequential_key=sequential_key,
|
||||
run_with=run_with,
|
||||
debug_session_id=debug_session_id,
|
||||
ai_fallback=ai_fallback,
|
||||
code_gen=code_gen,
|
||||
trigger_type=trigger_type.value if trigger_type else None,
|
||||
workflow_schedule_id=workflow_schedule_id,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
|
||||
@db_operation("update_workflow_run")
|
||||
async def update_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
webhook_failure_reason: str | None = None,
|
||||
ai_fallback_triggered: bool | None = None,
|
||||
job_id: str | None = None,
|
||||
run_with: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
ai_fallback: bool | None = None,
|
||||
depends_on_workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
waiting_for_verification_code: bool | None = None,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
browser_profile_id: str | None | object = _UNSET,
|
||||
browser_address: str | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
failure_category: list[dict[str, Any]] | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).first()
|
||||
if workflow_run:
|
||||
if status:
|
||||
workflow_run.status = status
|
||||
if status and status == WorkflowRunStatus.queued and workflow_run.queued_at is None:
|
||||
workflow_run.queued_at = datetime.utcnow()
|
||||
if status and status == WorkflowRunStatus.running and workflow_run.started_at is None:
|
||||
workflow_run.started_at = datetime.utcnow()
|
||||
if status and status.is_final() and workflow_run.finished_at is None:
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
if failure_reason:
|
||||
workflow_run.failure_reason = failure_reason
|
||||
if webhook_failure_reason is not None:
|
||||
workflow_run.webhook_failure_reason = webhook_failure_reason
|
||||
if ai_fallback_triggered is not None:
|
||||
workflow_run.script_run = {"ai_fallback_triggered": ai_fallback_triggered}
|
||||
if job_id:
|
||||
workflow_run.job_id = job_id
|
||||
if run_with:
|
||||
workflow_run.run_with = run_with
|
||||
if sequential_key:
|
||||
workflow_run.sequential_key = sequential_key
|
||||
if ai_fallback is not None:
|
||||
workflow_run.ai_fallback = ai_fallback
|
||||
if depends_on_workflow_run_id:
|
||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
if browser_address:
|
||||
workflow_run.browser_address = browser_address
|
||||
if extra_http_headers:
|
||||
workflow_run.extra_http_headers = extra_http_headers
|
||||
# 2FA verification code waiting state updates
|
||||
if waiting_for_verification_code is not None:
|
||||
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
workflow_run.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||
# Clear related fields when waiting is set to False
|
||||
workflow_run.verification_code_identifier = None
|
||||
workflow_run.verification_code_polling_started_at = None
|
||||
if browser_profile_id is not _UNSET:
|
||||
workflow_run.browser_profile_id = browser_profile_id
|
||||
if failure_category is not None:
|
||||
workflow_run.failure_category = failure_category
|
||||
await session.commit()
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
else:
|
||||
raise WorkflowRunNotFound(workflow_run_id)
|
||||
|
||||
@db_operation("bulk_update_workflow_runs")
|
||||
async def bulk_update_workflow_runs(
|
||||
self,
|
||||
workflow_run_ids: list[str],
|
||||
status: WorkflowRunStatus | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> None:
|
||||
"""Bulk update workflow runs by their IDs.
|
||||
|
||||
Args:
|
||||
workflow_run_ids: List of workflow run IDs to update
|
||||
status: Optional status to set for all workflow runs
|
||||
failure_reason: Optional failure reason to set for all workflow runs
|
||||
"""
|
||||
if not workflow_run_ids:
|
||||
return
|
||||
|
||||
async with self.Session() as session:
|
||||
update_values = {}
|
||||
if status:
|
||||
update_values["status"] = status.value
|
||||
if failure_reason:
|
||||
update_values["failure_reason"] = failure_reason
|
||||
|
||||
if update_values:
|
||||
update_stmt = (
|
||||
update(WorkflowRunModel)
|
||||
.where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
|
||||
.values(**update_values)
|
||||
)
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("clear_workflow_run_failure_reason")
|
||||
async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if workflow_run:
|
||||
workflow_run.failure_reason = None
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
else:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
@db_operation("get_all_runs")
|
||||
async def get_all_runs(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
include_debugger_runs: bool = False,
|
||||
search_key: str | None = None,
|
||||
) -> list[WorkflowRun | Task]:
|
||||
async with self.Session() as session:
|
||||
# temporary limit to 10 pages
|
||||
if page > 10:
|
||||
return []
|
||||
|
||||
limit = page * page_size
|
||||
|
||||
workflow_run_query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_debugger_runs:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.debug_session_id.is_(None))
|
||||
|
||||
if search_key:
|
||||
key_like = f"%{search_key}%"
|
||||
# Match workflow_run_id directly
|
||||
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
|
||||
# Match parameter key or description (only for non-deleted parameter definitions)
|
||||
param_key_desc_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.join(
|
||||
WorkflowParameterModel,
|
||||
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
|
||||
)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(key_like),
|
||||
WorkflowParameterModel.description.ilike(key_like),
|
||||
)
|
||||
)
|
||||
)
|
||||
# Match run parameter value directly (searches all values regardless of parameter definition status)
|
||||
param_value_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowRunParameterModel.value.ilike(key_like))
|
||||
)
|
||||
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
|
||||
extra_headers_match = and_(
|
||||
WorkflowRunModel.extra_http_headers.isnot(None),
|
||||
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
|
||||
)
|
||||
workflow_run_query = workflow_run_query.where(
|
||||
or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match)
|
||||
)
|
||||
|
||||
if status:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
|
||||
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
|
||||
workflow_run_query_result = (await session.execute(workflow_run_query)).all()
|
||||
workflow_runs = [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_run_query_result
|
||||
]
|
||||
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
|
||||
runs = workflow_runs + tasks
|
||||
|
||||
runs.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
lower = (page - 1) * page_size
|
||||
upper = page * page_size
|
||||
|
||||
return runs[lower:upper]
|
||||
|
||||
@read_retry()
|
||||
async def get_all_runs_v2(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[str] | None = None,
|
||||
search_key: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.Session() as session:
|
||||
effective_status = func.coalesce(WorkflowRunModel.status, TaskRunModel.status)
|
||||
query = (
|
||||
select(
|
||||
TaskRunModel.task_run_id.label("task_run_id"),
|
||||
TaskRunModel.run_id.label("run_id"),
|
||||
TaskRunModel.task_run_type.label("task_run_type"),
|
||||
effective_status.label("status"),
|
||||
TaskRunModel.title.label("title"),
|
||||
TaskRunModel.started_at.label("started_at"),
|
||||
TaskRunModel.finished_at.label("finished_at"),
|
||||
TaskRunModel.created_at.label("created_at"),
|
||||
TaskRunModel.workflow_permanent_id.label("workflow_permanent_id"),
|
||||
TaskRunModel.script_run.label("script_run"),
|
||||
TaskRunModel.searchable_text.label("searchable_text"),
|
||||
)
|
||||
.select_from(TaskRunModel)
|
||||
.outerjoin(
|
||||
WorkflowRunModel,
|
||||
and_(
|
||||
TaskRunModel.task_run_type == RunType.workflow_run,
|
||||
WorkflowRunModel.workflow_run_id == TaskRunModel.run_id,
|
||||
WorkflowRunModel.organization_id == TaskRunModel.organization_id,
|
||||
),
|
||||
)
|
||||
.filter(TaskRunModel.organization_id == organization_id)
|
||||
.filter(TaskRunModel.status.isnot(None))
|
||||
.filter(TaskRunModel.parent_workflow_run_id.is_(None))
|
||||
.filter(TaskRunModel.debug_session_id.is_(None))
|
||||
)
|
||||
|
||||
if status:
|
||||
query = query.filter(effective_status.in_(status))
|
||||
|
||||
if search_key:
|
||||
query = query.filter(TaskRunModel.searchable_text.icontains(search_key, autoescape=True))
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
query = query.order_by(TaskRunModel.created_at.desc()).offset(offset).limit(page_size)
|
||||
|
||||
result = await session.execute(query)
|
||||
return [dict(row) for row in result.mappings().all()]
|
||||
|
||||
@read_retry()
|
||||
@db_operation("get_workflow_run", log_errors=False)
|
||||
async def get_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
job_id: str | None = None,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
|
||||
if organization_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
|
||||
if job_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
|
||||
if status:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
|
||||
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
@db_operation("get_last_queued_workflow_run")
|
||||
async def get_last_queued_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(status=WorkflowRunStatus.queued)
|
||||
if sequential_key:
|
||||
query = query.filter_by(sequential_key=sequential_key)
|
||||
query = query.order_by(WorkflowRunModel.modified_at.desc())
|
||||
workflow_run = (await session.scalars(query)).first()
|
||||
return convert_to_workflow_run(workflow_run) if workflow_run else None
|
||||
|
||||
@db_operation("get_workflow_runs_by_ids")
|
||||
async def get_workflow_runs_by_ids(
|
||||
self,
|
||||
workflow_run_ids: list[str],
|
||||
workflow_permanent_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
|
||||
if workflow_permanent_id:
|
||||
query = query.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
|
||||
|
||||
@db_operation("get_last_running_workflow_run")
|
||||
async def get_last_running_workflow_run(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
sequential_key: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
query = query.filter(WorkflowRunModel.browser_session_id.is_(None))
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(status=WorkflowRunStatus.running)
|
||||
if sequential_key:
|
||||
query = query.filter_by(sequential_key=sequential_key)
|
||||
query = query.filter(
|
||||
WorkflowRunModel.started_at.isnot(None)
|
||||
) # filter out workflow runs that does not have a started_at timestamp
|
||||
query = query.order_by(WorkflowRunModel.started_at.desc())
|
||||
workflow_run = (await session.scalars(query)).first()
|
||||
return convert_to_workflow_run(workflow_run) if workflow_run else None
|
||||
|
||||
@db_operation("get_workflows_depending_on")
|
||||
async def get_workflows_depending_on(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> list[WorkflowRun]:
|
||||
"""
|
||||
Get all workflow runs that depend on the given workflow_run_id.
|
||||
|
||||
Used to find workflows that should be signaled when a workflow completes,
|
||||
for sequential workflow dependency handling.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow_run_id to find dependents for
|
||||
|
||||
Returns:
|
||||
List of WorkflowRun objects that have depends_on_workflow_run_id set to workflow_run_id
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(depends_on_workflow_run_id=workflow_run_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(workflow_run) for workflow_run in workflow_runs]
|
||||
|
||||
@staticmethod
|
||||
def _apply_search_key_filter(query, search_key: str | None): # type: ignore[no-untyped-def]
|
||||
if not search_key:
|
||||
return query
|
||||
key_like = f"%{search_key}%"
|
||||
# Match workflow_run_id directly
|
||||
id_matches = WorkflowRunModel.workflow_run_id.ilike(key_like)
|
||||
# Match parameter key or description (only for non-deleted parameter definitions)
|
||||
# Use EXISTS to avoid duplicate rows and to keep pagination correct
|
||||
param_key_desc_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.join(
|
||||
WorkflowParameterModel,
|
||||
WorkflowParameterModel.workflow_parameter_id == WorkflowRunParameterModel.workflow_parameter_id,
|
||||
)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(key_like),
|
||||
WorkflowParameterModel.description.ilike(key_like),
|
||||
)
|
||||
)
|
||||
)
|
||||
# Match run parameter value directly (searches all values regardless of parameter definition status)
|
||||
param_value_exists = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunParameterModel)
|
||||
.where(WorkflowRunParameterModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(WorkflowRunParameterModel.value.ilike(key_like))
|
||||
)
|
||||
# Match extra HTTP headers (cast JSON to text for search, skip NULLs)
|
||||
extra_headers_match = and_(
|
||||
WorkflowRunModel.extra_http_headers.isnot(None),
|
||||
func.cast(WorkflowRunModel.extra_http_headers, Text()).ilike(key_like),
|
||||
)
|
||||
return query.where(or_(id_matches, param_key_desc_exists, param_value_exists, extra_headers_match))
|
||||
|
||||
def _apply_error_code_filter(self, query, error_code: str | None): # type: ignore[no-untyped-def]
|
||||
if not error_code:
|
||||
return query
|
||||
|
||||
dialect_name = self.engine.dialect.name
|
||||
|
||||
if dialect_name == "sqlite":
|
||||
# Task errors: array of objects like [{"error_code": "timeout", ...}]
|
||||
# Use json_each to iterate + json_extract to match the error_code field
|
||||
error_code_in_tasks = exists(
|
||||
select(1)
|
||||
.select_from(TaskModel)
|
||||
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(func.json_each(TaskModel.errors))
|
||||
.where(func.json_extract(literal_column("json_each.value"), "$.error_code") == error_code)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Block errors: flat array of strings like ["timeout", "network_error"]
|
||||
error_code_in_blocks = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunBlockModel)
|
||||
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(func.json_each(WorkflowRunBlockModel.error_codes))
|
||||
.where(literal_column("json_each.value") == error_code)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# PostgreSQL: native JSONB containment
|
||||
error_code_in_tasks = exists(
|
||||
select(1)
|
||||
.select_from(TaskModel)
|
||||
.where(TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(cast(TaskModel.errors, JSONB).contains(literal([{"error_code": error_code}], type_=JSONB)))
|
||||
)
|
||||
error_code_in_blocks = exists(
|
||||
select(1)
|
||||
.select_from(WorkflowRunBlockModel)
|
||||
.where(WorkflowRunBlockModel.workflow_run_id == WorkflowRunModel.workflow_run_id)
|
||||
.where(cast(WorkflowRunBlockModel.error_codes, JSONB).contains(literal([error_code], type_=JSONB)))
|
||||
)
|
||||
return query.where(or_(error_code_in_tasks, error_code_in_blocks))
|
||||
|
||||
@db_operation("get_workflow_runs")
|
||||
async def get_workflow_runs(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
ordering: tuple[str, str] | None = None,
|
||||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
|
||||
query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
.filter(WorkflowRunModel.parent_workflow_run_id.is_(None))
|
||||
)
|
||||
|
||||
query = self._apply_search_key_filter(query, search_key)
|
||||
query = self._apply_error_code_filter(query, error_code)
|
||||
|
||||
if status:
|
||||
query = query.filter(WorkflowRunModel.status.in_(status))
|
||||
|
||||
allowed_ordering_fields = {
|
||||
"created_at": WorkflowRunModel.created_at,
|
||||
"status": WorkflowRunModel.status,
|
||||
}
|
||||
|
||||
field, direction = ("created_at", "desc")
|
||||
|
||||
if ordering and isinstance(ordering, tuple) and len(ordering) == 2:
|
||||
req_field, req_direction = ordering
|
||||
if req_field in allowed_ordering_fields and req_direction in ("asc", "desc"):
|
||||
field, direction = req_field, req_direction
|
||||
|
||||
order_column = allowed_ordering_fields[field]
|
||||
|
||||
if direction == "asc":
|
||||
query = query.order_by(order_column.asc())
|
||||
else:
|
||||
query = query.order_by(order_column.desc())
|
||||
|
||||
query = query.limit(page_size).offset(db_page * page_size)
|
||||
|
||||
workflow_runs = (await session.execute(query)).all()
|
||||
|
||||
return [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_runs
|
||||
]
|
||||
|
||||
@db_operation("get_workflow_runs_count")
|
||||
async def get_workflow_runs_count(
|
||||
self,
|
||||
organization_id: str,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
) -> int:
|
||||
async with self.Session() as session:
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRunModel)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
)
|
||||
if status:
|
||||
count_query = count_query.filter(WorkflowRunModel.status.in_(status))
|
||||
return (await session.execute(count_query)).scalar_one()
|
||||
|
||||
@db_operation("get_workflow_run_block_errors")
|
||||
async def get_workflow_run_block_errors(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[tuple[list[str], str | None]]:
|
||||
"""Return (error_codes, failure_reason) tuples for blocks with non-null error_codes."""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunBlockModel.error_codes, WorkflowRunBlockModel.failure_reason).filter_by(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
if organization_id is not None:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.where(WorkflowRunBlockModel.error_codes.isnot(None))
|
||||
rows = (await session.execute(query)).all()
|
||||
return [(row.error_codes, row.failure_reason) for row in rows]
|
||||
|
||||
@db_operation("get_workflow_runs_for_workflow_permanent_id")
|
||||
async def get_workflow_runs_for_workflow_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
"""
|
||||
Get runs for a workflow, with optional `search_key` on run ID, parameter key/description/value,
|
||||
or extra HTTP headers.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
db_page = page - 1 # offset logic is 0 based
|
||||
query = (
|
||||
select(WorkflowRunModel, WorkflowModel.title)
|
||||
.join(WorkflowModel, WorkflowModel.workflow_id == WorkflowRunModel.workflow_id)
|
||||
.filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
)
|
||||
query = self._apply_search_key_filter(query, search_key)
|
||||
query = self._apply_error_code_filter(query, error_code)
|
||||
if status:
|
||||
query = query.filter(WorkflowRunModel.status.in_(status))
|
||||
query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
workflow_runs_and_titles_tuples = (await session.execute(query)).all()
|
||||
workflow_runs = [
|
||||
convert_to_workflow_run(run, workflow_title=title, debug_enabled=self.debug_enabled)
|
||||
for run, title in workflow_runs_and_titles_tuples
|
||||
]
|
||||
return workflow_runs
|
||||
|
||||
@db_operation("get_workflow_runs_by_parent_workflow_run_id")
|
||||
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||
self,
|
||||
parent_workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id)
|
||||
if organization_id is not None:
|
||||
query = query.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
workflow_runs = (await session.scalars(query)).all()
|
||||
return [convert_to_workflow_run(run) for run in workflow_runs]
|
||||
|
||||
@db_operation("get_workflow_run_output_parameters")
|
||||
async def get_workflow_run_output_parameters(self, workflow_run_id: str) -> list[WorkflowRunOutputParameter]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameters = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
for parameter in workflow_run_output_parameters
|
||||
]
|
||||
|
||||
@db_operation("get_workflow_run_output_parameter_by_id")
|
||||
async def get_workflow_run_output_parameter_by_id(
|
||||
self, workflow_run_id: str, output_parameter_id: str
|
||||
) -> WorkflowRunOutputParameter | None:
|
||||
async with self.Session() as session:
|
||||
parameter = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
.order_by(WorkflowRunOutputParameterModel.created_at)
|
||||
)
|
||||
).first()
|
||||
|
||||
if parameter:
|
||||
return convert_to_workflow_run_output_parameter(parameter, self.debug_enabled)
|
||||
|
||||
return None
|
||||
|
||||
@db_operation("create_or_update_workflow_run_output_parameter")
|
||||
async def create_or_update_workflow_run_output_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
output_parameter_id: str,
|
||||
value: dict[str, Any] | list | str | None,
|
||||
) -> WorkflowRunOutputParameter:
|
||||
async with self.Session() as session:
|
||||
# check if the workflow run output parameter already exists
|
||||
# if it does, update the value
|
||||
if workflow_run_output_parameter := (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
)
|
||||
).first():
|
||||
LOG.info(
|
||||
"Updating existing workflow run output parameter",
|
||||
workflow_run_id=workflow_run_output_parameter.workflow_run_id,
|
||||
output_parameter_id=workflow_run_output_parameter.output_parameter_id,
|
||||
)
|
||||
workflow_run_output_parameter.value = value
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
# if it does not exist, create a new one
|
||||
workflow_run_output_parameter = WorkflowRunOutputParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=output_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_output_parameter)
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("update_workflow_run_output_parameter")
|
||||
async def update_workflow_run_output_parameter(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
output_parameter_id: str,
|
||||
value: dict[str, Any] | list | str | None,
|
||||
) -> WorkflowRunOutputParameter:
|
||||
async with self.Session() as session:
|
||||
workflow_run_output_parameter = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunOutputParameterModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.filter_by(output_parameter_id=output_parameter_id)
|
||||
)
|
||||
).first()
|
||||
if not workflow_run_output_parameter:
|
||||
raise NotFoundError(
|
||||
f"WorkflowRunOutputParameter not found for {workflow_run_id} and {output_parameter_id}"
|
||||
)
|
||||
workflow_run_output_parameter.value = value
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run_output_parameter)
|
||||
return convert_to_workflow_run_output_parameter(workflow_run_output_parameter, self.debug_enabled)
|
||||
|
||||
@db_operation("create_workflow_run_parameter")
|
||||
async def create_workflow_run_parameter(
|
||||
self, workflow_run_id: str, workflow_parameter: WorkflowParameter, value: Any
|
||||
) -> WorkflowRunParameter:
|
||||
workflow_parameter_id = workflow_parameter.workflow_parameter_id
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameter = WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
session.add(workflow_run_parameter)
|
||||
await session.flush()
|
||||
converted = convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter, workflow_parameter, self.debug_enabled
|
||||
)
|
||||
await session.commit()
|
||||
return converted
|
||||
|
||||
@db_operation("create_workflow_run_parameters")
|
||||
async def create_workflow_run_parameters(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_parameter_values: list[tuple[WorkflowParameter, Any]],
|
||||
) -> list[WorkflowRunParameter]:
|
||||
if not workflow_parameter_values:
|
||||
return []
|
||||
|
||||
workflow_run_parameters = [
|
||||
WorkflowRunParameterModel(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_id=workflow_parameter.workflow_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
for workflow_parameter, value in workflow_parameter_values
|
||||
]
|
||||
|
||||
async with self.Session() as session:
|
||||
session.add_all(workflow_run_parameters)
|
||||
await session.flush()
|
||||
converted = [
|
||||
convert_to_workflow_run_parameter(workflow_run_parameter, workflow_parameter, self.debug_enabled)
|
||||
for workflow_run_parameter, (workflow_parameter, _) in zip(
|
||||
workflow_run_parameters, workflow_parameter_values, strict=True
|
||||
)
|
||||
]
|
||||
await session.commit()
|
||||
return converted
|
||||
|
||||
@db_operation("get_workflow_run_parameters")
|
||||
async def get_workflow_run_parameters(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
async with self.Session() as session:
|
||||
workflow_run_parameters = (
|
||||
await session.scalars(select(WorkflowRunParameterModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).all()
|
||||
results = []
|
||||
for workflow_run_parameter in workflow_run_parameters:
|
||||
workflow_parameter = await self.get_workflow_parameter(workflow_run_parameter.workflow_parameter_id) # type: ignore[attr-defined]
|
||||
if not workflow_parameter:
|
||||
raise WorkflowParameterNotFound(workflow_parameter_id=workflow_run_parameter.workflow_parameter_id)
|
||||
results.append(
|
||||
(
|
||||
workflow_parameter,
|
||||
convert_to_workflow_run_parameter(
|
||||
workflow_run_parameter,
|
||||
workflow_parameter,
|
||||
self.debug_enabled,
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@db_operation("_get_last_workflow_run_by_filter")
|
||||
async def _get_last_workflow_run_by_filter(
|
||||
self,
|
||||
organization_id: str | None = None,
|
||||
**filters: str,
|
||||
) -> WorkflowRun | None:
|
||||
"""Get the last queued or running workflow run matching the given column filters.
|
||||
|
||||
Used for browser_session_id and browser_address sequential execution.
|
||||
"""
|
||||
async with self.Session() as session:
|
||||
query = select(WorkflowRunModel).filter_by(**filters)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
|
||||
# check if there's a queued run
|
||||
queue_query = query.filter_by(status=WorkflowRunStatus.queued)
|
||||
queue_query = queue_query.order_by(WorkflowRunModel.modified_at.desc())
|
||||
workflow_run = (await session.scalars(queue_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
|
||||
# check if there's a running run
|
||||
running_query = query.filter_by(status=WorkflowRunStatus.running)
|
||||
running_query = running_query.filter(WorkflowRunModel.started_at.isnot(None))
|
||||
running_query = running_query.order_by(WorkflowRunModel.started_at.desc())
|
||||
workflow_run = (await session.scalars(running_query)).first()
|
||||
if workflow_run:
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
@db_operation("get_last_workflow_run_for_browser_address")
|
||||
async def get_last_workflow_run_for_browser_address(
|
||||
self,
|
||||
browser_address: str,
|
||||
organization_id: str | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
return await self._get_last_workflow_run_by_filter(
|
||||
organization_id=organization_id,
|
||||
browser_address=browser_address,
|
||||
)
|
||||
|
|
@ -1,695 +0,0 @@
|
|||
# TODO: Standardize soft-delete filtering — some methods use exclude_deleted(), others use
|
||||
# inline .filter(WorkflowModel.deleted_at.is_(None)). Migrate all to exclude_deleted()
|
||||
# where possible (some join/filter-order cases require inline).
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import exists, func, or_, select, update
|
||||
|
||||
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
from skyvern.forge.sdk.db._soft_delete import exclude_deleted
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
AWSSecretParameterModel,
|
||||
AzureVaultCredentialParameterModel,
|
||||
BitwardenCreditCardDataParameterModel,
|
||||
BitwardenLoginCredentialParameterModel,
|
||||
BitwardenSensitiveInformationParameterModel,
|
||||
CredentialParameterModel,
|
||||
FolderModel,
|
||||
OnePasswordCredentialParameterModel,
|
||||
OutputParameterModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
WorkflowRunModel,
|
||||
WorkflowTemplateModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_workflow, serialize_proxy_location
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
from skyvern.schemas.workflows import WorkflowStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class WorkflowsMixin:
|
||||
"""Database operations for workflow management."""
|
||||
|
||||
Session: _SessionFactory
|
||||
debug_enabled: bool
|
||||
|
||||
@db_operation("create_workflow")
|
||||
async def create_workflow(
|
||||
self,
|
||||
title: str,
|
||||
workflow_definition: dict[str, Any],
|
||||
organization_id: str | None = None,
|
||||
description: str | None = None,
|
||||
proxy_location: ProxyLocationInput = None,
|
||||
webhook_callback_url: str | None = None,
|
||||
max_screenshot_scrolling_times: int | None = None,
|
||||
extra_http_headers: dict[str, str] | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
persist_browser_session: bool = False,
|
||||
model: dict[str, Any] | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
version: int | None = None,
|
||||
is_saved_task: bool = False,
|
||||
status: WorkflowStatus = WorkflowStatus.published,
|
||||
run_with: str | None = None,
|
||||
ai_fallback: bool = True,
|
||||
cache_key: str | None = None,
|
||||
adaptive_caching: bool = False,
|
||||
code_version: int | None = None,
|
||||
generate_script_on_terminal: bool = False,
|
||||
run_sequentially: bool = False,
|
||||
sequential_key: str | None = None,
|
||||
folder_id: str | None = None,
|
||||
) -> Workflow:
|
||||
async with self.Session() as session:
|
||||
workflow = WorkflowModel(
|
||||
organization_id=organization_id,
|
||||
title=title,
|
||||
description=description,
|
||||
workflow_definition=workflow_definition,
|
||||
proxy_location=serialize_proxy_location(proxy_location),
|
||||
webhook_callback_url=webhook_callback_url,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
max_screenshot_scrolling_times=max_screenshot_scrolling_times,
|
||||
extra_http_headers=extra_http_headers,
|
||||
persist_browser_session=persist_browser_session,
|
||||
model=model,
|
||||
is_saved_task=is_saved_task,
|
||||
status=status,
|
||||
run_with=run_with,
|
||||
ai_fallback=ai_fallback,
|
||||
cache_key=cache_key or DEFAULT_SCRIPT_RUN_ID,
|
||||
adaptive_caching=adaptive_caching,
|
||||
code_version=code_version,
|
||||
generate_script_on_terminal=generate_script_on_terminal,
|
||||
run_sequentially=run_sequentially,
|
||||
sequential_key=sequential_key,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
if workflow_permanent_id:
|
||||
workflow.workflow_permanent_id = workflow_permanent_id
|
||||
if version:
|
||||
workflow.version = version
|
||||
session.add(workflow)
|
||||
|
||||
# Update folder's modified_at if folder_id is provided
|
||||
if folder_id:
|
||||
# Validate folder exists and belongs to the same organization
|
||||
folder_stmt = (
|
||||
select(FolderModel)
|
||||
.where(FolderModel.folder_id == folder_id)
|
||||
.where(FolderModel.organization_id == organization_id)
|
||||
.where(FolderModel.deleted_at.is_(None))
|
||||
)
|
||||
folder_model = await session.scalar(folder_stmt)
|
||||
if not folder_model:
|
||||
raise ValueError(
|
||||
f"Folder {folder_id} not found or does not belong to organization {organization_id}"
|
||||
)
|
||||
folder_model.modified_at = datetime.utcnow()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
return convert_to_workflow(workflow, self.debug_enabled)
|
||||
|
||||
@db_operation("soft_delete_workflow_by_id")
|
||||
async def soft_delete_workflow_by_id(self, workflow_id: str, organization_id: str) -> None:
|
||||
async with self.Session() as session:
|
||||
# soft delete the workflow by setting the deleted_at field to the current time
|
||||
update_deleted_at_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_id == workflow_id)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.utcnow())
|
||||
)
|
||||
await session.execute(update_deleted_at_query)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_workflow")
|
||||
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
|
||||
async with self.Session() as session:
|
||||
get_workflow_query = exclude_deleted(
|
||||
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
|
||||
)
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_by_permanent_id")
|
||||
async def get_workflow_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
version: int | None = None,
|
||||
ignore_version: int | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None:
|
||||
get_workflow_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if filter_deleted:
|
||||
get_workflow_query = exclude_deleted(get_workflow_query, WorkflowModel)
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if version:
|
||||
get_workflow_query = get_workflow_query.filter_by(version=version)
|
||||
if ignore_version:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowModel.version != ignore_version)
|
||||
get_workflow_query = get_workflow_query.order_by(WorkflowModel.version.desc())
|
||||
async with self.Session() as session:
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_for_workflow_run")
|
||||
async def get_workflow_for_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> Workflow | None:
|
||||
get_workflow_query = select(WorkflowModel)
|
||||
|
||||
if filter_deleted:
|
||||
# Can't use exclude_deleted() here — it appends a WHERE clause, but the
|
||||
# deleted_at filter must precede the JOIN to avoid filtering on the wrong table.
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None))
|
||||
|
||||
get_workflow_query = get_workflow_query.join(
|
||||
WorkflowRunModel,
|
||||
WorkflowRunModel.workflow_id == WorkflowModel.workflow_id,
|
||||
)
|
||||
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.organization_id == organization_id)
|
||||
|
||||
get_workflow_query = get_workflow_query.filter(WorkflowRunModel.workflow_run_id == workflow_run_id)
|
||||
async with self.Session() as session:
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
return None
|
||||
|
||||
@db_operation("get_workflow_versions_by_permanent_id")
|
||||
async def get_workflow_versions_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
filter_deleted: bool = True,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all versions of a workflow by its permanent ID, ordered by version descending (newest first).
|
||||
"""
|
||||
get_workflows_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
if filter_deleted:
|
||||
get_workflows_query = get_workflows_query.filter(WorkflowModel.deleted_at.is_(None))
|
||||
if organization_id:
|
||||
get_workflows_query = get_workflows_query.filter_by(organization_id=organization_id)
|
||||
get_workflows_query = get_workflows_query.order_by(WorkflowModel.version.desc())
|
||||
|
||||
async with self.Session() as session:
|
||||
workflows = (await session.scalars(get_workflows_query)).all()
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("get_workflows_by_permanent_ids")
|
||||
async def get_workflows_by_permanent_ids(
|
||||
self,
|
||||
workflow_permanent_ids: list[str],
|
||||
organization_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
title: str = "",
|
||||
statuses: list[WorkflowStatus] | None = None,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all workflows with the latest version for the organization.
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.workflow_permanent_id.in_(workflow_permanent_ids))
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
main_query = select(WorkflowModel).join(
|
||||
subquery,
|
||||
(WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
if organization_id:
|
||||
main_query = main_query.where(WorkflowModel.organization_id == organization_id)
|
||||
if title:
|
||||
main_query = main_query.where(WorkflowModel.title.ilike(f"%{title}%"))
|
||||
if statuses:
|
||||
main_query = main_query.where(WorkflowModel.status.in_(statuses))
|
||||
main_query = (
|
||||
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
)
|
||||
workflows = (await session.scalars(main_query)).all()
|
||||
|
||||
# Map template status by permanent_id so API responses surface is_template
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("get_workflows_by_organization_id")
|
||||
async def get_workflows_by_organization_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
only_saved_tasks: bool = False,
|
||||
only_workflows: bool = False,
|
||||
only_templates: bool = False,
|
||||
search_key: str | None = None,
|
||||
folder_id: str | None = None,
|
||||
statuses: list[WorkflowStatus] | None = None,
|
||||
) -> list[Workflow]:
|
||||
"""
|
||||
Get all workflows with the latest version for the organization.
|
||||
|
||||
Search semantics:
|
||||
- If `search_key` is provided, its value is used as a unified search term for
|
||||
`workflows.title`, `folders.title`, and workflow parameter metadata (key, description, and default_value).
|
||||
- If `search_key` is not provided, no search filtering is applied.
|
||||
- Parameter metadata search excludes soft-deleted parameter rows across parameter tables.
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError(f"Page must be greater than 0, got {page}")
|
||||
db_page = page - 1
|
||||
async with self.Session() as session:
|
||||
subquery = (
|
||||
select(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
func.max(WorkflowModel.version).label("max_version"),
|
||||
)
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
.group_by(
|
||||
WorkflowModel.organization_id,
|
||||
WorkflowModel.workflow_permanent_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
main_query = (
|
||||
select(WorkflowModel)
|
||||
.join(
|
||||
subquery,
|
||||
(WorkflowModel.organization_id == subquery.c.organization_id)
|
||||
& (WorkflowModel.workflow_permanent_id == subquery.c.workflow_permanent_id)
|
||||
& (WorkflowModel.version == subquery.c.max_version),
|
||||
)
|
||||
.outerjoin(
|
||||
FolderModel,
|
||||
(WorkflowModel.folder_id == FolderModel.folder_id)
|
||||
& (FolderModel.organization_id == WorkflowModel.organization_id),
|
||||
)
|
||||
)
|
||||
if only_saved_tasks:
|
||||
main_query = main_query.where(WorkflowModel.is_saved_task.is_(True))
|
||||
elif only_workflows:
|
||||
main_query = main_query.where(WorkflowModel.is_saved_task.is_(False))
|
||||
if only_templates:
|
||||
# Filter by workflow_templates table (templates at permanent_id level)
|
||||
template_subquery = select(WorkflowTemplateModel.workflow_permanent_id).where(
|
||||
WorkflowTemplateModel.organization_id == organization_id,
|
||||
WorkflowTemplateModel.deleted_at.is_(None),
|
||||
)
|
||||
main_query = main_query.where(WorkflowModel.workflow_permanent_id.in_(template_subquery))
|
||||
if statuses:
|
||||
main_query = main_query.where(WorkflowModel.status.in_(statuses))
|
||||
if folder_id:
|
||||
main_query = main_query.where(WorkflowModel.folder_id == folder_id)
|
||||
if search_key:
|
||||
search_like = f"%{search_key}%"
|
||||
title_like = WorkflowModel.title.ilike(search_like)
|
||||
folder_title_like = FolderModel.title.ilike(search_like)
|
||||
|
||||
parameter_filters = [
|
||||
# WorkflowParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(WorkflowParameterModel)
|
||||
.where(WorkflowParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(WorkflowParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
WorkflowParameterModel.key.ilike(search_like),
|
||||
WorkflowParameterModel.description.ilike(search_like),
|
||||
WorkflowParameterModel.default_value.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# OutputParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(OutputParameterModel)
|
||||
.where(OutputParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(OutputParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
OutputParameterModel.key.ilike(search_like),
|
||||
OutputParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# AWSSecretParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(AWSSecretParameterModel)
|
||||
.where(AWSSecretParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(AWSSecretParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
AWSSecretParameterModel.key.ilike(search_like),
|
||||
AWSSecretParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenLoginCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenLoginCredentialParameterModel)
|
||||
.where(BitwardenLoginCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenLoginCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenLoginCredentialParameterModel.key.ilike(search_like),
|
||||
BitwardenLoginCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenSensitiveInformationParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenSensitiveInformationParameterModel)
|
||||
.where(BitwardenSensitiveInformationParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenSensitiveInformationParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenSensitiveInformationParameterModel.key.ilike(search_like),
|
||||
BitwardenSensitiveInformationParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# BitwardenCreditCardDataParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(BitwardenCreditCardDataParameterModel)
|
||||
.where(BitwardenCreditCardDataParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(BitwardenCreditCardDataParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
BitwardenCreditCardDataParameterModel.key.ilike(search_like),
|
||||
BitwardenCreditCardDataParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# OnePasswordCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(OnePasswordCredentialParameterModel)
|
||||
.where(OnePasswordCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(OnePasswordCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
OnePasswordCredentialParameterModel.key.ilike(search_like),
|
||||
OnePasswordCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# AzureVaultCredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(AzureVaultCredentialParameterModel)
|
||||
.where(AzureVaultCredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(AzureVaultCredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
AzureVaultCredentialParameterModel.key.ilike(search_like),
|
||||
AzureVaultCredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
# CredentialParameterModel
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(CredentialParameterModel)
|
||||
.where(CredentialParameterModel.workflow_id == WorkflowModel.workflow_id)
|
||||
.where(CredentialParameterModel.deleted_at.is_(None))
|
||||
.where(
|
||||
or_(
|
||||
CredentialParameterModel.key.ilike(search_like),
|
||||
CredentialParameterModel.description.ilike(search_like),
|
||||
)
|
||||
)
|
||||
),
|
||||
]
|
||||
main_query = main_query.where(or_(title_like, folder_title_like, or_(*parameter_filters)))
|
||||
main_query = (
|
||||
main_query.order_by(WorkflowModel.created_at.desc()).limit(page_size).offset(db_page * page_size)
|
||||
)
|
||||
workflows = (await session.scalars(main_query)).all()
|
||||
template_permanent_ids: set[str] = set()
|
||||
if workflows and organization_id:
|
||||
template_permanent_ids = await self.get_org_template_permanent_ids(organization_id)
|
||||
|
||||
return [
|
||||
convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=workflow.workflow_permanent_id in template_permanent_ids,
|
||||
)
|
||||
for workflow in workflows
|
||||
]
|
||||
|
||||
@db_operation("update_workflow")
|
||||
async def update_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
organization_id: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
workflow_definition: dict[str, Any] | None = None,
|
||||
version: int | None = None,
|
||||
run_with: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
status: str | None = None,
|
||||
import_error: str | None = None,
|
||||
) -> Workflow:
|
||||
async with self.Session() as session:
|
||||
get_workflow_query = exclude_deleted(
|
||||
select(WorkflowModel).filter_by(workflow_id=workflow_id), WorkflowModel
|
||||
)
|
||||
if organization_id:
|
||||
get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
|
||||
if workflow := (await session.scalars(get_workflow_query)).first():
|
||||
if title is not None:
|
||||
workflow.title = title
|
||||
if description is not None:
|
||||
workflow.description = description
|
||||
if workflow_definition is not None:
|
||||
workflow.workflow_definition = workflow_definition
|
||||
if version is not None:
|
||||
workflow.version = version
|
||||
if run_with is not None:
|
||||
workflow.run_with = run_with
|
||||
if cache_key is not None:
|
||||
workflow.cache_key = cache_key
|
||||
if status is not None:
|
||||
workflow.status = status
|
||||
if import_error is not None:
|
||||
workflow.import_error = import_error
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
is_template = (
|
||||
await self.is_workflow_template(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if organization_id
|
||||
else False
|
||||
)
|
||||
return convert_to_workflow(
|
||||
workflow,
|
||||
self.debug_enabled,
|
||||
is_template=is_template,
|
||||
)
|
||||
else:
|
||||
raise NotFoundError("Workflow not found")
|
||||
|
||||
@db_operation("soft_delete_workflow_by_permanent_id")
|
||||
async def soft_delete_workflow_by_permanent_id(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
# soft delete the workflow by setting the deleted_at field
|
||||
update_deleted_at_query = (
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowModel.deleted_at.is_(None))
|
||||
)
|
||||
if organization_id:
|
||||
update_deleted_at_query = update_deleted_at_query.filter_by(organization_id=organization_id)
|
||||
update_deleted_at_query = update_deleted_at_query.values(deleted_at=datetime.utcnow())
|
||||
await session.execute(update_deleted_at_query)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("add_workflow_template")
|
||||
async def add_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""Add a workflow to the templates table."""
|
||||
async with self.Session() as session:
|
||||
existing = (
|
||||
await session.scalars(
|
||||
select(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
existing.deleted_at = None
|
||||
await session.commit()
|
||||
return
|
||||
template = WorkflowTemplateModel(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session.add(template)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("remove_workflow_template")
|
||||
async def remove_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""Soft delete a workflow from the templates table."""
|
||||
async with self.Session() as session:
|
||||
update_deleted_at_query = (
|
||||
update(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
.values(deleted_at=datetime.utcnow())
|
||||
)
|
||||
await session.execute(update_deleted_at_query)
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_org_template_permanent_ids")
|
||||
async def get_org_template_permanent_ids(
|
||||
self,
|
||||
organization_id: str,
|
||||
) -> set[str]:
|
||||
"""Get all workflow_permanent_ids that are templates for an organization."""
|
||||
async with self.Session() as session:
|
||||
result = await session.scalars(
|
||||
select(WorkflowTemplateModel.workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
)
|
||||
return set(result.all())
|
||||
|
||||
@db_operation("is_workflow_template")
|
||||
async def is_workflow_template(
|
||||
self,
|
||||
workflow_permanent_id: str,
|
||||
organization_id: str,
|
||||
) -> bool:
|
||||
"""Check if a workflow is marked as a template."""
|
||||
async with self.Session() as session:
|
||||
result = (
|
||||
await session.scalars(
|
||||
select(WorkflowTemplateModel)
|
||||
.where(WorkflowTemplateModel.workflow_permanent_id == workflow_permanent_id)
|
||||
.where(WorkflowTemplateModel.organization_id == organization_id)
|
||||
.where(WorkflowTemplateModel.deleted_at.is_(None))
|
||||
)
|
||||
).first()
|
||||
return result is not None
|
||||
|
|
@ -29,7 +29,9 @@ async def await_browser_session(
|
|||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
while True:
|
||||
persistent_browser_session = await db.get_persistent_browser_session(session_id, organization_id)
|
||||
persistent_browser_session = await db.browser_sessions.get_persistent_browser_session(
|
||||
session_id, organization_id
|
||||
)
|
||||
if persistent_browser_session is None:
|
||||
raise Exception(f"Persistent browser session not found for {session_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,10 @@ from skyvern.forge.sdk.db.base_repository import BaseRepository
|
|||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
BrowserProfileModel,
|
||||
DebugSessionModel,
|
||||
PersistentBrowserSessionModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import serialize_proxy_location
|
||||
from skyvern.forge.sdk.schemas.browser_profiles import BrowserProfile
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
|
||||
Extensions,
|
||||
PersistentBrowserSession,
|
||||
|
|
@ -457,19 +455,3 @@ class BrowserSessionsRepository(BaseRepository):
|
|||
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@db_operation("get_debug_session_by_browser_session_id")
|
||||
async def get_debug_session_by_browser_session_id(
|
||||
self,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
model = (await session.scalars(query)).first()
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
|
|
|||
|
|
@ -111,27 +111,20 @@ class DebugRepository(BaseRepository):
|
|||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("get_latest_debug_session_for_user")
|
||||
async def get_latest_debug_session_for_user(
|
||||
@db_operation("get_debug_session_by_browser_session_id")
|
||||
async def get_debug_session_by_browser_session_id(
|
||||
self,
|
||||
*,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
user_id: str,
|
||||
workflow_permanent_id: str,
|
||||
) -> DebugSession | None:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(DebugSessionModel)
|
||||
.filter_by(browser_session_id=browser_session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
.filter_by(status="created")
|
||||
.filter_by(user_id=user_id)
|
||||
.filter_by(workflow_permanent_id=workflow_permanent_id)
|
||||
.order_by(DebugSessionModel.created_at.desc())
|
||||
)
|
||||
|
||||
model = (await session.scalars(query)).first()
|
||||
|
||||
return DebugSession.model_validate(model) if model else None
|
||||
|
||||
@db_operation("get_debug_session_by_id")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, delete, distinct, func, select, update
|
||||
from sqlalchemy import and_, delete, distinct, func, or_, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
|
|
@ -699,12 +699,18 @@ class ScriptsRepository(BaseRepository):
|
|||
page_size: int = 50,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> tuple[list[WorkflowRunModel], int, dict[str, int]]:
|
||||
"""Get workflow runs associated with a script, with total count and status counts.
|
||||
) -> tuple[list[WorkflowRunModel], int, dict[str, int], float | None]:
|
||||
"""Get workflow runs associated with a script, with total count, status counts,
|
||||
and average AI fallbacks per run.
|
||||
|
||||
Returns (runs, total_count, status_counts) where runs is limited by page_size,
|
||||
total_count is derived from the status_counts GROUP BY, and status_counts is a
|
||||
GROUP BY aggregation of statuses across all runs.
|
||||
Includes actual script runs (run_with='code', 'code_v2', or NULL), excluding
|
||||
explicit agent runs. run_with is NULL when the workflow ran in auto mode and
|
||||
should_run_script() resolved to code via fallback (e.g. code_version >= 1).
|
||||
|
||||
Returns (runs, total_count, status_counts, avg_fallbacks_per_run) where runs
|
||||
is limited by page_size, total_count is derived from the status_counts GROUP BY,
|
||||
status_counts is a GROUP BY aggregation of statuses across all runs, and
|
||||
avg_fallbacks_per_run is the average number of fallback episodes per run.
|
||||
|
||||
If created_after/created_before are provided, filters by the workflow_script
|
||||
entry's created_at (not the run's created_at), scoping to the version that
|
||||
|
|
@ -730,10 +736,19 @@ class ScriptsRepository(BaseRepository):
|
|||
WorkflowScriptModel.created_at < created_before,
|
||||
)
|
||||
|
||||
# Base filter for workflow runs
|
||||
# Base filter for workflow runs - only include actual script runs.
|
||||
# run_with may be NULL when the workflow ran in auto mode and
|
||||
# should_run_script() resolved to code mode via fallback (e.g.
|
||||
# code_version >= 1 or adaptive_caching). NULL is therefore
|
||||
# treated as a code run here; explicit "agent" runs are excluded
|
||||
# by not appearing in the workflow_scripts join above.
|
||||
base_filters = [
|
||||
WorkflowRunModel.workflow_run_id.in_(run_ids_subquery),
|
||||
WorkflowRunModel.organization_id == organization_id,
|
||||
or_(
|
||||
WorkflowRunModel.run_with.in_(["code", "code_v2"]),
|
||||
WorkflowRunModel.run_with.is_(None),
|
||||
),
|
||||
]
|
||||
|
||||
# Count statuses via GROUP BY (also gives us total_count)
|
||||
|
|
@ -744,7 +759,7 @@ class ScriptsRepository(BaseRepository):
|
|||
total_count = sum(status_counts.values())
|
||||
|
||||
if total_count == 0:
|
||||
return [], 0, {}
|
||||
return [], 0, {}, None
|
||||
|
||||
# Get the actual workflow runs (paginated)
|
||||
runs_query = (
|
||||
|
|
@ -755,7 +770,27 @@ class ScriptsRepository(BaseRepository):
|
|||
)
|
||||
runs = list((await session.scalars(runs_query)).all())
|
||||
|
||||
return runs, total_count, status_counts
|
||||
# Compute average AI fallbacks per run over the last 20 runs.
|
||||
max_fallback_sample = 20
|
||||
recent_run_ids = (
|
||||
select(WorkflowRunModel.workflow_run_id)
|
||||
.filter(*base_filters)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
.limit(max_fallback_sample)
|
||||
)
|
||||
total_fallbacks_result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(ScriptFallbackEpisodeModel)
|
||||
.filter(
|
||||
ScriptFallbackEpisodeModel.workflow_run_id.in_(recent_run_ids),
|
||||
ScriptFallbackEpisodeModel.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
total_fallbacks = total_fallbacks_result.scalar() or 0
|
||||
sample_size = min(total_count, max_fallback_sample)
|
||||
avg_fallbacks_per_run = round(total_fallbacks / sample_size, 2)
|
||||
|
||||
return runs, total_count, status_counts, avg_fallbacks_per_run
|
||||
|
||||
@db_operation("get_script_run_stats")
|
||||
async def get_script_run_stats(
|
||||
|
|
|
|||
|
|
@ -36,14 +36,14 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
|||
if organization is None:
|
||||
raise OrganizationNotFound(organization_id)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
step = await app.DATABASE.tasks.create_step(
|
||||
task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
task = await app.DATABASE.update_task(
|
||||
task = await app.DATABASE.tasks.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.running,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -51,7 +51,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
|||
|
||||
close_browser_on_completion = browser_session_id is None and not task.browser_address
|
||||
|
||||
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
|
||||
run_obj = await app.DATABASE.tasks.get_run(run_id=task_id, organization_id=organization_id)
|
||||
engine = RunEngine.skyvern_v1
|
||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||
engine = RunEngine.openai_cua
|
||||
|
|
@ -136,17 +136,17 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
|||
if organization is None:
|
||||
raise OrganizationNotFound(organization_id)
|
||||
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_v2_id=task_v2_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(task_v2_id=task_v2_id, organization_id=organization_id)
|
||||
if not task_v2 or not task_v2.workflow_run_id:
|
||||
raise ValueError("No task v2 or no workflow run associated with task v2")
|
||||
|
||||
# mark task v2 as queued
|
||||
await app.DATABASE.update_task_v2(
|
||||
await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2_id,
|
||||
status=TaskV2Status.queued,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=task_v2.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ async def _save_log_artifacts(
|
|||
return
|
||||
log_json = json.dumps(log, cls=SkyvernJSONLogEncoder, indent=2)
|
||||
|
||||
log_artifact = await app.DATABASE.get_artifact_by_entity_id(
|
||||
log_artifact = await app.DATABASE.artifacts.get_artifact_by_entity_id(
|
||||
artifact_type=ArtifactType.SKYVERN_LOG_RAW,
|
||||
step_id=step_id,
|
||||
task_id=task_id,
|
||||
|
|
@ -158,7 +158,7 @@ async def _save_log_artifacts(
|
|||
|
||||
formatted_log = SkyvernLogEncoder.encode(log)
|
||||
|
||||
formatted_log_artifact = await app.DATABASE.get_artifact_by_entity_id(
|
||||
formatted_log_artifact = await app.DATABASE.artifacts.get_artifact_by_entity_id(
|
||||
artifact_type=ArtifactType.SKYVERN_LOG,
|
||||
step_id=step_id,
|
||||
task_id=task_id,
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ async def run_task(
|
|||
max_steps_override=run_request.max_steps,
|
||||
browser_session_id=run_request.browser_session_id,
|
||||
)
|
||||
refreshed_task_v2 = await app.DATABASE.get_task_v2(
|
||||
refreshed_task_v2 = await app.DATABASE.observer.get_task_v2(
|
||||
task_v2_id=task_v2.observer_cruise_id, organization_id=current_org.organization_id
|
||||
)
|
||||
task_v2 = refreshed_task_v2 if refreshed_task_v2 else task_v2
|
||||
|
|
@ -1416,7 +1416,7 @@ async def get_artifact(
|
|||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> Artifact:
|
||||
analytics.capture("skyvern-oss-artifact-get")
|
||||
artifact = await app.DATABASE.get_artifact_by_id(
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
artifact_id=artifact_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -1488,14 +1488,14 @@ async def get_artifact_content(
|
|||
status_code=http_status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or expired artifact URL",
|
||||
)
|
||||
artifact = await app.DATABASE.get_artifact_by_id_no_org(artifact_id=artifact_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id_no_org(artifact_id=artifact_id)
|
||||
else:
|
||||
# Standard org-auth path (existing behaviour).
|
||||
current_org = await org_auth_service.get_current_org(
|
||||
x_api_key=x_api_key,
|
||||
authorization=authorization,
|
||||
)
|
||||
artifact = await app.DATABASE.get_artifact_by_id(
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
artifact_id=artifact_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -1542,7 +1542,7 @@ async def get_run_artifacts(
|
|||
) -> Response:
|
||||
analytics.capture("skyvern-oss-run-artifacts-get")
|
||||
# Get artifacts as a list (not grouped by type)
|
||||
artifacts = await app.DATABASE.get_artifacts_for_run(
|
||||
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
|
||||
run_id=run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
artifact_types=artifact_type,
|
||||
|
|
@ -1642,7 +1642,9 @@ async def get_run_timeline(
|
|||
|
||||
# Handle task_v2 runs by getting their associated workflow_run_id
|
||||
if run_response.run_type == RunType.task_v2:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_v2_id=run_id, organization_id=current_org.organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(
|
||||
task_v2_id=run_id, organization_id=current_org.organization_id
|
||||
)
|
||||
if not task_v2:
|
||||
raise HTTPException(
|
||||
status_code=http_status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -1902,7 +1904,7 @@ async def cancel_task(
|
|||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> None:
|
||||
analytics.capture("skyvern-oss-agent-task-get")
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task_obj:
|
||||
raise HTTPException(
|
||||
status_code=http_status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -1914,7 +1916,7 @@ async def cancel_task(
|
|||
|
||||
|
||||
async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api_key: str | None = None) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1929,7 +1931,7 @@ async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api
|
|||
await app.PERSISTENT_SESSIONS_MANAGER.release_browser_session(workflow_run.browser_session_id, organization_id)
|
||||
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1949,7 +1951,7 @@ async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api
|
|||
|
||||
|
||||
async def _continue_workflow_run(workflow_run_id: str, organization_id: str) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
status=WorkflowRunStatus.paused,
|
||||
|
|
@ -2028,7 +2030,7 @@ async def retry_webhook(
|
|||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> TaskResponse:
|
||||
analytics.capture("skyvern-oss-agent-task-retry-webhook")
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=current_org.organization_id)
|
||||
if not task_obj:
|
||||
raise HTTPException(
|
||||
status_code=http_status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -2036,7 +2038,7 @@ async def retry_webhook(
|
|||
)
|
||||
|
||||
# get latest step
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=current_org.organization_id)
|
||||
latest_step = await app.DATABASE.tasks.get_latest_step(task_id, organization_id=current_org.organization_id)
|
||||
if not latest_step:
|
||||
return await app.agent.build_task_response(task=task_obj)
|
||||
|
||||
|
|
@ -2088,7 +2090,7 @@ async def get_tasks(
|
|||
status_code=http_status.HTTP_400_BAD_REQUEST,
|
||||
detail="only_standalone_tasks and workflow_run_id cannot be used together",
|
||||
)
|
||||
tasks = await app.DATABASE.get_tasks(
|
||||
tasks = await app.DATABASE.tasks.get_tasks(
|
||||
page,
|
||||
page_size,
|
||||
task_status=task_status,
|
||||
|
|
@ -2137,7 +2139,7 @@ async def get_runs(
|
|||
if page > 10:
|
||||
return []
|
||||
|
||||
runs = await app.DATABASE.get_all_runs(
|
||||
runs = await app.DATABASE.workflow_runs.get_all_runs(
|
||||
current_org.organization_id, page=page, page_size=page_size, status=status, search_key=search_key
|
||||
)
|
||||
return ORJSONResponse([run.model_dump() for run in runs])
|
||||
|
|
@ -2174,7 +2176,7 @@ async def get_runs_v2(
|
|||
) -> Response:
|
||||
analytics.capture("skyvern-oss-agent-runs-v2-get")
|
||||
|
||||
rows = await app.DATABASE.get_all_runs_v2(
|
||||
rows = await app.DATABASE.workflow_runs.get_all_runs_v2(
|
||||
current_org.organization_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -2208,7 +2210,7 @@ async def get_steps(
|
|||
:return: List of steps for a task with pagination.
|
||||
"""
|
||||
analytics.capture("skyvern-oss-agent-task-steps-get")
|
||||
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
|
||||
steps = await app.DATABASE.tasks.get_task_steps(task_id, organization_id=current_org.organization_id)
|
||||
return ORJSONResponse([step.model_dump(exclude_none=True) for step in steps])
|
||||
|
||||
|
||||
|
|
@ -2255,7 +2257,10 @@ async def get_artifacts(
|
|||
params = {
|
||||
entity_type_to_param[entity_type]: entity_id,
|
||||
}
|
||||
artifacts = await app.DATABASE.get_artifacts_by_entity_id(organization_id=current_org.organization_id, **params) # type: ignore
|
||||
artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
|
||||
organization_id=current_org.organization_id,
|
||||
**params, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links_with_bundle_support(artifacts)
|
||||
for i, artifact in enumerate(artifacts):
|
||||
|
|
@ -2289,7 +2294,7 @@ async def get_step_artifacts(
|
|||
:return: List of artifacts for a list of steps.
|
||||
"""
|
||||
analytics.capture("skyvern-oss-agent-task-step-artifacts-get")
|
||||
artifacts = await app.DATABASE.get_artifacts_for_task_step(
|
||||
artifacts = await app.DATABASE.artifacts.get_artifacts_for_task_step(
|
||||
task_id,
|
||||
step_id,
|
||||
organization_id=current_org.organization_id,
|
||||
|
|
@ -2318,7 +2323,7 @@ async def get_actions(
|
|||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> list[Action]:
|
||||
analytics.capture("skyvern-oss-agent-task-actions-get")
|
||||
actions = await app.DATABASE.get_task_actions(task_id, organization_id=current_org.organization_id)
|
||||
actions = await app.DATABASE.tasks.get_task_actions(task_id, organization_id=current_org.organization_id)
|
||||
return actions
|
||||
|
||||
|
||||
|
|
@ -2602,7 +2607,7 @@ async def get_workflow_run_with_workflow_id(
|
|||
)
|
||||
return_dict = workflow_run_status_response.model_dump(by_alias=True)
|
||||
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session_by_runnable_id(
|
||||
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session_by_runnable_id(
|
||||
runnable_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -2966,7 +2971,7 @@ async def suggest(
|
|||
)
|
||||
|
||||
try:
|
||||
new_ai_suggestion = await app.DATABASE.create_ai_suggestion(
|
||||
new_ai_suggestion = await app.DATABASE.workflow_params.create_ai_suggestion(
|
||||
organization_id=current_org.organization_id,
|
||||
ai_suggestion_type=ai_suggestion_type,
|
||||
)
|
||||
|
|
@ -3239,7 +3244,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
|
|||
"""
|
||||
|
||||
# get task v2 by workflow run id
|
||||
task_v2_obj = await app.DATABASE.get_task_v2_by_workflow_run_id(
|
||||
task_v2_obj = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ async def list_browser_profiles(
|
|||
include_deleted=include_deleted,
|
||||
)
|
||||
|
||||
profiles = await app.DATABASE.list_browser_profiles(
|
||||
profiles = await app.DATABASE.browser_sessions.list_browser_profiles(
|
||||
organization_id=organization_id,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
|
|
@ -199,7 +199,7 @@ async def get_browser_profile(
|
|||
browser_profile_id=profile_id,
|
||||
)
|
||||
|
||||
profile = await app.DATABASE.get_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.get_browser_profile(
|
||||
profile_id=profile_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -264,7 +264,7 @@ async def delete_browser_profile(
|
|||
)
|
||||
|
||||
try:
|
||||
await app.DATABASE.delete_browser_profile(
|
||||
await app.DATABASE.browser_sessions.delete_browser_profile(
|
||||
profile_id=profile_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -290,7 +290,9 @@ async def _create_profile_from_session(
|
|||
description: str | None,
|
||||
browser_session_id: str,
|
||||
) -> BrowserProfile:
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id)
|
||||
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
|
||||
browser_session_id, organization_id
|
||||
)
|
||||
if browser_session is None:
|
||||
LOG.warning(
|
||||
"Browser session not found for profile creation",
|
||||
|
|
@ -318,7 +320,7 @@ async def _create_profile_from_session(
|
|||
)
|
||||
|
||||
try:
|
||||
profile = await app.DATABASE.create_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.create_browser_profile(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
description=description,
|
||||
|
|
@ -334,7 +336,9 @@ async def _create_profile_from_session(
|
|||
)
|
||||
except Exception:
|
||||
# Rollback: delete the profile if storage fails
|
||||
await app.DATABASE.delete_browser_profile(profile.browser_profile_id, organization_id=organization_id)
|
||||
await app.DATABASE.browser_sessions.delete_browser_profile(
|
||||
profile.browser_profile_id, organization_id=organization_id
|
||||
)
|
||||
LOG.error(
|
||||
"Failed to store browser profile artifacts, rolled back profile creation",
|
||||
organization_id=organization_id,
|
||||
|
|
@ -359,7 +363,7 @@ async def _create_profile_from_workflow_run(
|
|||
description: str | None,
|
||||
workflow_run_id: str,
|
||||
) -> BrowserProfile:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||
if not workflow_run:
|
||||
LOG.warning(
|
||||
"Workflow run not found for profile creation",
|
||||
|
|
@ -421,7 +425,7 @@ async def _create_profile_from_workflow_run(
|
|||
)
|
||||
|
||||
try:
|
||||
profile = await app.DATABASE.create_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.create_browser_profile(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
description=description,
|
||||
|
|
@ -443,7 +447,9 @@ async def _create_profile_from_workflow_run(
|
|||
)
|
||||
except Exception:
|
||||
# Rollback: delete the profile if storage fails
|
||||
await app.DATABASE.delete_browser_profile(profile.browser_profile_id, organization_id=organization_id)
|
||||
await app.DATABASE.browser_sessions.delete_browser_profile(
|
||||
profile.browser_profile_id, organization_id=organization_id
|
||||
)
|
||||
LOG.error(
|
||||
"Failed to store browser profile artifacts, rolled back profile creation",
|
||||
organization_id=organization_id,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def get_browser_sessions_all(
|
|||
"""Get all browser sessions for the organization"""
|
||||
analytics.capture("skyvern-oss-agent-browser-sessions-get-all")
|
||||
|
||||
browser_sessions = await app.DATABASE.get_persistent_browser_sessions_history(
|
||||
browser_sessions = await app.DATABASE.browser_sessions.get_persistent_browser_sessions_history(
|
||||
current_org.organization_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -91,7 +91,7 @@ async def create_browser_session(
|
|||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> BrowserSessionResponse:
|
||||
if browser_session_request.browser_profile_id:
|
||||
profile = await app.DATABASE.get_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.get_browser_profile(
|
||||
browser_session_request.browser_profile_id,
|
||||
current_org.organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ async def send_totp_code(
|
|||
)
|
||||
# validate task_id, workflow_id, workflow_run_id are valid ids in db if provided
|
||||
if data.task_id:
|
||||
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(data.task_id, curr_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
||||
if data.workflow_id:
|
||||
|
|
@ -171,7 +171,7 @@ async def send_totp_code(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
||||
if data.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||
if not workflow_run:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid workflow run id: {data.workflow_run_id}")
|
||||
content = data.content.strip()
|
||||
|
|
@ -679,7 +679,7 @@ async def test_credential(
|
|||
# Check if the credential already has a browser profile
|
||||
existing_browser_profile_id = credential.browser_profile_id
|
||||
if existing_browser_profile_id:
|
||||
profile = await app.DATABASE.get_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.get_browser_profile(
|
||||
profile_id=existing_browser_profile_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -891,7 +891,9 @@ async def get_test_credential_status(
|
|||
) -> TestCredentialStatusResponse:
|
||||
organization_id = current_org.organization_id
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise HTTPException(status_code=404, detail=f"Workflow run {workflow_run_id} not found")
|
||||
|
||||
|
|
@ -1054,7 +1056,7 @@ async def _create_browser_profile_after_workflow(
|
|||
|
||||
try:
|
||||
for _ in range(max_polls):
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
if not workflow_run:
|
||||
|
|
@ -1130,7 +1132,7 @@ async def _create_browser_profile_after_workflow(
|
|||
|
||||
# Create the browser profile in DB
|
||||
profile_name = f"Profile - {credential_name} ({credential_id})"
|
||||
profile = await app.DATABASE.create_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.create_browser_profile(
|
||||
organization_id=organization_id,
|
||||
name=profile_name,
|
||||
description=f"Browser profile from credential test for {credential_name}",
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def get_or_create_debug_session_by_user_and_workflow_permanent_id(
|
|||
)
|
||||
|
||||
# Skip renewal for sessions that haven't started yet (browser still launching)
|
||||
session = await app.DATABASE.get_persistent_browser_session(
|
||||
session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
|
||||
debug_session.browser_session_id,
|
||||
current_org.organization_id,
|
||||
)
|
||||
|
|
@ -129,7 +129,7 @@ async def new_debug_session(
|
|||
"""
|
||||
|
||||
if current_user_id:
|
||||
debug_session = await app.DATABASE.debug.get_latest_debug_session_for_user(
|
||||
debug_session = await app.DATABASE.debug.get_debug_session(
|
||||
organization_id=current_org.organization_id,
|
||||
user_id=current_user_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
|
|
@ -172,7 +172,7 @@ async def new_debug_session(
|
|||
|
||||
for debug_session in completed_debug_sessions:
|
||||
try:
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session(
|
||||
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
|
||||
debug_session.browser_session_id,
|
||||
current_org.organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,13 +52,13 @@ async def _load_main_script_content(
|
|||
script_revision_id: str,
|
||||
) -> str | None:
|
||||
"""Load the main.py content from a script revision, if it exists."""
|
||||
script_files = await app.DATABASE.get_script_files(
|
||||
script_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
for f in script_files:
|
||||
if f.file_path == "main.py" and f.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(f.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(f.artifact_id, organization_id)
|
||||
if artifact:
|
||||
data = await app.STORAGE.retrieve_artifact(artifact)
|
||||
if data:
|
||||
|
|
@ -81,7 +81,7 @@ async def get_script_blocks_response(
|
|||
script_id: str | None = None,
|
||||
version: int | None = None,
|
||||
) -> ScriptBlocksResponse:
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -117,7 +117,7 @@ async def get_script_blocks_response(
|
|||
)
|
||||
continue
|
||||
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id=script_revision_id,
|
||||
file_id=script_file_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -147,7 +147,7 @@ async def get_script_blocks_response(
|
|||
)
|
||||
continue
|
||||
|
||||
artifact = await app.DATABASE.get_artifact_by_id(
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
artifact_id,
|
||||
organization_id,
|
||||
)
|
||||
|
|
@ -261,7 +261,7 @@ async def get_script(
|
|||
script_id=script_id,
|
||||
)
|
||||
|
||||
script = await app.DATABASE.get_script(
|
||||
script = await app.DATABASE.scripts.get_script(
|
||||
script_id=script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -287,7 +287,7 @@ async def get_script_versions(
|
|||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> ScriptVersionListResponse:
|
||||
"""List all versions of a script."""
|
||||
scripts = await app.DATABASE.get_script_versions(
|
||||
scripts = await app.DATABASE.scripts.get_script_versions(
|
||||
script_id=script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -319,7 +319,7 @@ async def get_script_version_code(
|
|||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> ScriptBlocksResponse:
|
||||
"""Get a specific version's code blocks."""
|
||||
script = await app.DATABASE.get_script(
|
||||
script = await app.DATABASE.scripts.get_script(
|
||||
script_id=script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
version=version,
|
||||
|
|
@ -358,8 +358,8 @@ async def compare_script_versions(
|
|||
organization_id = current_org.organization_id
|
||||
|
||||
base_script, compare_script = await asyncio.gather(
|
||||
app.DATABASE.get_script(script_id=script_id, organization_id=organization_id, version=base),
|
||||
app.DATABASE.get_script(script_id=script_id, organization_id=organization_id, version=compare),
|
||||
app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id, version=base),
|
||||
app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id, version=compare),
|
||||
)
|
||||
if not base_script:
|
||||
raise HTTPException(status_code=404, detail=f"Base version {base} not found")
|
||||
|
|
@ -418,7 +418,7 @@ async def get_script_version_detail(
|
|||
"""Get full detail for a specific script version, including code blocks and metadata."""
|
||||
organization_id = current_org.organization_id
|
||||
|
||||
script = await app.DATABASE.get_script(
|
||||
script = await app.DATABASE.scripts.get_script(
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
version=version,
|
||||
|
|
@ -436,7 +436,7 @@ async def get_script_version_detail(
|
|||
script_id=script.script_id,
|
||||
version=script.version,
|
||||
),
|
||||
app.DATABASE.get_fallback_episodes_count(
|
||||
app.DATABASE.scripts.get_fallback_episodes_count(
|
||||
organization_id=organization_id,
|
||||
script_revision_id=script.script_revision_id,
|
||||
),
|
||||
|
|
@ -492,7 +492,7 @@ async def get_scripts(
|
|||
page_size=page_size,
|
||||
)
|
||||
|
||||
scripts = await app.DATABASE.get_scripts(
|
||||
scripts = await app.DATABASE.scripts.get_scripts(
|
||||
organization_id=current_org.organization_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -535,7 +535,7 @@ async def deploy_script(
|
|||
|
||||
try:
|
||||
# Get the latest version of the script
|
||||
latest_script = await app.DATABASE.get_latest_script_version(
|
||||
latest_script = await app.DATABASE.scripts.get_latest_script_version(
|
||||
script_id=script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -545,7 +545,7 @@ async def deploy_script(
|
|||
|
||||
# Create a new version of the script
|
||||
new_version = latest_script.version + 1
|
||||
new_script = await app.DATABASE.create_script(
|
||||
new_script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=current_org.organization_id,
|
||||
run_id=latest_script.run_id,
|
||||
script_id=script_id,
|
||||
|
|
@ -553,7 +553,7 @@ async def deploy_script(
|
|||
)
|
||||
|
||||
# Fetch source files from the base revision to build the old->new ID mapping
|
||||
source_files = await app.DATABASE.get_script_files(
|
||||
source_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=latest_script.script_revision_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -580,7 +580,7 @@ async def deploy_script(
|
|||
file_path=file.path,
|
||||
data=content_bytes,
|
||||
)
|
||||
new_file = await app.DATABASE.create_script_file(
|
||||
new_file = await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
|
|
@ -601,7 +601,7 @@ async def deploy_script(
|
|||
for f in source_files:
|
||||
if f.file_path in deployed_file_paths:
|
||||
continue
|
||||
new_file = await app.DATABASE.create_script_file(
|
||||
new_file = await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=current_org.organization_id,
|
||||
|
|
@ -617,13 +617,13 @@ async def deploy_script(
|
|||
old_to_new_file_id[f.file_id] = new_file.file_id
|
||||
|
||||
# Copy existing script blocks, re-pointing file IDs to the new revision's files
|
||||
existing_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
existing_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=latest_script.script_revision_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
for sb in existing_blocks:
|
||||
new_file_id = old_to_new_file_id.get(sb.script_file_id, sb.script_file_id) if sb.script_file_id else None
|
||||
await app.DATABASE.create_script_block(
|
||||
await app.DATABASE.scripts.create_script_block(
|
||||
organization_id=current_org.organization_id,
|
||||
script_id=new_script.script_id,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
|
|
@ -716,7 +716,7 @@ async def get_workflow_script_blocks(
|
|||
include_main_script = True
|
||||
workflow_run_id = block_script_request.workflow_run_id
|
||||
if workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
@ -727,7 +727,7 @@ async def get_workflow_script_blocks(
|
|||
# get_workflow_script() always resolves the latest version for a
|
||||
# cache_key_value, but the Code tab should show the version that was
|
||||
# active when this run executed (SKY-8448).
|
||||
workflow_script = await app.DATABASE.get_workflow_script(
|
||||
workflow_script = await app.DATABASE.scripts.get_workflow_script(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -833,7 +833,7 @@ async def get_workflow_cache_key_values(
|
|||
) -> ScriptCacheKeyValuesResponse:
|
||||
# TODO(jdo): concurrent-ize
|
||||
|
||||
values = await app.DATABASE.get_workflow_cache_key_values(
|
||||
values = await app.DATABASE.scripts.get_workflow_cache_key_values(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key=cache_key,
|
||||
|
|
@ -842,13 +842,13 @@ async def get_workflow_cache_key_values(
|
|||
filter=filter,
|
||||
)
|
||||
|
||||
total_count = await app.DATABASE.get_workflow_cache_key_count(
|
||||
total_count = await app.DATABASE.scripts.get_workflow_cache_key_count(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key=cache_key,
|
||||
)
|
||||
|
||||
filtered_count = await app.DATABASE.get_workflow_cache_key_count(
|
||||
filtered_count = await app.DATABASE.scripts.get_workflow_cache_key_count(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key=cache_key,
|
||||
|
|
@ -889,7 +889,7 @@ async def list_workflow_scripts(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
workflow_scripts = await app.DATABASE.get_workflow_scripts_by_permanent_id(
|
||||
workflow_scripts = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
statuses=[ScriptStatus.published],
|
||||
|
|
@ -918,11 +918,11 @@ async def list_workflow_scripts(
|
|||
# These are independent -- run in parallel.
|
||||
rep_script_ids = [ws.script_id for ws in representatives]
|
||||
version_stats, run_stats = await asyncio.gather(
|
||||
app.DATABASE.get_script_version_stats(
|
||||
app.DATABASE.scripts.get_script_version_stats(
|
||||
organization_id=organization_id,
|
||||
script_ids=rep_script_ids,
|
||||
),
|
||||
app.DATABASE.get_script_run_stats(
|
||||
app.DATABASE.scripts.get_script_run_stats(
|
||||
organization_id=organization_id,
|
||||
script_ids=rep_script_ids,
|
||||
),
|
||||
|
|
@ -980,7 +980,7 @@ async def get_script_runs(
|
|||
organization_id = current_org.organization_id
|
||||
|
||||
# Verify script exists
|
||||
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
|
||||
script = await app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id)
|
||||
if not script:
|
||||
raise HTTPException(status_code=404, detail="Script not found")
|
||||
|
||||
|
|
@ -989,7 +989,7 @@ async def get_script_runs(
|
|||
|
||||
if version is not None:
|
||||
# Get all versions to determine the time window for this version
|
||||
all_versions = await app.DATABASE.get_script_versions(
|
||||
all_versions = await app.DATABASE.scripts.get_script_versions(
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1006,7 +1006,7 @@ async def get_script_runs(
|
|||
if not version_found:
|
||||
raise HTTPException(status_code=404, detail=f"Script version {version} not found")
|
||||
|
||||
runs, total_count, status_counts, avg_fallbacks_per_run = await app.DATABASE.get_workflow_runs_for_script(
|
||||
runs, total_count, status_counts, avg_fallbacks_per_run = await app.DATABASE.scripts.get_workflow_runs_for_script(
|
||||
organization_id=organization_id,
|
||||
script_id=script_id,
|
||||
page_size=page_size,
|
||||
|
|
@ -1063,7 +1063,7 @@ async def delete_workflow_cache_key_value(
|
|||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Delete the cache key value
|
||||
deleted = await app.DATABASE.delete_workflow_cache_key_value(
|
||||
deleted = await app.DATABASE.scripts.delete_workflow_cache_key_value(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
|
|
@ -1130,7 +1130,7 @@ async def clear_workflow_cache(
|
|||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Clear database cache (soft delete)
|
||||
deleted_count = await app.DATABASE.delete_workflow_scripts_by_permanent_id(
|
||||
deleted_count = await app.DATABASE.scripts.delete_workflow_scripts_by_permanent_id(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
)
|
||||
|
|
@ -1185,7 +1185,7 @@ async def pin_workflow_script(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
result = await app.DATABASE.pin_workflow_script(
|
||||
result = await app.DATABASE.scripts.pin_workflow_script(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key_value=data.cache_key_value,
|
||||
|
|
@ -1233,7 +1233,7 @@ async def unpin_workflow_script(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
result = await app.DATABASE.unpin_workflow_script(
|
||||
result = await app.DATABASE.scripts.unpin_workflow_script(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key_value=data.cache_key_value,
|
||||
|
|
@ -1299,7 +1299,7 @@ async def review_script_with_instructions(
|
|||
run_parameter_values: dict[str, str] = {}
|
||||
latest_script = None
|
||||
if data.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1308,17 +1308,17 @@ async def review_script_with_instructions(
|
|||
if workflow_run.workflow_permanent_id != workflow_permanent_id:
|
||||
raise HTTPException(status_code=400, detail="Workflow run does not belong to this workflow")
|
||||
# Look up the specific script used by this run
|
||||
run_workflow_script = await app.DATABASE.get_workflow_script(
|
||||
run_workflow_script = await app.DATABASE.scripts.get_workflow_script(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
)
|
||||
if run_workflow_script:
|
||||
latest_script = await app.DATABASE.get_latest_script_version(
|
||||
latest_script = await app.DATABASE.scripts.get_latest_script_version(
|
||||
script_id=run_workflow_script.script_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
episodes = await app.DATABASE.get_fallback_episodes(
|
||||
episodes = await app.DATABASE.scripts.get_fallback_episodes(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
|
|
@ -1326,7 +1326,7 @@ async def review_script_with_instructions(
|
|||
page_size=50,
|
||||
)
|
||||
try:
|
||||
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
)
|
||||
for wf_param, run_param in run_param_tuples:
|
||||
|
|
@ -1436,7 +1436,7 @@ async def get_fallback_episodes(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
episodes = await app.DATABASE.get_fallback_episodes(
|
||||
episodes = await app.DATABASE.scripts.get_fallback_episodes(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
page=page,
|
||||
|
|
@ -1446,7 +1446,7 @@ async def get_fallback_episodes(
|
|||
reviewed=reviewed,
|
||||
fallback_type=fallback_type,
|
||||
)
|
||||
total_count = await app.DATABASE.get_fallback_episodes_count(
|
||||
total_count = await app.DATABASE.scripts.get_fallback_episodes_count(
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1485,7 +1485,7 @@ async def get_fallback_episode(
|
|||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
episode = await app.DATABASE.get_fallback_episode(
|
||||
episode = await app.DATABASE.scripts.get_fallback_episode(
|
||||
episode_id=episode_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ async def run_sdk_action(
|
|||
|
||||
# Use existing workflow_run_id if provided, otherwise create a new one
|
||||
if action_request.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=action_request.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -90,12 +90,12 @@ async def run_sdk_action(
|
|||
organization=organization,
|
||||
version=None,
|
||||
)
|
||||
workflow_run = await app.DATABASE.update_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
status=WorkflowRunStatus.completed,
|
||||
)
|
||||
|
||||
task = await app.DATABASE.create_task(
|
||||
task = await app.DATABASE.tasks.create_task(
|
||||
organization_id=organization_id,
|
||||
url=action_request.url,
|
||||
navigation_goal=action.get_navigation_goal(),
|
||||
|
|
@ -107,14 +107,14 @@ async def run_sdk_action(
|
|||
browser_address=browser_address,
|
||||
)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
step = await app.DATABASE.tasks.create_step(
|
||||
task.task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
|
||||
await app.DATABASE.create_workflow_run_block(
|
||||
await app.DATABASE.observer.create_workflow_run_block(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
block_type=BlockType.ACTION,
|
||||
|
|
@ -221,13 +221,13 @@ async def run_sdk_action(
|
|||
model=action.model,
|
||||
)
|
||||
result = prompt_result
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
status=TaskStatus.completed,
|
||||
)
|
||||
except ScrapingFailed as e:
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
status=TaskStatus.failed,
|
||||
|
|
@ -240,7 +240,7 @@ async def run_sdk_action(
|
|||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.reason or str(e))
|
||||
except Exception as e:
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
status=TaskStatus.failed,
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ async def cdp_input_stream(
|
|||
try:
|
||||
deadline = time.monotonic() + 120
|
||||
while True:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ async def task_stream(
|
|||
)
|
||||
return
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
|
||||
await websocket.send_json(
|
||||
|
|
@ -210,7 +210,7 @@ async def workflow_run_streaming(
|
|||
)
|
||||
return
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -403,7 +403,7 @@ async def _local_screencast_for_workflow_run(
|
|||
async def wait_for_running() -> str | None:
|
||||
deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
|
||||
while True:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -423,14 +423,14 @@ async def _local_screencast_for_workflow_run(
|
|||
await asyncio.sleep(1)
|
||||
|
||||
async def check_finalized() -> bool:
|
||||
wr = await app.DATABASE.get_workflow_run(
|
||||
wr = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return wr is None or wr.status.is_final()
|
||||
|
||||
async def get_current_status() -> str | None:
|
||||
wr = await app.DATABASE.get_workflow_run(
|
||||
wr = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -457,7 +457,7 @@ async def _local_screencast_for_task(
|
|||
nonlocal task_workflow_run_id
|
||||
deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
|
||||
while True:
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
return "not_found"
|
||||
if task.status.is_final():
|
||||
|
|
@ -475,11 +475,11 @@ async def _local_screencast_for_task(
|
|||
await asyncio.sleep(1)
|
||||
|
||||
async def check_finalized() -> bool:
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
return task is None or task.status.is_final()
|
||||
|
||||
async def get_current_status() -> str | None:
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
return task.status if task else None
|
||||
|
||||
await _run_local_screencast(
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ async def verify_task(
|
|||
with it.
|
||||
"""
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
|
||||
|
|
@ -171,7 +171,7 @@ async def verify_workflow_run(
|
|||
with it.
|
||||
"""
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class RunInfo:
|
|||
|
||||
|
||||
async def _get_debug_artifact(organization_id: str, workflow_run_id: str) -> Artifact | None:
|
||||
artifacts = await app.DATABASE.get_artifacts_for_run(
|
||||
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
|
||||
run_id=workflow_run_id, organization_id=organization_id, artifact_types=[ArtifactType.VISIBLE_ELEMENTS_TREE]
|
||||
)
|
||||
return artifacts[0] if isinstance(artifacts, list) and artifacts else None
|
||||
|
|
@ -76,7 +76,7 @@ async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None)
|
|||
if not workflow_run_id:
|
||||
return None
|
||||
|
||||
blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
if not blocks:
|
||||
|
|
@ -651,7 +651,7 @@ async def workflow_copilot_chat_post(
|
|||
)
|
||||
|
||||
if chat_request.workflow_copilot_chat_id:
|
||||
chat = await app.DATABASE.get_workflow_copilot_chat_by_id(
|
||||
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
||||
)
|
||||
|
|
@ -660,12 +660,12 @@ async def workflow_copilot_chat_post(
|
|||
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
|
||||
else:
|
||||
chat = await app.DATABASE.create_workflow_copilot_chat(
|
||||
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_permanent_id=chat_request.workflow_permanent_id,
|
||||
)
|
||||
|
||||
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(
|
||||
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
)
|
||||
global_llm_context = None
|
||||
|
|
@ -719,20 +719,20 @@ async def workflow_copilot_chat_post(
|
|||
return
|
||||
|
||||
if updated_workflow and chat.auto_accept is not True:
|
||||
await app.DATABASE.update_workflow_copilot_chat(
|
||||
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
proposed_workflow=updated_workflow.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.USER,
|
||||
content=chat_request.message,
|
||||
)
|
||||
|
||||
assistant_message = await app.DATABASE.create_workflow_copilot_chat_message(
|
||||
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
||||
organization_id=chat.organization_id,
|
||||
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
||||
sender=WorkflowCopilotChatSender.AI,
|
||||
|
|
@ -791,12 +791,14 @@ async def workflow_copilot_chat_history(
|
|||
workflow_permanent_id: str,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowCopilotChatHistoryResponse:
|
||||
latest_chat = await app.DATABASE.get_latest_workflow_copilot_chat(
|
||||
latest_chat = await app.DATABASE.workflow_params.get_latest_workflow_copilot_chat(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
)
|
||||
if latest_chat:
|
||||
chat_messages = await app.DATABASE.get_workflow_copilot_chat_messages(latest_chat.workflow_copilot_chat_id)
|
||||
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
|
||||
latest_chat.workflow_copilot_chat_id
|
||||
)
|
||||
else:
|
||||
chat_messages = []
|
||||
return WorkflowCopilotChatHistoryResponse(
|
||||
|
|
@ -814,7 +816,7 @@ async def workflow_copilot_clear_proposed_workflow(
|
|||
clear_request: WorkflowCopilotClearProposedWorkflowRequest,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> None:
|
||||
updated_chat = await app.DATABASE.update_workflow_copilot_chat(
|
||||
updated_chat = await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
||||
organization_id=organization.organization_id,
|
||||
workflow_copilot_chat_id=clear_request.workflow_copilot_chat_id,
|
||||
proposed_workflow=None,
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class Block(BaseModel, abc.ABC):
|
|||
parameter=self.output_parameter,
|
||||
value=value,
|
||||
)
|
||||
await app.DATABASE.create_or_update_workflow_run_output_parameter(
|
||||
await app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=value,
|
||||
|
|
@ -238,7 +238,7 @@ class Block(BaseModel, abc.ABC):
|
|||
output_parameter_value = {"value": output_parameter_value}
|
||||
|
||||
if workflow_run_block_id:
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
output=output_parameter_value,
|
||||
status=status,
|
||||
|
|
@ -495,7 +495,7 @@ class Block(BaseModel, abc.ABC):
|
|||
LOG.exception("Failed to generate description for the workflow run block", error=e)
|
||||
|
||||
if description:
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
description=description,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -522,7 +522,7 @@ class Block(BaseModel, abc.ABC):
|
|||
if isinstance(self, BaseTaskBlock):
|
||||
engine = self.engine
|
||||
|
||||
workflow_run_block = await app.DATABASE.create_workflow_run_block(
|
||||
workflow_run_block = await app.DATABASE.observer.create_workflow_run_block(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
parent_workflow_run_block_id=parent_workflow_run_block_id,
|
||||
|
|
@ -707,7 +707,9 @@ class BaseTaskBlock(Block):
|
|||
"""
|
||||
Returns the order and retry for the next task in the workflow run as a tuple.
|
||||
"""
|
||||
last_task_for_workflow_run = await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
last_task_for_workflow_run = await app.DATABASE.tasks.get_last_task_for_workflow_run(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
# If there is no previous task, the order will be 0 and the retry will be 0.
|
||||
if last_task_for_workflow_run is None:
|
||||
return 0, 0
|
||||
|
|
@ -746,7 +748,7 @@ class BaseTaskBlock(Block):
|
|||
This helper method consolidates the error detection logic that was previously
|
||||
duplicated across multiple exception handlers in the execute method.
|
||||
"""
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -764,7 +766,7 @@ class BaseTaskBlock(Block):
|
|||
if detected_errors:
|
||||
# Only pass new errors — update_task() appends to existing errors
|
||||
new_errors = [error.model_dump() for error in detected_errors]
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
errors=new_errors,
|
||||
|
|
@ -867,7 +869,7 @@ class BaseTaskBlock(Block):
|
|||
task_order=task_order,
|
||||
task_retry=task_retry,
|
||||
)
|
||||
workflow_run_block = await app.DATABASE.update_workflow_run_block(
|
||||
workflow_run_block = await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1011,7 +1013,7 @@ class BaseTaskBlock(Block):
|
|||
current_context.task_id = None
|
||||
|
||||
# Check task status
|
||||
updated_task = await app.DATABASE.get_task(
|
||||
updated_task = await app.DATABASE.tasks.get_task(
|
||||
task_id=task.task_id, organization_id=workflow_run.organization_id
|
||||
)
|
||||
if not updated_task:
|
||||
|
|
@ -1941,7 +1943,7 @@ class ForLoopBlock(Block):
|
|||
)
|
||||
try:
|
||||
if block_output.workflow_run_block_id:
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=block_output.workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
current_value=str(loop_over_value),
|
||||
|
|
@ -2106,7 +2108,7 @@ class ForLoopBlock(Block):
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
loop_values=loop_over_values,
|
||||
|
|
@ -2586,7 +2588,9 @@ class TextPromptBlock(Block):
|
|||
artifacts_to_persist: list[tuple[ArtifactType, bytes]] = []
|
||||
if workflow_run_block_id:
|
||||
try:
|
||||
workflow_run_block = await app.DATABASE.get_workflow_run_block(workflow_run_block_id, organization_id)
|
||||
workflow_run_block = await app.DATABASE.observer.get_workflow_run_block(
|
||||
workflow_run_block_id, organization_id
|
||||
)
|
||||
if workflow_run_block:
|
||||
artifacts_to_persist.append((ArtifactType.LLM_PROMPT, prompt.encode("utf-8")))
|
||||
except Exception as e:
|
||||
|
|
@ -2645,7 +2649,7 @@ class TextPromptBlock(Block):
|
|||
)
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
prompt=self.prompt,
|
||||
|
|
@ -3444,7 +3448,7 @@ class SendEmailBlock(Block):
|
|||
**kwargs: dict,
|
||||
) -> BlockResult:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
recipients=self.recipients,
|
||||
|
|
@ -4140,7 +4144,7 @@ class WaitBlock(Block):
|
|||
**kwargs: dict,
|
||||
) -> BlockResult:
|
||||
# TODO: we need to support to interrupt the sleep when the workflow run failed/cancelled/terminated
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
wait_sec=self.wait_sec,
|
||||
|
|
@ -4246,7 +4250,7 @@ class HumanInteractionBlock(BaseTaskBlock):
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
recipients=self.recipients,
|
||||
|
|
@ -4265,12 +4269,12 @@ class HumanInteractionBlock(BaseTaskBlock):
|
|||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=WorkflowRunStatus.paused,
|
||||
)
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -4372,7 +4376,7 @@ class HumanInteractionBlock(BaseTaskBlock):
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -4574,7 +4578,7 @@ class TaskV2Block(Block):
|
|||
organization = await app.DATABASE.organizations.get_organization(organization_id)
|
||||
if not organization:
|
||||
raise ValueError(f"Organization not found {organization_id}")
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"WorkflowRun not found {workflow_run_id} when running TaskV2Block")
|
||||
try:
|
||||
|
|
@ -4588,15 +4592,15 @@ class TaskV2Block(Block):
|
|||
totp_verification_url=resolved_totp_verification_url,
|
||||
max_screenshot_scrolling_times=workflow_run.max_screenshot_scrolls,
|
||||
)
|
||||
await app.DATABASE.update_task_v2(
|
||||
await app.DATABASE.observer.update_task_v2(
|
||||
task_v2.observer_cruise_id, status=TaskV2Status.queued, organization_id=organization_id
|
||||
)
|
||||
if task_v2.workflow_run_id:
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=task_v2.workflow_run_id,
|
||||
status=WorkflowRunStatus.queued,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
block_workflow_run_id=task_v2.workflow_run_id,
|
||||
|
|
@ -4642,7 +4646,9 @@ class TaskV2Block(Block):
|
|||
failure_reason: str | None = None
|
||||
task_v2_workflow_run_id = task_v2.workflow_run_id
|
||||
if task_v2_workflow_run_id:
|
||||
task_v2_workflow_run = await app.DATABASE.get_workflow_run(task_v2_workflow_run_id, organization_id)
|
||||
task_v2_workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
task_v2_workflow_run_id, organization_id
|
||||
)
|
||||
if task_v2_workflow_run:
|
||||
failure_reason = task_v2_workflow_run.failure_reason
|
||||
|
||||
|
|
@ -5237,7 +5243,7 @@ class PrintPageBlock(Block):
|
|||
return None, None
|
||||
|
||||
try:
|
||||
workflow_run_block = await app.DATABASE.get_workflow_run_block(
|
||||
workflow_run_block = await app.DATABASE.observer.get_workflow_run_block(
|
||||
workflow_run_block_id,
|
||||
organization_id=artifact_org_id,
|
||||
)
|
||||
|
|
@ -5269,7 +5275,7 @@ class PrintPageBlock(Block):
|
|||
# Generate a downloadable URL for the artifact
|
||||
artifact_url = None
|
||||
try:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(artifact_id, organization_id=artifact_org_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(artifact_id, organization_id=artifact_org_id)
|
||||
if artifact:
|
||||
artifact_url = await app.ARTIFACT_MANAGER.get_share_link(artifact)
|
||||
except Exception:
|
||||
|
|
@ -6575,7 +6581,7 @@ class WorkflowTriggerBlock(Block):
|
|||
f"Workflow trigger depth exceeds maximum of {self.MAX_TRIGGER_DEPTH}. "
|
||||
"This may indicate a circular workflow trigger chain."
|
||||
)
|
||||
run = await app.DATABASE.get_workflow_run(current_run_id)
|
||||
run = await app.DATABASE.workflow_runs.get_workflow_run(current_run_id)
|
||||
if not run or not run.parent_workflow_run_id:
|
||||
break
|
||||
current_run_id = run.parent_workflow_run_id
|
||||
|
|
@ -6721,7 +6727,7 @@ class WorkflowTriggerBlock(Block):
|
|||
elif self.wait_for_completion:
|
||||
# Sync mode: child runs inline in the same process, so it needs
|
||||
# its own persistent session to avoid sharing the parent's browser.
|
||||
parent_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id)
|
||||
parent_workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id)
|
||||
proxy_location = parent_workflow_run.proxy_location if parent_workflow_run else None
|
||||
try:
|
||||
child_browser_session = await app.PERSISTENT_SESSIONS_MANAGER.create_session(
|
||||
|
|
|
|||
|
|
@ -363,14 +363,14 @@ class WorkflowService:
|
|||
target_labels = set(block_labels_to_disable)
|
||||
|
||||
for candidate in candidates:
|
||||
script = await app.DATABASE.get_script(
|
||||
script = await app.DATABASE.scripts.get_script(
|
||||
script_id=candidate.script_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not script:
|
||||
continue
|
||||
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -400,7 +400,7 @@ class WorkflowService:
|
|||
"""Remove cached run signatures for the supplied block groups to force regeneration."""
|
||||
for group in groups:
|
||||
for block in group.blocks_to_clear:
|
||||
await app.DATABASE.update_script_block(
|
||||
await app.DATABASE.scripts.update_script_block(
|
||||
script_block_id=block.script_block_id,
|
||||
organization_id=organization_id,
|
||||
clear_run_signature=True,
|
||||
|
|
@ -451,7 +451,7 @@ class WorkflowService:
|
|||
if not artifact_ids or not organization_id:
|
||||
return []
|
||||
|
||||
artifacts = await app.DATABASE.get_artifacts_by_ids(artifact_ids, organization_id)
|
||||
artifacts = await app.DATABASE.artifacts.get_artifacts_by_ids(artifact_ids, organization_id)
|
||||
if not artifacts:
|
||||
return []
|
||||
|
||||
|
|
@ -988,7 +988,7 @@ class WorkflowService:
|
|||
if browser_session:
|
||||
browser_session_id = browser_session.persistent_browser_session_id
|
||||
close_browser_on_completion = True
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
|
@ -1064,7 +1064,7 @@ class WorkflowService:
|
|||
|
||||
# Refresh workflow_run from DB to pick up status/failure_reason
|
||||
# set by _execute_workflow_blocks.
|
||||
if refreshed_workflow_run := await app.DATABASE.get_workflow_run(
|
||||
if refreshed_workflow_run := await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
):
|
||||
|
|
@ -1110,7 +1110,7 @@ class WorkflowService:
|
|||
should_trigger_reviewer = True
|
||||
current_ctx = skyvern_context.current()
|
||||
if current_ctx and current_ctx.script_id:
|
||||
latest_script = await app.DATABASE.get_latest_script_version(
|
||||
latest_script = await app.DATABASE.scripts.get_latest_script_version(
|
||||
script_id=current_ctx.script_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -1257,7 +1257,7 @@ class WorkflowService:
|
|||
context.script_id = script.script_id
|
||||
context.script_revision_id = script.script_revision_id
|
||||
try:
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1271,7 +1271,7 @@ class WorkflowService:
|
|||
|
||||
if is_script_run:
|
||||
# load the script files
|
||||
script_files = await app.DATABASE.get_script_files(
|
||||
script_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1280,7 +1280,7 @@ class WorkflowService:
|
|||
script_path = os.path.join(settings.TEMP_PATH, script.script_id, "main.py")
|
||||
if os.path.exists(script_path):
|
||||
# setup script run
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
script_parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
|
|
@ -1353,7 +1353,7 @@ class WorkflowService:
|
|||
is_script_run = True
|
||||
# Initialize RunContext with the browser page + parameters,
|
||||
# same as the normal script loading path at line 1310.
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
script_parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
|
|
@ -1559,7 +1559,7 @@ class WorkflowService:
|
|||
# browser-automation code like page.classify for pure-Python conditionals).
|
||||
fallback_type = "conditional_agent" if isinstance(block, ConditionalBlock) else "full_block"
|
||||
|
||||
episode = await app.DATABASE.create_fallback_episode(
|
||||
episode = await app.DATABASE.scripts.create_fallback_episode(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1824,7 +1824,7 @@ class WorkflowService:
|
|||
block_executed_with_code = False
|
||||
|
||||
try:
|
||||
if refreshed_workflow_run := await app.DATABASE.get_workflow_run(
|
||||
if refreshed_workflow_run := await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
):
|
||||
|
|
@ -1884,12 +1884,12 @@ class WorkflowService:
|
|||
# Persist the browser_profile_id on the workflow_run so
|
||||
# subsequent blocks create / reuse a browser with the
|
||||
# saved profile (cookies, localStorage, etc.).
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
browser_profile_id=resolved_browser_profile_id,
|
||||
)
|
||||
workflow_run = (
|
||||
await app.DATABASE.get_workflow_run(
|
||||
await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1932,7 +1932,7 @@ class WorkflowService:
|
|||
)
|
||||
profile_loaded = False
|
||||
# Clear the profile so the normal login path doesn't reuse it
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
browser_profile_id=None,
|
||||
)
|
||||
|
|
@ -2027,7 +2027,7 @@ class WorkflowService:
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -2215,12 +2215,12 @@ class WorkflowService:
|
|||
fallback_wrb_id = workflow_run_block_result.workflow_run_block_id
|
||||
if fallback_wrb_id:
|
||||
try:
|
||||
wrb = await app.DATABASE.get_workflow_run_block(
|
||||
wrb = await app.DATABASE.observer.get_workflow_run_block(
|
||||
workflow_run_block_id=fallback_wrb_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if wrb and wrb.task_id:
|
||||
actions = await app.DATABASE.get_task_actions(
|
||||
actions = await app.DATABASE.tasks.get_task_actions(
|
||||
task_id=wrb.task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -2232,7 +2232,7 @@ class WorkflowService:
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_fallback_episode(
|
||||
await app.DATABASE.scripts.update_fallback_episode(
|
||||
episode_id=fallback_episode_id,
|
||||
organization_id=organization_id,
|
||||
agent_actions=agent_actions_summary,
|
||||
|
|
@ -2306,7 +2306,7 @@ class WorkflowService:
|
|||
}
|
||||
)
|
||||
cond_context = skyvern_context.current()
|
||||
cond_episode = await app.DATABASE.create_fallback_episode(
|
||||
cond_episode = await app.DATABASE.scripts.create_fallback_episode(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -2321,7 +2321,7 @@ class WorkflowService:
|
|||
"expressions": expressions,
|
||||
},
|
||||
)
|
||||
await app.DATABASE.update_fallback_episode(
|
||||
await app.DATABASE.scripts.update_fallback_episode(
|
||||
episode_id=cond_episode.episode_id,
|
||||
organization_id=organization_id,
|
||||
fallback_succeeded=True,
|
||||
|
|
@ -2450,7 +2450,7 @@ class WorkflowService:
|
|||
# falls back to default_value on the workflow parameter).
|
||||
if run_param_tuples is None:
|
||||
try:
|
||||
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -2486,7 +2486,7 @@ class WorkflowService:
|
|||
)
|
||||
if db_cred and db_cred.browser_profile_id:
|
||||
# Verify the browser profile still exists before using it
|
||||
profile = await app.DATABASE.get_browser_profile(
|
||||
profile = await app.DATABASE.browser_sessions.get_browser_profile(
|
||||
profile_id=db_cred.browser_profile_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -3186,7 +3186,7 @@ class WorkflowService:
|
|||
if not block_run:
|
||||
continue
|
||||
|
||||
output_parameter = await app.DATABASE.get_workflow_run_output_parameter_by_id(
|
||||
output_parameter = await app.DATABASE.workflow_runs.get_workflow_run_output_parameter_by_id(
|
||||
workflow_run_id=block_run.workflow_run_id, output_parameter_id=block_run.output_parameter_id
|
||||
)
|
||||
|
||||
|
|
@ -3323,7 +3323,7 @@ class WorkflowService:
|
|||
previous_blocks=current_definition.get("blocks", []),
|
||||
new_blocks=new_definition.get("blocks", []),
|
||||
)
|
||||
candidates = await app.DATABASE.get_workflow_scripts_by_permanent_id(
|
||||
candidates = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
|
||||
)
|
||||
|
|
@ -3396,7 +3396,7 @@ class WorkflowService:
|
|||
|
||||
if len(to_delete) > 0:
|
||||
try:
|
||||
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
|
||||
await app.DATABASE.scripts.delete_workflow_scripts_by_permanent_id(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
|
||||
script_ids=[s.script_id for s in to_delete],
|
||||
|
|
@ -3458,7 +3458,7 @@ class WorkflowService:
|
|||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
return await app.DATABASE.get_workflow_runs(
|
||||
return await app.DATABASE.workflow_runs.get_workflow_runs(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -3473,7 +3473,7 @@ class WorkflowService:
|
|||
organization_id: str,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
) -> int:
|
||||
return await app.DATABASE.get_workflow_runs_count(
|
||||
return await app.DATABASE.workflow_runs.get_workflow_runs_count(
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
)
|
||||
|
|
@ -3488,7 +3488,7 @@ class WorkflowService:
|
|||
search_key: str | None = None,
|
||||
error_code: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
return await app.DATABASE.get_workflow_runs_for_workflow_permanent_id(
|
||||
return await app.DATABASE.workflow_runs.get_workflow_runs_for_workflow_permanent_id(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
|
|
@ -3515,7 +3515,7 @@ class WorkflowService:
|
|||
# validate the browser session or profile id
|
||||
browser_profile_id = workflow_request.browser_profile_id
|
||||
if workflow_request.browser_session_id:
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session(
|
||||
browser_session = await app.DATABASE.browser_sessions.get_persistent_browser_session(
|
||||
session_id=workflow_request.browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -3531,7 +3531,7 @@ class WorkflowService:
|
|||
)
|
||||
|
||||
if browser_profile_id:
|
||||
browser_profile = await app.DATABASE.get_browser_profile(
|
||||
browser_profile = await app.DATABASE.browser_sessions.get_browser_profile(
|
||||
browser_profile_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -3579,7 +3579,7 @@ class WorkflowService:
|
|||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
return await app.DATABASE.create_workflow_run(
|
||||
return await app.DATABASE.workflow_runs.create_workflow_run(
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -3612,7 +3612,7 @@ class WorkflowService:
|
|||
ai_fallback: bool | None = None,
|
||||
failure_category: list[dict] | None = None,
|
||||
) -> WorkflowRun:
|
||||
workflow_run = await app.DATABASE.update_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=status,
|
||||
failure_reason=failure_reason,
|
||||
|
|
@ -3659,7 +3659,7 @@ class WorkflowService:
|
|||
) -> None:
|
||||
"""Fire-and-forget: propagate workflow_run status to task_runs."""
|
||||
try:
|
||||
await app.DATABASE.sync_task_run_status(
|
||||
await app.DATABASE.tasks.sync_task_run_status(
|
||||
organization_id=workflow_run.organization_id,
|
||||
run_id=workflow_run_id,
|
||||
status=status.value,
|
||||
|
|
@ -3667,12 +3667,12 @@ class WorkflowService:
|
|||
finished_at=workflow_run.finished_at,
|
||||
)
|
||||
# Also sync task_v2 if this workflow_run backs an observer_cruise
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=workflow_run.organization_id,
|
||||
)
|
||||
if task_v2:
|
||||
await app.DATABASE.sync_task_run_status(
|
||||
await app.DATABASE.tasks.sync_task_run_status(
|
||||
organization_id=workflow_run.organization_id,
|
||||
run_id=task_v2.observer_cruise_id,
|
||||
status=status.value,
|
||||
|
|
@ -3882,7 +3882,7 @@ class WorkflowService:
|
|||
)
|
||||
|
||||
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -3898,7 +3898,7 @@ class WorkflowService:
|
|||
default_value: bool | int | float | str | dict | list | None = None,
|
||||
description: str | None = None,
|
||||
) -> WorkflowParameter:
|
||||
return await app.DATABASE.create_workflow_parameter(
|
||||
return await app.DATABASE.workflow_params.create_workflow_parameter(
|
||||
workflow_id=workflow_id,
|
||||
workflow_parameter_type=workflow_parameter_type,
|
||||
key=key,
|
||||
|
|
@ -3909,17 +3909,19 @@ class WorkflowService:
|
|||
async def create_aws_secret_parameter(
|
||||
self, workflow_id: str, aws_key: str, key: str, description: str | None = None
|
||||
) -> AWSSecretParameter:
|
||||
return await app.DATABASE.create_aws_secret_parameter(
|
||||
return await app.DATABASE.workflow_params.create_aws_secret_parameter(
|
||||
workflow_id=workflow_id, aws_key=aws_key, key=key, description=description
|
||||
)
|
||||
|
||||
async def create_output_parameter(
|
||||
self, workflow_id: str, key: str, description: str | None = None
|
||||
) -> OutputParameter:
|
||||
return await app.DATABASE.create_output_parameter(workflow_id=workflow_id, key=key, description=description)
|
||||
return await app.DATABASE.workflow_params.create_output_parameter(
|
||||
workflow_id=workflow_id, key=key, description=description
|
||||
)
|
||||
|
||||
async def get_workflow_parameters(self, workflow_id: str) -> list[WorkflowParameter]:
|
||||
return await app.DATABASE.get_workflow_parameters(workflow_id=workflow_id)
|
||||
return await app.DATABASE.workflow_params.get_workflow_parameters(workflow_id=workflow_id)
|
||||
|
||||
async def create_workflow_run_parameter(
|
||||
self,
|
||||
|
|
@ -3929,7 +3931,7 @@ class WorkflowService:
|
|||
) -> WorkflowRunParameter:
|
||||
value = self._serialize_workflow_run_parameter_value(workflow_parameter, value)
|
||||
|
||||
return await app.DATABASE.create_workflow_run_parameter(
|
||||
return await app.DATABASE.workflow_runs.create_workflow_run_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter=workflow_parameter,
|
||||
value=value,
|
||||
|
|
@ -3945,7 +3947,7 @@ class WorkflowService:
|
|||
for workflow_parameter, value in workflow_parameter_values
|
||||
]
|
||||
|
||||
return await app.DATABASE.create_workflow_run_parameters(
|
||||
return await app.DATABASE.workflow_runs.create_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_parameter_values=serialized_workflow_parameter_values,
|
||||
)
|
||||
|
|
@ -3960,27 +3962,27 @@ class WorkflowService:
|
|||
async def get_workflow_run_parameter_tuples(
|
||||
self, workflow_run_id: str
|
||||
) -> list[tuple[WorkflowParameter, WorkflowRunParameter]]:
|
||||
return await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
return await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_workflow_output_parameters(workflow_id: str) -> list[OutputParameter]:
|
||||
return await app.DATABASE.get_workflow_output_parameters(workflow_id=workflow_id)
|
||||
return await app.DATABASE.workflow_params.get_workflow_output_parameters(workflow_id=workflow_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_workflow_run_output_parameters(
|
||||
workflow_run_id: str,
|
||||
) -> list[WorkflowRunOutputParameter]:
|
||||
return await app.DATABASE.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
|
||||
return await app.DATABASE.workflow_runs.get_workflow_run_output_parameters(workflow_run_id=workflow_run_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_output_parameter_workflow_run_output_parameter_tuples(
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> list[tuple[OutputParameter, WorkflowRunOutputParameter]]:
|
||||
workflow_run_output_parameters = await app.DATABASE.get_workflow_run_output_parameters(
|
||||
workflow_run_output_parameters = await app.DATABASE.workflow_runs.get_workflow_run_output_parameters(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
output_parameters = await app.DATABASE.get_workflow_output_parameters_by_ids(
|
||||
output_parameters = await app.DATABASE.workflow_params.get_workflow_output_parameters_by_ids(
|
||||
output_parameter_ids=[
|
||||
workflow_run_output_parameter.output_parameter_id
|
||||
for workflow_run_output_parameter in workflow_run_output_parameters
|
||||
|
|
@ -3995,10 +3997,10 @@ class WorkflowService:
|
|||
]
|
||||
|
||||
async def get_last_task_for_workflow_run(self, workflow_run_id: str) -> Task | None:
|
||||
return await app.DATABASE.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
return await app.DATABASE.tasks.get_last_task_for_workflow_run(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_tasks_by_workflow_run_id(self, workflow_run_id: str) -> list[Task]:
|
||||
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
return await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
|
||||
async def get_recent_task_screenshot_artifacts(
|
||||
self,
|
||||
|
|
@ -4015,7 +4017,7 @@ class WorkflowService:
|
|||
artifacts: list[Artifact] = []
|
||||
if task_id:
|
||||
artifacts = (
|
||||
await app.DATABASE.get_latest_n_artifacts(
|
||||
await app.DATABASE.artifacts.get_latest_n_artifacts(
|
||||
task_id=task_id,
|
||||
artifact_types=artifact_types,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -4024,13 +4026,13 @@ class WorkflowService:
|
|||
or []
|
||||
)
|
||||
elif task_v2_id:
|
||||
action_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
|
||||
action_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_ACTION,
|
||||
task_v2_id=task_v2_id,
|
||||
limit=limit,
|
||||
)
|
||||
final_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
|
||||
final_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_FINAL,
|
||||
task_v2_id=task_v2_id,
|
||||
|
|
@ -4077,10 +4079,10 @@ class WorkflowService:
|
|||
seen_artifact_ids: set[str] = set()
|
||||
|
||||
if workflow_run_tasks is None:
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
|
||||
for task in workflow_run_tasks[::-1]:
|
||||
artifact = await app.DATABASE.get_latest_artifact(
|
||||
artifact = await app.DATABASE.artifacts.get_latest_artifact(
|
||||
task_id=task.task_id,
|
||||
artifact_types=[ArtifactType.SCREENSHOT_ACTION, ArtifactType.SCREENSHOT_FINAL],
|
||||
organization_id=organization_id,
|
||||
|
|
@ -4092,13 +4094,13 @@ class WorkflowService:
|
|||
break
|
||||
|
||||
if len(screenshot_artifacts) < limit:
|
||||
action_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
|
||||
action_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_ACTION,
|
||||
workflow_run_id=workflow_run_id,
|
||||
limit=limit,
|
||||
)
|
||||
final_artifacts = await app.DATABASE.get_artifacts_by_entity_id(
|
||||
final_artifacts = await app.DATABASE.artifacts.get_artifacts_by_entity_id(
|
||||
organization_id=organization_id,
|
||||
artifact_type=ArtifactType.SCREENSHOT_FINAL,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -4173,11 +4175,11 @@ class WorkflowService:
|
|||
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
screenshot_urls: list[str] | None = await self.get_recent_workflow_screenshot_urls(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -4201,7 +4203,7 @@ class WorkflowService:
|
|||
LOG.warning("Timeout getting recordings", browser_session_id=workflow_run.browser_session_id)
|
||||
|
||||
if recording_url is None:
|
||||
recording_artifact = await app.DATABASE.get_artifact_for_run(
|
||||
recording_artifact = await app.DATABASE.artifacts.get_artifact_for_run(
|
||||
run_id=task_v2.observer_cruise_id if task_v2 else workflow_run_id,
|
||||
artifact_type=ArtifactType.RECORDING,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -4239,7 +4241,9 @@ class WorkflowService:
|
|||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
workflow_parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
|
||||
output_parameter_tuples: list[
|
||||
tuple[OutputParameter, WorkflowRunOutputParameter]
|
||||
|
|
@ -4270,7 +4274,7 @@ class WorkflowService:
|
|||
# matching the task-level error format. Uses a lightweight query that only
|
||||
# fetches blocks with non-null error_codes to avoid a full block load on
|
||||
# every status poll.
|
||||
block_errors = await app.DATABASE.get_workflow_run_block_errors(
|
||||
block_errors = await app.DATABASE.workflow_runs.get_workflow_run_block_errors(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
for error_codes, failure_reason in block_errors:
|
||||
|
|
@ -4286,13 +4290,13 @@ class WorkflowService:
|
|||
total_steps = None
|
||||
total_cost = None
|
||||
if include_step_count or include_cost:
|
||||
workflow_run_steps = await app.DATABASE.get_steps_by_task_ids(
|
||||
workflow_run_steps = await app.DATABASE.tasks.get_steps_by_task_ids(
|
||||
task_ids=[task.task_id for task in workflow_run_tasks], organization_id=organization_id
|
||||
)
|
||||
total_steps = len(workflow_run_steps)
|
||||
|
||||
if include_cost:
|
||||
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
text_prompt_blocks = [
|
||||
|
|
@ -4353,7 +4357,7 @@ class WorkflowService:
|
|||
# tasks into the parent list for debug artifact persistence, and collect
|
||||
# child workflow_run IDs so cleanup_for_workflow_run can pop their orphaned
|
||||
# entries from self.pages (child skips clean_up_workflow).
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow_run.organization_id,
|
||||
)
|
||||
|
|
@ -4514,7 +4518,7 @@ class WorkflowService:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
webhook_failure_reason="",
|
||||
)
|
||||
|
|
@ -4528,7 +4532,7 @@ class WorkflowService:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",
|
||||
)
|
||||
|
|
@ -4611,7 +4615,7 @@ class WorkflowService:
|
|||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> None:
|
||||
last_step = await app.DATABASE.get_latest_step(
|
||||
last_step = await app.DATABASE.tasks.get_latest_step(
|
||||
task_id=last_task.task_id, organization_id=last_task.organization_id
|
||||
)
|
||||
if not last_step:
|
||||
|
|
@ -4679,7 +4683,7 @@ class WorkflowService:
|
|||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
await app.DATABASE.save_workflow_definition_parameters(workflow_definition.parameters)
|
||||
await app.DATABASE.workflow_params.save_workflow_definition_parameters(workflow_definition.parameters)
|
||||
|
||||
return workflow_definition
|
||||
|
||||
|
|
@ -4830,7 +4834,7 @@ class WorkflowService:
|
|||
@staticmethod
|
||||
async def create_output_parameter_for_block(workflow_id: str, block_yaml: BLOCK_YAML_TYPES) -> OutputParameter:
|
||||
output_parameter_key = f"{block_yaml.label}_output"
|
||||
return await app.DATABASE.create_output_parameter(
|
||||
return await app.DATABASE.workflow_params.create_output_parameter(
|
||||
workflow_id=workflow_id,
|
||||
key=output_parameter_key,
|
||||
description=f"Output parameter for block {block_yaml.label}",
|
||||
|
|
@ -4875,7 +4879,7 @@ class WorkflowService:
|
|||
"""
|
||||
build the tree structure of the workflow run timeline
|
||||
"""
|
||||
workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
workflow_run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -4884,7 +4888,7 @@ class WorkflowService:
|
|||
task_id_to_block: dict[str, WorkflowRunBlock] = {
|
||||
block.task_id: block for block in workflow_run_blocks if block.task_id
|
||||
}
|
||||
actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids, organization_id=organization_id)
|
||||
actions = await app.DATABASE.tasks.get_tasks_actions(task_ids=task_ids, organization_id=organization_id)
|
||||
for action in actions:
|
||||
if not action.task_id:
|
||||
continue
|
||||
|
|
@ -5005,7 +5009,7 @@ class WorkflowService:
|
|||
return None
|
||||
|
||||
cached_block_labels: set[str] = set()
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=existing_script.script_revision_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -5096,7 +5100,7 @@ class WorkflowService:
|
|||
return
|
||||
|
||||
# Get the latest version number so we can increment it
|
||||
version_stats = await app.DATABASE.get_script_version_stats(
|
||||
version_stats = await app.DATABASE.scripts.get_script_version_stats(
|
||||
organization_id=workflow.organization_id,
|
||||
script_ids=[existing_script.script_id],
|
||||
)
|
||||
|
|
@ -5117,7 +5121,7 @@ class WorkflowService:
|
|||
)
|
||||
|
||||
# Create a new version of the SAME script_id instead of a new script
|
||||
regenerated_script = await app.DATABASE.create_script(
|
||||
regenerated_script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=workflow.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
script_id=existing_script.script_id,
|
||||
|
|
@ -5135,7 +5139,7 @@ class WorkflowService:
|
|||
|
||||
# If generation failed (e.g. syntax error), clean up the empty script row
|
||||
# to avoid orphaned versions that skip version numbers on next regeneration.
|
||||
script_files = await app.DATABASE.get_script_files(
|
||||
script_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=regenerated_script.script_revision_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -5145,7 +5149,7 @@ class WorkflowService:
|
|||
script_id=regenerated_script.script_id,
|
||||
version=regenerated_script.version,
|
||||
)
|
||||
await app.DATABASE.soft_delete_script_by_revision(
|
||||
await app.DATABASE.scripts.soft_delete_script_by_revision(
|
||||
script_revision_id=regenerated_script.script_revision_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -5194,7 +5198,7 @@ class WorkflowService:
|
|||
await _regenerate_script()
|
||||
return
|
||||
|
||||
created_script = await app.DATABASE.create_script(
|
||||
created_script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=workflow.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
|
@ -5236,7 +5240,7 @@ class WorkflowService:
|
|||
# Check if the script is pinned — skip auto-review for pinned scripts.
|
||||
# Query by script_id (not workflow_run_id) because pinning is applied
|
||||
# at the cache_key_value level and may not be on this run's row.
|
||||
if await app.DATABASE.is_script_pinned(
|
||||
if await app.DATABASE.scripts.is_script_pinned(
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=script_id,
|
||||
):
|
||||
|
|
@ -5357,7 +5361,7 @@ class WorkflowService:
|
|||
) -> None:
|
||||
"""Run the script reviewer inside a lock. Episodes are scoped to the script version."""
|
||||
# Double-check: re-query episodes after acquiring lock (another process may have reviewed them)
|
||||
episodes = await app.DATABASE.get_unreviewed_episodes(
|
||||
episodes = await app.DATABASE.scripts.get_unreviewed_episodes(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
script_revision_id=script_revision_id,
|
||||
|
|
@ -5375,7 +5379,7 @@ class WorkflowService:
|
|||
# Query stale branches for TTL-based pruning
|
||||
stale_branches: list = []
|
||||
try:
|
||||
stale_branches = await app.DATABASE.get_stale_branches(
|
||||
stale_branches = await app.DATABASE.scripts.get_stale_branches(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
stale_days=90,
|
||||
|
|
@ -5393,7 +5397,7 @@ class WorkflowService:
|
|||
# Use the latest version as the base (not the potentially-stale run revision)
|
||||
reviewer_base_revision_id = script_revision_id
|
||||
try:
|
||||
latest = await app.DATABASE.get_latest_script_version(
|
||||
latest = await app.DATABASE.scripts.get_latest_script_version(
|
||||
script_id=script_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -5405,7 +5409,7 @@ class WorkflowService:
|
|||
# Fetch historical (already-reviewed) episodes for cross-run context
|
||||
historical_episodes: list = []
|
||||
try:
|
||||
historical_episodes = await app.DATABASE.get_recent_reviewed_episodes(
|
||||
historical_episodes = await app.DATABASE.scripts.get_recent_reviewed_episodes(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
organization_id=workflow.organization_id,
|
||||
limit=20,
|
||||
|
|
@ -5457,7 +5461,7 @@ class WorkflowService:
|
|||
# use context.parameters['recipient'] instead of a literal string).
|
||||
run_parameter_values: dict[str, str] = {}
|
||||
try:
|
||||
run_param_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
run_param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
for wf_param, run_param in run_param_tuples:
|
||||
|
|
@ -5513,7 +5517,7 @@ class WorkflowService:
|
|||
)
|
||||
# Still mark episodes as reviewed
|
||||
for episode in episodes:
|
||||
await app.DATABASE.mark_episode_reviewed(
|
||||
await app.DATABASE.scripts.mark_episode_reviewed(
|
||||
episode_id=episode.episode_id,
|
||||
organization_id=workflow.organization_id,
|
||||
reviewer_output=None,
|
||||
|
|
@ -5523,7 +5527,7 @@ class WorkflowService:
|
|||
# Get the base script to create a new version from
|
||||
base_script = None
|
||||
if script_revision_id:
|
||||
base_script = await app.DATABASE.get_script_revision(
|
||||
base_script = await app.DATABASE.scripts.get_script_revision(
|
||||
script_revision_id=script_revision_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
|
@ -5558,7 +5562,7 @@ class WorkflowService:
|
|||
|
||||
# Mark all episodes as reviewed
|
||||
for episode in episodes:
|
||||
await app.DATABASE.mark_episode_reviewed(
|
||||
await app.DATABASE.scripts.mark_episode_reviewed(
|
||||
episode_id=episode.episode_id,
|
||||
organization_id=workflow.organization_id,
|
||||
reviewer_output=str(updated_blocks) if updated_blocks else None,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ async def get_action_history(
|
|||
"""
|
||||
|
||||
# Get action results from the last history_window steps
|
||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
steps = await app.DATABASE.tasks.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||
# the last step is always the newly created one and it should be excluded from the history window
|
||||
window_steps = steps[-1 - history_window : -1]
|
||||
if current_step:
|
||||
|
|
|
|||
|
|
@ -53,12 +53,12 @@ async def check_running_tasks_or_workflows() -> tuple[bool, int, int]:
|
|||
stale_threshold = settings.CLEANUP_STALE_TASK_THRESHOLD_HOURS
|
||||
|
||||
# Check tasks
|
||||
active_tasks, stale_tasks = await app.DATABASE.get_running_tasks_info_globally(
|
||||
active_tasks, stale_tasks = await app.DATABASE.tasks.get_running_tasks_info_globally(
|
||||
stale_threshold_hours=stale_threshold
|
||||
)
|
||||
|
||||
# Check workflow runs
|
||||
active_workflows, stale_workflows = await app.DATABASE.get_running_workflow_runs_info_globally(
|
||||
active_workflows, stale_workflows = await app.DATABASE.workflow_runs.get_running_workflow_runs_info_globally(
|
||||
stale_threshold_hours=stale_threshold
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from skyvern.services import task_v1_service, task_v2_service, webhook_service,
|
|||
|
||||
|
||||
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
# try to see if it's a workflow run id for task v2
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2_by_workflow_run_id(run_id, organization_id=organization_id)
|
||||
if task_v2:
|
||||
run = await app.DATABASE.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
|
||||
run = await app.DATABASE.tasks.get_run(task_v2.observer_cruise_id, organization_id=organization_id)
|
||||
|
||||
if not run:
|
||||
return None
|
||||
|
|
@ -73,7 +73,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
|||
step_count=task_v1_response.step_count,
|
||||
)
|
||||
elif run.task_run_type == RunType.task_v2:
|
||||
task_v2 = await app.DATABASE.get_task_v2(run.run_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(run.run_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
return None
|
||||
return await task_v2_service.build_task_v2_run_response(task_v2)
|
||||
|
|
@ -83,7 +83,7 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
|||
|
||||
|
||||
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
task = await app.agent.update_task(task, status=TaskStatus.canceled)
|
||||
|
|
@ -91,7 +91,7 @@ async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_k
|
|||
|
||||
|
||||
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(task_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
await task_v2_service.mark_task_v2_as_canceled(
|
||||
|
|
@ -102,7 +102,7 @@ async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> No
|
|||
async def cancel_workflow_run(
|
||||
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
|
||||
) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -110,7 +110,7 @@ async def cancel_workflow_run(
|
|||
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
child_workflow_runs = await app.DATABASE.workflow_runs.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -128,7 +128,7 @@ async def cancel_workflow_run(
|
|||
|
||||
|
||||
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -156,7 +156,7 @@ async def retry_run_webhook(
|
|||
) -> None:
|
||||
"""Retry sending the webhook for a run."""
|
||||
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -175,19 +175,19 @@ async def retry_run_webhook(
|
|||
return
|
||||
|
||||
if run.task_run_type in [RunType.task_v1, RunType.openai_cua, RunType.anthropic_cua, RunType.ui_tars]:
|
||||
task = await app.DATABASE.get_task(run_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(run_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise TaskNotFound(task_id=run_id)
|
||||
latest_step = await app.DATABASE.get_latest_step(run_id, organization_id=organization_id)
|
||||
latest_step = await app.DATABASE.tasks.get_latest_step(run_id, organization_id=organization_id)
|
||||
if latest_step:
|
||||
await app.agent.execute_task_webhook(task=task, api_key=api_key)
|
||||
elif run.task_run_type == RunType.task_v2:
|
||||
task_v2 = await app.DATABASE.get_task_v2(run_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(run_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
raise TaskNotFound(task_id=run_id)
|
||||
await task_v2_service.send_task_v2_webhook(task_v2)
|
||||
elif run.task_run_type == RunType.workflow_run:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class ScriptReviewer:
|
|||
|
||||
async def _load_run_params(run_id: str) -> tuple[str, dict[str, str]]:
|
||||
try:
|
||||
param_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=run_id)
|
||||
param_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=run_id)
|
||||
params = {
|
||||
wf_param.key: str(run_param.value)
|
||||
for wf_param, run_param in param_tuples
|
||||
|
|
@ -213,7 +213,7 @@ class ScriptReviewer:
|
|||
triaged_episodes.append(episode)
|
||||
else:
|
||||
# Mark as reviewed so we don't re-triage on every run
|
||||
await app.DATABASE.mark_episode_reviewed(
|
||||
await app.DATABASE.scripts.mark_episode_reviewed(
|
||||
episode_id=episode.episode_id,
|
||||
organization_id=organization_id,
|
||||
reviewer_output="TRIAGE: not_code_fixable — skipped",
|
||||
|
|
@ -1081,20 +1081,20 @@ class ScriptReviewer:
|
|||
return None
|
||||
|
||||
try:
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
for sb in script_blocks:
|
||||
if sb.script_block_label == block_label and sb.script_file_id:
|
||||
# Load the code from the script file
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id=script_revision_id,
|
||||
file_id=sb.script_file_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if script_file and script_file.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
artifact_id=script_file.artifact_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1124,7 +1124,7 @@ class ScriptReviewer:
|
|||
if not script_revision_id:
|
||||
return {}
|
||||
try:
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1132,13 +1132,13 @@ class ScriptReviewer:
|
|||
for sb in script_blocks:
|
||||
if not sb.script_file_id:
|
||||
continue
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id=script_revision_id,
|
||||
file_id=sb.script_file_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if script_file and script_file.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
artifact_id=script_file.artifact_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ async def build_file_tree(
|
|||
try:
|
||||
if pending:
|
||||
# get the script file object
|
||||
script_file = await app.DATABASE.get_script_file_by_path(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_path(
|
||||
script_revision_id=script_revision_id,
|
||||
file_path=file.path,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -140,7 +140,7 @@ async def build_file_tree(
|
|||
script_file_id=script_file.file_id,
|
||||
)
|
||||
continue
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
# override the actual file in the storage
|
||||
asyncio.create_task(app.STORAGE.store_artifact(artifact, content_bytes))
|
||||
|
|
@ -153,7 +153,7 @@ async def build_file_tree(
|
|||
data=content_bytes,
|
||||
)
|
||||
# update the artifact_id in the script file
|
||||
await app.DATABASE.update_script_file(
|
||||
await app.DATABASE.scripts.update_script_file(
|
||||
script_file_id=script_file.file_id,
|
||||
organization_id=organization_id,
|
||||
artifact_id=artifact_id,
|
||||
|
|
@ -174,7 +174,7 @@ async def build_file_tree(
|
|||
script_version=script_version,
|
||||
)
|
||||
# create a script file record
|
||||
await app.DATABASE.create_script_file(
|
||||
await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -202,7 +202,7 @@ async def build_file_tree(
|
|||
script_version=script_version,
|
||||
)
|
||||
# create a script file record
|
||||
await app.DATABASE.create_script_file(
|
||||
await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -264,10 +264,10 @@ async def create_script(
|
|||
)
|
||||
|
||||
try:
|
||||
if run_id and not await app.DATABASE.get_run(run_id=run_id, organization_id=organization_id):
|
||||
if run_id and not await app.DATABASE.tasks.get_run(run_id=run_id, organization_id=organization_id):
|
||||
raise HTTPException(status_code=404, detail=f"Run_id {run_id} not found")
|
||||
|
||||
script = await app.DATABASE.create_script(
|
||||
script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
|
@ -306,7 +306,7 @@ async def load_scripts(
|
|||
# retrieve the artifact
|
||||
if not file.artifact_id:
|
||||
continue
|
||||
artifact = await app.DATABASE.get_artifact_by_id(file.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(file.artifact_id, organization_id)
|
||||
if not artifact:
|
||||
LOG.error("Artifact not found", artifact_id=file.artifact_id, script_id=script.script_id)
|
||||
continue
|
||||
|
|
@ -345,7 +345,7 @@ async def execute_script(
|
|||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
# step 1: get the script revision
|
||||
script = await app.DATABASE.get_script(
|
||||
script = await app.DATABASE.scripts.get_script(
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -353,7 +353,7 @@ async def execute_script(
|
|||
raise ScriptNotFound(script_id=script_id)
|
||||
|
||||
# step 2: get the script files
|
||||
script_files = await app.DATABASE.get_script_files(
|
||||
script_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=script.script_revision_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ async def execute_script(
|
|||
|
||||
# step 4: execute the script
|
||||
if workflow_run_id and not parameters:
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
LOG.info("Script run Parameters is using workflow run parameters", parameters=parameters)
|
||||
|
||||
|
|
@ -453,7 +453,7 @@ async def _create_workflow_block_run_and_task(
|
|||
current_value_str = str(cv) if cv is not None else None
|
||||
current_index_val = context.loop_metadata.get("current_index")
|
||||
|
||||
workflow_run_block = await app.DATABASE.create_workflow_run_block(
|
||||
workflow_run_block = await app.DATABASE.observer.create_workflow_run_block(
|
||||
workflow_run_id=workflow_run_id,
|
||||
parent_workflow_run_block_id=context.parent_workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -486,7 +486,7 @@ async def _create_workflow_block_run_and_task(
|
|||
# Without this, the LLM only sees a generic navigation_goal and can
|
||||
# falsely mark login as complete when credentials were never entered.
|
||||
task_complete_criterion = DEFAULT_LOGIN_COMPLETE_CRITERION if block_type == BlockType.LOGIN else None
|
||||
task = await app.DATABASE.create_task(
|
||||
task = await app.DATABASE.tasks.create_task(
|
||||
# fix HACK: changed the type of url to str | None to support None url. url is not used in the script right now.
|
||||
url=url or "",
|
||||
title=f"Script {block_type.value} task",
|
||||
|
|
@ -508,7 +508,7 @@ async def _create_workflow_block_run_and_task(
|
|||
task_id = task.task_id
|
||||
|
||||
# create a single step for the task
|
||||
step = await app.DATABASE.create_step(
|
||||
step = await app.DATABASE.tasks.create_step(
|
||||
task_id=task_id,
|
||||
order=0,
|
||||
retry_index=0,
|
||||
|
|
@ -525,7 +525,7 @@ async def _create_workflow_block_run_and_task(
|
|||
)
|
||||
|
||||
# Update workflow run block with task_id
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -616,7 +616,7 @@ async def _record_output_parameter_value(
|
|||
parameter=output_parameter,
|
||||
value=output,
|
||||
)
|
||||
await app.DATABASE.create_or_update_workflow_run_output_parameter(
|
||||
await app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=output_parameter.output_parameter_id,
|
||||
value=output,
|
||||
|
|
@ -672,7 +672,7 @@ async def _update_workflow_block(
|
|||
errors=errors,
|
||||
)
|
||||
|
||||
await app.DATABASE.update_step(
|
||||
await app.DATABASE.tasks.update_step(
|
||||
step_id=step_id,
|
||||
task_id=task_id,
|
||||
organization_id=context.organization_id,
|
||||
|
|
@ -680,7 +680,7 @@ async def _update_workflow_block(
|
|||
is_last=is_last,
|
||||
output=step_output,
|
||||
)
|
||||
updated_task = await app.DATABASE.update_task(
|
||||
updated_task = await app.DATABASE.tasks.update_task(
|
||||
task_id=task_id,
|
||||
organization_id=context.organization_id,
|
||||
status=task_status,
|
||||
|
|
@ -719,7 +719,7 @@ async def _update_workflow_block(
|
|||
final_output = task_output.model_dump()
|
||||
step_for_billing: Step | None = None
|
||||
if step_id:
|
||||
step_for_billing = await app.DATABASE.get_step(
|
||||
step_for_billing = await app.DATABASE.tasks.get_step(
|
||||
step_id=step_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -746,7 +746,7 @@ async def _update_workflow_block(
|
|||
# final_output is already set to `output` at line 596.
|
||||
pass
|
||||
|
||||
await app.DATABASE.update_workflow_run_block(
|
||||
await app.DATABASE.observer.update_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id if context else None,
|
||||
status=status,
|
||||
|
|
@ -855,7 +855,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
|
|||
return
|
||||
|
||||
try:
|
||||
script_block = await app.DATABASE.get_script_block_by_label(
|
||||
script_block = await app.DATABASE.scripts.get_script_block_by_label(
|
||||
organization_id=context.organization_id,
|
||||
script_revision_id=context.script_revision_id,
|
||||
script_block_label=cache_key,
|
||||
|
|
@ -873,7 +873,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
|
|||
return
|
||||
|
||||
try:
|
||||
source_block = await app.DATABASE.get_workflow_run_block(
|
||||
source_block = await app.DATABASE.observer.get_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
|
|
@ -886,7 +886,9 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
|
|||
|
||||
try:
|
||||
# actios are ordered by created_at
|
||||
actions = await app.DATABASE.get_task_actions_hydrated(task_id=task_id, organization_id=context.organization_id)
|
||||
actions = await app.DATABASE.tasks.get_task_actions_hydrated(
|
||||
task_id=task_id, organization_id=context.organization_id
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
|
@ -938,7 +940,7 @@ async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_
|
|||
)
|
||||
step = None
|
||||
if step_id:
|
||||
step = await app.DATABASE.get_step(step_id=step_id, organization_id=context.organization_id)
|
||||
step = await app.DATABASE.tasks.get_step(step_id=step_id, organization_id=context.organization_id)
|
||||
llm_response = await app.SCRIPT_GENERATION_LLM_API_HANDLER(
|
||||
prompt=merged_prompt,
|
||||
prompt_name="merged-block-inputs",
|
||||
|
|
@ -1102,7 +1104,7 @@ async def _fallback_to_ai_run(
|
|||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
# 1. fail the previous step
|
||||
previous_step = await app.DATABASE.update_step(
|
||||
previous_step = await app.DATABASE.tasks.update_step(
|
||||
step_id=script_step_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1112,7 +1114,7 @@ async def _fallback_to_ai_run(
|
|||
organization = await app.DATABASE.organizations.get_organization(organization_id=organization_id)
|
||||
if not organization:
|
||||
raise Exception(f"Organization is missing organization_id={organization_id}")
|
||||
task = await app.DATABASE.get_task(task_id=context.task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=context.task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise Exception(f"Task is missing task_id={context.task_id}")
|
||||
workflow = await app.DATABASE.workflows.get_workflow(
|
||||
|
|
@ -1120,7 +1122,7 @@ async def _fallback_to_ai_run(
|
|||
)
|
||||
if not workflow:
|
||||
return
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
if not workflow_run:
|
||||
|
|
@ -1157,7 +1159,7 @@ async def _fallback_to_ai_run(
|
|||
if detected_errors:
|
||||
task_errors = task.errors or []
|
||||
task_errors.extend([error.model_dump() for error in detected_errors])
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
errors=task_errors,
|
||||
|
|
@ -1241,7 +1243,7 @@ async def _fallback_to_ai_run(
|
|||
# _fallback_to_ai_run is only called for TaskBlock-style blocks (navigation,
|
||||
# extraction, action, login, download), never for ConditionalBlock, so
|
||||
# fallback_type is always "full_block" here.
|
||||
episode = await app.DATABASE.create_fallback_episode(
|
||||
episode = await app.DATABASE.scripts.create_fallback_episode(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1262,7 +1264,7 @@ async def _fallback_to_ai_run(
|
|||
)
|
||||
|
||||
# 2. create a new step for ai run
|
||||
ai_step = await app.DATABASE.create_step(
|
||||
ai_step = await app.DATABASE.tasks.create_step(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
order=previous_step.order + 1,
|
||||
|
|
@ -1323,7 +1325,7 @@ async def _fallback_to_ai_run(
|
|||
|
||||
# update workflow run to indicate that there's a script run
|
||||
if workflow_run_id:
|
||||
await app.DATABASE.update_workflow_run(
|
||||
await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
ai_fallback_triggered=True,
|
||||
)
|
||||
|
|
@ -1332,7 +1334,7 @@ async def _fallback_to_ai_run(
|
|||
if workflow_run_block_id:
|
||||
# refresh the task
|
||||
failure_reason = None
|
||||
refreshed_task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
refreshed_task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if refreshed_task:
|
||||
task = refreshed_task
|
||||
if task.status in [TaskStatus.terminated, TaskStatus.failed]:
|
||||
|
|
@ -1389,7 +1391,7 @@ async def _fallback_to_ai_run(
|
|||
if not fallback_succeeded and task.failure_reason:
|
||||
agent_actions_summary["failure_reason"] = str(task.failure_reason)[:2000]
|
||||
try:
|
||||
actions = await app.DATABASE.get_task_actions(
|
||||
actions = await app.DATABASE.tasks.get_task_actions(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1397,7 +1399,7 @@ async def _fallback_to_ai_run(
|
|||
except Exception:
|
||||
LOG.debug("Could not fetch actions for fallback episode", exc_info=True)
|
||||
|
||||
await app.DATABASE.update_fallback_episode(
|
||||
await app.DATABASE.scripts.update_fallback_episode(
|
||||
episode_id=fallback_episode_id,
|
||||
organization_id=organization_id,
|
||||
agent_actions=agent_actions_summary,
|
||||
|
|
@ -1468,7 +1470,9 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
cache_key_value = ""
|
||||
if workflow.cache_key:
|
||||
try:
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run_id
|
||||
)
|
||||
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
cache_key_value = jinja_sandbox_env.from_string(workflow.cache_key).render(parameters)
|
||||
except Exception as e:
|
||||
|
|
@ -1479,7 +1483,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
if not cache_key_value:
|
||||
cache_key_value = cache_key # Fallback
|
||||
|
||||
existing_script, _is_pinned = await app.DATABASE.get_workflow_script_by_cache_key_value(
|
||||
existing_script, _is_pinned = await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
|
|
@ -1501,7 +1505,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
)
|
||||
|
||||
# Create a new script version
|
||||
new_script = await app.DATABASE.create_script(
|
||||
new_script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=organization_id,
|
||||
run_id=workflow_run_id,
|
||||
script_id=current_script.script_id, # Use same script_id for versioning
|
||||
|
|
@ -1509,14 +1513,14 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
)
|
||||
|
||||
# deprecate the current workflow script
|
||||
await app.DATABASE.delete_workflow_cache_key_value(
|
||||
await app.DATABASE.scripts.delete_workflow_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
)
|
||||
|
||||
# Create workflow script mapping for the new version
|
||||
await app.DATABASE.create_workflow_script(
|
||||
await app.DATABASE.scripts.create_workflow_script(
|
||||
organization_id=organization_id,
|
||||
script_id=new_script.script_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
|
|
@ -1527,7 +1531,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
)
|
||||
|
||||
# Get all existing script blocks from the previous version
|
||||
existing_script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
existing_script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=current_script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1553,7 +1557,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
# Copy the existing block to the new version
|
||||
# Get the script file content for this block and copy a new script block for it
|
||||
if existing_block.script_file_id:
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id=current_script.script_revision_id,
|
||||
file_id=existing_block.script_file_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1561,7 +1565,9 @@ async def _regenerate_script_block_after_ai_fallback(
|
|||
|
||||
if script_file and script_file.artifact_id:
|
||||
# Retrieve the artifact content
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(
|
||||
script_file.artifact_id, organization_id
|
||||
)
|
||||
if artifact:
|
||||
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
||||
if file_content:
|
||||
|
|
@ -1650,7 +1656,7 @@ async def _get_block_definition_by_label(
|
|||
if not final_dump:
|
||||
return None
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
task = await app.DATABASE.tasks.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if task:
|
||||
task_dump = task.model_dump()
|
||||
final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump})
|
||||
|
|
@ -1681,7 +1687,7 @@ async def _generate_block_code_from_task(
|
|||
return ""
|
||||
try:
|
||||
# Now regenerate only the specific block that fell back to AI
|
||||
task_actions = await app.DATABASE.get_task_actions_hydrated(
|
||||
task_actions = await app.DATABASE.tasks.get_task_actions_hydrated(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -2505,13 +2511,13 @@ async def run_script(
|
|||
context.script_revision_id = script_revision_id
|
||||
|
||||
if workflow_run_id and organization_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||
# update workfow run to indicate that there's a script run
|
||||
workflow_run = await app.DATABASE.update_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
ai_fallback_triggered=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ async def is_cua_task(
|
|||
|
||||
if task.workflow_run_id:
|
||||
# it's a task based block, should look up the block run to see if it's a CUA task
|
||||
block = await app.DATABASE.get_workflow_run_block_by_task_id(
|
||||
block = await app.DATABASE.observer.get_workflow_run_block_by_task_id(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
if block.engine is not None and block.engine in CUA_ENGINES:
|
||||
return True
|
||||
|
||||
run = await app.DATABASE.get_run(
|
||||
run = await app.DATABASE.tasks.get_run(
|
||||
run_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen
|
|||
user_prompt_hash = hash_object.hexdigest()
|
||||
# check if there's a same user_prompt within the past x Hours
|
||||
# in the future, we can use vector db to fetch similar prompts
|
||||
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||
existing_task_generation = await app.DATABASE.workflow_params.get_task_generation_by_prompt_hash(
|
||||
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
|
||||
)
|
||||
if existing_task_generation:
|
||||
new_task_generation = await app.DATABASE.create_task_generation(
|
||||
new_task_generation = await app.DATABASE.workflow_params.create_task_generation(
|
||||
organization_id=organization.organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
|
|
@ -53,7 +53,7 @@ async def generate_task(user_prompt: str, organization: Organization) -> TaskGen
|
|||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
# generate a TaskGenerationModel
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
task_generation = await app.DATABASE.workflow_params.create_task_generation(
|
||||
organization_id=organization.organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
|
|
@ -94,7 +94,7 @@ async def run_task(
|
|||
run_type = RunType.anthropic_cua
|
||||
elif engine == RunEngine.ui_tars:
|
||||
run_type = RunType.ui_tars
|
||||
await app.DATABASE.create_task_run(
|
||||
await app.DATABASE.tasks.create_task_run(
|
||||
task_run_type=run_type,
|
||||
organization_id=organization.organization_id,
|
||||
run_id=created_task.task_id,
|
||||
|
|
@ -123,15 +123,15 @@ async def run_task(
|
|||
|
||||
|
||||
async def get_task_v1_response(task_id: str, organization_id: str | None = None) -> TaskResponse:
|
||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=organization_id)
|
||||
task_obj = await app.DATABASE.tasks.get_task(task_id, organization_id=organization_id)
|
||||
if not task_obj:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
|
||||
# get step count efficiently via COUNT query
|
||||
step_count = await app.DATABASE.get_task_step_count(task_id, organization_id)
|
||||
step_count = await app.DATABASE.tasks.get_task_step_count(task_id, organization_id)
|
||||
|
||||
# get latest step
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
|
||||
latest_step = await app.DATABASE.tasks.get_latest_step(task_id, organization_id=organization_id)
|
||||
if not latest_step:
|
||||
return await app.agent.build_task_response(task=task_obj, step_count=step_count)
|
||||
|
||||
|
|
|
|||
|
|
@ -119,14 +119,14 @@ async def _summarize_max_steps_failure_reason(
|
|||
|
||||
screenshots = await SkyvernFrame.take_split_screenshots(page=page, url=str(task_v2.url), draw_boxes=False)
|
||||
|
||||
run_blocks = await app.DATABASE.get_workflow_run_blocks(
|
||||
run_blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
||||
workflow_run_id=task_v2.workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
history = [f"{idx + 1}. {block.description} -- {block.status}" for idx, block in enumerate(run_blocks[::-1])]
|
||||
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
workflow_run_id=task_v2.workflow_run_id,
|
||||
|
|
@ -197,7 +197,7 @@ async def _handle_task_v2_termination(
|
|||
)
|
||||
|
||||
# Create a dedicated termination thought for UI visibility
|
||||
termination_thought = await app.DATABASE.create_thought(
|
||||
termination_thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -216,7 +216,7 @@ async def _handle_task_v2_termination(
|
|||
if source:
|
||||
output["source"] = source
|
||||
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=termination_thought.observer_thought_id,
|
||||
organization_id=organization_id,
|
||||
output=output,
|
||||
|
|
@ -265,7 +265,7 @@ async def initialize_task_v2(
|
|||
browser_address: str | None = None,
|
||||
run_with: str | None = None,
|
||||
) -> TaskV2:
|
||||
task_v2 = await app.DATABASE.create_task_v2(
|
||||
task_v2 = await app.DATABASE.observer.create_task_v2(
|
||||
prompt=user_prompt,
|
||||
url=user_url if user_url else None,
|
||||
organization_id=organization.organization_id,
|
||||
|
|
@ -329,7 +329,7 @@ async def initialize_task_v2(
|
|||
|
||||
# update observer cruise
|
||||
try:
|
||||
task_v2 = await app.DATABASE.update_task_v2(
|
||||
task_v2 = await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_id=new_workflow.workflow_id,
|
||||
|
|
@ -337,7 +337,7 @@ async def initialize_task_v2(
|
|||
organization_id=organization.organization_id,
|
||||
)
|
||||
if create_task_run:
|
||||
await app.DATABASE.create_task_run(
|
||||
await app.DATABASE.tasks.create_task_run(
|
||||
task_run_type=RunType.task_v2,
|
||||
organization_id=organization.organization_id,
|
||||
run_id=task_v2.observer_cruise_id,
|
||||
|
|
@ -369,7 +369,7 @@ async def initialize_task_v2_metadata(
|
|||
current_browser_url: str | None,
|
||||
user_url: str | None,
|
||||
) -> TaskV2:
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=organization.organization_id,
|
||||
thought_type=ThoughtType.metadata,
|
||||
|
|
@ -403,7 +403,7 @@ async def initialize_task_v2_metadata(
|
|||
raise UrlGenerationFailure()
|
||||
|
||||
try:
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=organization.organization_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
|
|
@ -422,7 +422,7 @@ async def initialize_task_v2_metadata(
|
|||
organization_id=organization.organization_id,
|
||||
title=metadata.workflow_title,
|
||||
)
|
||||
task_v2 = await app.DATABASE.update_task_v2(
|
||||
task_v2 = await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
|
|
@ -430,11 +430,11 @@ async def initialize_task_v2_metadata(
|
|||
url=metadata.url,
|
||||
organization_id=organization.organization_id,
|
||||
)
|
||||
task_run = await app.DATABASE.get_run(
|
||||
task_run = await app.DATABASE.tasks.get_run(
|
||||
run_id=task_v2.observer_cruise_id, organization_id=organization.organization_id
|
||||
)
|
||||
if task_run:
|
||||
await app.DATABASE.update_task_run(
|
||||
await app.DATABASE.tasks.update_task_run(
|
||||
organization_id=organization.organization_id,
|
||||
run_id=task_v2.observer_cruise_id,
|
||||
title=metadata.workflow_title,
|
||||
|
|
@ -465,7 +465,7 @@ async def run_task_v2(
|
|||
) -> TaskV2:
|
||||
organization_id = organization.organization_id
|
||||
try:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_v2_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(task_v2_id, organization_id=organization_id)
|
||||
except Exception:
|
||||
LOG.error(
|
||||
"Failed to get task v2",
|
||||
|
|
@ -643,7 +643,7 @@ async def run_task_v2_helper(
|
|||
)
|
||||
)
|
||||
|
||||
task_v2 = await app.DATABASE.update_task_v2(
|
||||
task_v2 = await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2_id, organization_id=organization_id, status=TaskV2Status.running
|
||||
)
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
|
||||
|
|
@ -802,7 +802,7 @@ async def run_task_v2_helper(
|
|||
task_history=task_history,
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
)
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
|
|
@ -837,7 +837,7 @@ async def run_task_v2_helper(
|
|||
plan = task_v2_response.get("plan", "")
|
||||
task_type = task_v2_response.get("task_type", "")
|
||||
# Create and save task thought
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=organization_id,
|
||||
thought=thoughts,
|
||||
|
|
@ -1055,7 +1055,7 @@ async def run_task_v2_helper(
|
|||
task_history=task_history,
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
)
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1082,7 +1082,7 @@ async def run_task_v2_helper(
|
|||
termination_reason = completion_resp.get("termination_reason")
|
||||
completion_failure_categories = completion_resp.get("failure_categories")
|
||||
thought_content = completion_resp.get("thoughts", "")
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=organization_id,
|
||||
thought=thought_content,
|
||||
|
|
@ -1128,8 +1128,8 @@ async def run_task_v2_helper(
|
|||
return workflow, workflow_run, task_v2
|
||||
|
||||
# total step number validation
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
total_step_count = await app.DATABASE.get_total_unique_step_order_count_by_task_ids(
|
||||
workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
total_step_count = await app.DATABASE.tasks.get_total_unique_step_order_count_by_task_ids(
|
||||
task_ids=[task.task_id for task in workflow_run_tasks],
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1314,7 +1314,7 @@ async def _generate_loop_task(
|
|||
plan=plan,
|
||||
)
|
||||
data_extraction_thought = f"Going to generate a list of values to go through based on the plan: {plan}."
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1383,7 +1383,7 @@ async def _generate_loop_task(
|
|||
raise
|
||||
|
||||
# update the thought
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
output=output_value_obj,
|
||||
|
|
@ -1441,7 +1441,7 @@ async def _generate_loop_task(
|
|||
is_link=is_loop_value_link,
|
||||
loop_values=loop_values,
|
||||
)
|
||||
thought_task_in_loop = await app.DATABASE.create_thought(
|
||||
thought_task_in_loop = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
|
@ -1461,7 +1461,7 @@ async def _generate_loop_task(
|
|||
data_extraction_goal = task_in_loop_metadata_response.get("data_extraction_goal")
|
||||
data_extraction_schema = task_in_loop_metadata_response.get("data_schema")
|
||||
thought_content = task_in_loop_metadata_response.get("thoughts")
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought_task_in_loop.observer_thought_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
thought=thought_content,
|
||||
|
|
@ -1675,7 +1675,7 @@ async def _generate_goto_url_task(
|
|||
|
||||
|
||||
async def get_thought_timelines(*, task_v2_id: str, organization_id: str) -> list[WorkflowRunTimeline]:
|
||||
thoughts = await app.DATABASE.get_thoughts(
|
||||
thoughts = await app.DATABASE.observer.get_thoughts(
|
||||
task_v2_id=task_v2_id,
|
||||
organization_id=organization_id,
|
||||
thought_types=[
|
||||
|
|
@ -1695,7 +1695,7 @@ async def get_thought_timelines(*, task_v2_id: str, organization_id: str) -> lis
|
|||
|
||||
|
||||
async def get_task_v2(task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
|
||||
return await app.DATABASE.get_task_v2(task_v2_id, organization_id=organization_id)
|
||||
return await app.DATABASE.observer.get_task_v2(task_v2_id, organization_id=organization_id)
|
||||
|
||||
|
||||
async def _update_task_v2_status(
|
||||
|
|
@ -1706,7 +1706,7 @@ async def _update_task_v2_status(
|
|||
output: dict[str, Any] | None = None,
|
||||
failure_category: list[dict] | None = None,
|
||||
) -> TaskV2:
|
||||
task_v2 = await app.DATABASE.update_task_v2(
|
||||
task_v2 = await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
|
|
@ -1939,7 +1939,7 @@ async def _summarize_task_v2(
|
|||
context: SkyvernContext,
|
||||
screenshots: list[bytes] | None = None,
|
||||
) -> TaskV2:
|
||||
thought = await app.DATABASE.create_thought(
|
||||
thought = await app.DATABASE.observer.create_thought(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
workflow_run_id=task_v2.workflow_run_id,
|
||||
|
|
@ -1966,7 +1966,7 @@ async def _summarize_task_v2(
|
|||
|
||||
summary_description = task_v2_summary_resp.get("description")
|
||||
summarized_output = task_v2_summary_resp.get("output")
|
||||
await app.DATABASE.update_thought(
|
||||
await app.DATABASE.observer.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
thought=summary_description,
|
||||
|
|
@ -2081,7 +2081,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_task_v2(
|
||||
await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
webhook_failure_reason="",
|
||||
|
|
@ -2094,7 +2094,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
|
|||
resp_code=resp.status_code,
|
||||
resp_text=resp.text,
|
||||
)
|
||||
await app.DATABASE.update_task_v2(
|
||||
await app.DATABASE.observer.update_task_v2(
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
organization_id=task_v2.organization_id,
|
||||
webhook_failure_reason=f"Webhook failed with status code {resp.status_code}, error message: {resp.text}",
|
||||
|
|
|
|||
|
|
@ -267,13 +267,13 @@ async def replay_run_webhook(
|
|||
|
||||
|
||||
async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookPayload:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
run = await app.DATABASE.tasks.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
# Attempt to resolve task v2 runs that may not yet be in the runs table.
|
||||
task_v2 = await app.DATABASE.get_task_v2(run_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(run_id, organization_id=organization_id)
|
||||
if task_v2:
|
||||
return await _build_task_v2_payload(task_v2)
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -300,7 +300,7 @@ async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookP
|
|||
run_type_str=run_type,
|
||||
)
|
||||
if run.task_run_type == RunType.task_v2:
|
||||
task_v2 = await app.DATABASE.get_task_v2(run.run_id, organization_id=organization_id)
|
||||
task_v2 = await app.DATABASE.observer.get_task_v2(run.run_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
raise SkyvernHTTPException(
|
||||
f"Task v2 run {run_id} missing task record",
|
||||
|
|
@ -314,7 +314,7 @@ async def _build_webhook_payload(organization_id: str, run_id: str) -> _WebhookP
|
|||
|
||||
|
||||
async def _build_task_payload(organization_id: str, run_id: str, run_type_str: str) -> _WebhookPayload:
|
||||
task: Task | None = await app.DATABASE.get_task(run_id, organization_id=organization_id)
|
||||
task: Task | None = await app.DATABASE.tasks.get_task(run_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise TaskNotFound(task_id=run_id)
|
||||
if not task.status.is_final():
|
||||
|
|
@ -324,7 +324,7 @@ async def _build_task_payload(organization_id: str, run_id: str, run_type_str: s
|
|||
status=task.status,
|
||||
)
|
||||
raise WebhookReplayError(f"Run {run_id} has not reached a terminal state (status={task.status}).")
|
||||
latest_step = await app.DATABASE.get_latest_step(run_id, organization_id=organization_id)
|
||||
latest_step = await app.DATABASE.tasks.get_latest_step(run_id, organization_id=organization_id)
|
||||
task_response = await app.agent.build_task_response(task=task, last_step=latest_step)
|
||||
|
||||
payload_dict = json.loads(task_response.model_dump_json(exclude={"request"}))
|
||||
|
|
@ -382,7 +382,7 @@ async def _build_workflow_payload(
|
|||
organization_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> _WebhookPayload:
|
||||
workflow_run: WorkflowRun | None = await app.DATABASE.get_workflow_run(
|
||||
workflow_run: WorkflowRun | None = await app.DATABASE.workflow_runs.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -147,10 +147,12 @@ async def generate_or_update_pending_workflow_script(
|
|||
script_id = context.script_id
|
||||
script = None
|
||||
if script_id:
|
||||
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
|
||||
script = await app.DATABASE.scripts.get_script(script_id=script_id, organization_id=organization_id)
|
||||
|
||||
if not script:
|
||||
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
|
||||
script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=organization_id, run_id=workflow_run.workflow_run_id
|
||||
)
|
||||
if context:
|
||||
context.script_id = script.script_id
|
||||
context.script_revision_id = script.script_revision_id
|
||||
|
|
@ -184,7 +186,7 @@ async def get_workflow_script(
|
|||
rendered_cache_key_value = ""
|
||||
|
||||
try:
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
parameter_tuples = await app.DATABASE.workflow_runs.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
|
|
@ -296,7 +298,7 @@ async def get_workflow_script_by_cache_key_value(
|
|||
return _workflow_script_cache[cache_key_tuple]
|
||||
|
||||
# Cache miss - fetch from database
|
||||
script, is_pinned = await app.DATABASE.get_workflow_script_by_cache_key_value(
|
||||
script, is_pinned = await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
|
|
@ -311,7 +313,7 @@ async def get_workflow_script_by_cache_key_value(
|
|||
|
||||
return script, is_pinned
|
||||
|
||||
return await app.DATABASE.get_workflow_script_by_cache_key_value(
|
||||
return await app.DATABASE.scripts.get_workflow_script_by_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
|
|
@ -331,7 +333,7 @@ async def get_latest_published_script(
|
|||
variants), this returns the script with the highest version number to ensure
|
||||
the most recently reviewed code is selected.
|
||||
"""
|
||||
workflow_scripts = await app.DATABASE.get_workflow_scripts_by_permanent_id(
|
||||
workflow_scripts = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
statuses=[ScriptStatus.published],
|
||||
|
|
@ -344,7 +346,7 @@ async def get_latest_published_script(
|
|||
# TODO: add a bulk get_latest_script_versions() if this becomes a bottleneck.
|
||||
best: Script | None = None
|
||||
for ws in workflow_scripts:
|
||||
script = await app.DATABASE.get_latest_script_version(
|
||||
script = await app.DATABASE.scripts.get_latest_script_version(
|
||||
script_id=ws.script_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -362,7 +364,7 @@ async def _load_cached_script_block_sources(
|
|||
"""
|
||||
cached_blocks: dict[str, ScriptBlockSource] = {}
|
||||
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -373,13 +375,13 @@ async def _load_cached_script_block_sources(
|
|||
|
||||
code_str: str | None = None
|
||||
if script_block.script_file_id:
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_file = await app.DATABASE.scripts.get_script_file_by_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
file_id=script_block.script_file_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if script_file and script_file.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
||||
if isinstance(file_content, bytes):
|
||||
|
|
@ -539,7 +541,7 @@ async def generate_workflow_script(
|
|||
status = ScriptStatus.published
|
||||
if pending:
|
||||
status = ScriptStatus.pending
|
||||
existing_pending_workflow_script = await app.DATABASE.get_workflow_script(
|
||||
existing_pending_workflow_script = await app.DATABASE.scripts.get_workflow_script(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
|
|
@ -547,7 +549,7 @@ async def generate_workflow_script(
|
|||
)
|
||||
if not existing_pending_workflow_script:
|
||||
# Record the workflow->script mapping for cache lookup
|
||||
await app.DATABASE.create_workflow_script(
|
||||
await app.DATABASE.scripts.create_workflow_script(
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=script.script_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
|
|
@ -965,13 +967,13 @@ async def _find_main_py_content(script_id: str, organization_id: str, base_revis
|
|||
"""
|
||||
|
||||
async def _load_main_py_from_revision(revision_id: str) -> str | None:
|
||||
files = await app.DATABASE.get_script_files(
|
||||
files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
for f in files:
|
||||
if f.file_path == "main.py" and f.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(f.artifact_id, organization_id)
|
||||
artifact = await app.DATABASE.artifacts.get_artifact_by_id(f.artifact_id, organization_id)
|
||||
if artifact:
|
||||
content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
||||
if content:
|
||||
|
|
@ -984,7 +986,7 @@ async def _find_main_py_content(script_id: str, organization_id: str, base_revis
|
|||
return result
|
||||
|
||||
# Fall back to v1 (bootstrapping: base was created before this fix)
|
||||
v1_script = await app.DATABASE.get_script(
|
||||
v1_script = await app.DATABASE.scripts.get_script(
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
version=1,
|
||||
|
|
@ -1022,7 +1024,7 @@ async def create_script_version_from_review(
|
|||
# _trigger_script_reviewer() already gates on is_script_pinned(), but
|
||||
# that check can be bypassed when the skyvern context is missing.
|
||||
# Guard here so no code path can mutate a pinned script.
|
||||
if await app.DATABASE.is_script_pinned(
|
||||
if await app.DATABASE.scripts.is_script_pinned(
|
||||
organization_id=organization_id,
|
||||
script_id=base_script.script_id,
|
||||
):
|
||||
|
|
@ -1035,7 +1037,7 @@ async def create_script_version_from_review(
|
|||
return None
|
||||
|
||||
# Create a new script version
|
||||
new_script = await app.DATABASE.create_script(
|
||||
new_script = await app.DATABASE.scripts.create_script(
|
||||
organization_id=organization_id,
|
||||
script_id=base_script.script_id,
|
||||
version=base_script.version + 1,
|
||||
|
|
@ -1043,7 +1045,7 @@ async def create_script_version_from_review(
|
|||
)
|
||||
|
||||
# Copy existing script blocks from the base revision
|
||||
existing_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
existing_blocks = await app.DATABASE.scripts.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=base_script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1066,7 +1068,7 @@ async def create_script_version_from_review(
|
|||
file_path=file_path,
|
||||
data=content_bytes,
|
||||
)
|
||||
new_file = await app.DATABASE.create_script_file(
|
||||
new_file = await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1090,7 +1092,7 @@ async def create_script_version_from_review(
|
|||
)
|
||||
|
||||
# Create script block entry pointing to the new file
|
||||
await app.DATABASE.create_script_block(
|
||||
await app.DATABASE.scripts.create_script_block(
|
||||
organization_id=organization_id,
|
||||
script_id=new_script.script_id,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
|
|
@ -1103,7 +1105,7 @@ async def create_script_version_from_review(
|
|||
)
|
||||
else:
|
||||
# Copy existing block as-is
|
||||
await app.DATABASE.create_script_block(
|
||||
await app.DATABASE.scripts.create_script_block(
|
||||
organization_id=organization_id,
|
||||
script_id=new_script.script_id,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
|
|
@ -1136,7 +1138,7 @@ async def create_script_version_from_review(
|
|||
file_path="main.py",
|
||||
data=patched_bytes,
|
||||
)
|
||||
await app.DATABASE.create_script_file(
|
||||
await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1158,19 +1160,19 @@ async def create_script_version_from_review(
|
|||
|
||||
# Copy non-block files (e.g., .skyvern metadata) from the base revision
|
||||
# or v1 — whichever has the full file set
|
||||
source_files = await app.DATABASE.get_script_files(
|
||||
source_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=base_script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not any(f.file_path != "main.py" and not f.file_path.startswith("blocks/") for f in source_files):
|
||||
# Base revision has no non-block files, fall back to v1
|
||||
v1_script = await app.DATABASE.get_script(
|
||||
v1_script = await app.DATABASE.scripts.get_script(
|
||||
script_id=base_script.script_id,
|
||||
organization_id=organization_id,
|
||||
version=1,
|
||||
)
|
||||
if v1_script:
|
||||
source_files = await app.DATABASE.get_script_files(
|
||||
source_files = await app.DATABASE.scripts.get_script_files(
|
||||
script_revision_id=v1_script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -1180,7 +1182,7 @@ async def create_script_version_from_review(
|
|||
# Skip main.py (already patched) and updated block files (already created)
|
||||
if f.file_path == "main.py" or f.file_path in updated_block_file_paths:
|
||||
continue
|
||||
await app.DATABASE.create_script_file(
|
||||
await app.DATABASE.scripts.create_script_file(
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -1208,7 +1210,7 @@ async def create_script_version_from_review(
|
|||
)
|
||||
else:
|
||||
# No workflow run — look up the existing cache key value from the base script
|
||||
existing_ws = await app.DATABASE.get_workflow_scripts_by_permanent_id(
|
||||
existing_ws = await app.DATABASE.scripts.get_workflow_scripts_by_permanent_id(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
statuses=[ScriptStatus.published],
|
||||
|
|
@ -1219,7 +1221,7 @@ async def create_script_version_from_review(
|
|||
rendered_cache_key_value = ws.cache_key_value
|
||||
break
|
||||
|
||||
await app.DATABASE.create_workflow_script(
|
||||
await app.DATABASE.scripts.create_workflow_script(
|
||||
organization_id=organization_id,
|
||||
script_id=new_script.script_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ async def prepare_workflow(
|
|||
version=version,
|
||||
)
|
||||
|
||||
await app.DATABASE.create_task_run(
|
||||
await app.DATABASE.tasks.create_task_run(
|
||||
task_run_type=RunType.workflow_run,
|
||||
organization_id=organization.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
|
|
@ -120,7 +120,7 @@ async def run_workflow(
|
|||
async def get_workflow_run_response(
|
||||
workflow_run_id: str, organization_id: str | None = None
|
||||
) -> WorkflowRunResponse | None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||
workflow_run = await app.DATABASE.workflow_runs.get_workflow_run(workflow_run_id, organization_id=organization_id)
|
||||
if not workflow_run:
|
||||
return None
|
||||
workflow_run_resp = await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(
|
||||
|
|
|
|||
|
|
@ -24,14 +24,14 @@ async def _retrieve_action_plan(task: Task, step: Step, scraped_page: ScrapedPag
|
|||
# V0: use the previous action plan if there is a completed task with the same url and navigation goal
|
||||
# get completed task with the same url and navigation goal
|
||||
# TODO(kerem): don't use step_order, get all the previous actions instead
|
||||
cached_actions = await app.DATABASE.retrieve_action_plan(task=task)
|
||||
cached_actions = await app.DATABASE.workflow_params.retrieve_action_plan(task=task)
|
||||
if not cached_actions:
|
||||
LOG.info("No cached actions found for the task, fallback to no-cache mode")
|
||||
return []
|
||||
|
||||
# Get the existing actions for this task from the database. Then find the actions that are already executed by looking at
|
||||
# the source_action_id field for this task's actions.
|
||||
previous_actions = await app.DATABASE.get_previous_actions_for_task(task_id=task.task_id)
|
||||
previous_actions = await app.DATABASE.tasks.get_previous_actions_for_task(task_id=task.task_id)
|
||||
|
||||
executed_cached_actions = []
|
||||
remaining_cached_actions = []
|
||||
|
|
|
|||
|
|
@ -409,7 +409,7 @@ class ActionHandler:
|
|||
page=page,
|
||||
action=action,
|
||||
)
|
||||
persisted_action = await app.DATABASE.create_action(action=action)
|
||||
persisted_action = await app.DATABASE.workflow_params.create_action(action=action)
|
||||
action.action_id = persisted_action.action_id
|
||||
return results
|
||||
|
||||
|
|
@ -545,7 +545,7 @@ class ActionHandler:
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
persisted_action = await app.DATABASE.create_action(action=action)
|
||||
persisted_action = await app.DATABASE.workflow_params.create_action(action=action)
|
||||
action.action_id = persisted_action.action_id
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -847,7 +847,7 @@ async def generate_cua_fallback_actions(
|
|||
assistant_message=assistant_message,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
await app.DATABASE.update_task(
|
||||
await app.DATABASE.tasks.update_task(
|
||||
task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
extracted_information=assistant_message,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ async def validate_session_for_renewal(
|
|||
Validate a specific browser session for renewal. Otherwise raise.
|
||||
"""
|
||||
|
||||
browser_session = await database.get_persistent_browser_session(
|
||||
browser_session = await database.browser_sessions.get_persistent_browser_session(
|
||||
session_id=session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
|
@ -117,7 +117,7 @@ async def renew_session(database: AgentDB, session_id: str, organization_id: str
|
|||
minutes_diff = floor((new_timeout_datetime - current_timeout_datetime).total_seconds() / 60)
|
||||
new_timeout_minutes = current_timeout_minutes + minutes_diff
|
||||
|
||||
browser_session = await database.update_persistent_browser_session(
|
||||
browser_session = await database.browser_sessions.update_persistent_browser_session(
|
||||
session_id,
|
||||
organization_id=organization_id,
|
||||
timeout_minutes=new_timeout_minutes,
|
||||
|
|
@ -138,7 +138,7 @@ async def renew_session(database: AgentDB, session_id: str, organization_id: str
|
|||
async def update_status(
|
||||
db: AgentDB, session_id: str, organization_id: str, status: str
|
||||
) -> PersistentBrowserSession | None:
|
||||
persistent_browser_session = await db.get_persistent_browser_session(session_id, organization_id)
|
||||
persistent_browser_session = await db.browser_sessions.get_persistent_browser_session(session_id, organization_id)
|
||||
|
||||
if not persistent_browser_session:
|
||||
LOG.warning(
|
||||
|
|
@ -167,7 +167,7 @@ async def update_status(
|
|||
)
|
||||
|
||||
completed_at = datetime.now(timezone.utc) if is_final_status(status) else None
|
||||
persistent_browser_session = await db.update_persistent_browser_session(
|
||||
persistent_browser_session = await db.browser_sessions.update_persistent_browser_session(
|
||||
session_id,
|
||||
status=status,
|
||||
organization_id=organization_id,
|
||||
|
|
@ -215,7 +215,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
|
||||
LOG.info("Begin browser session", browser_session_id=browser_session_id)
|
||||
|
||||
persistent_browser_session = await self.database.get_persistent_browser_session(
|
||||
persistent_browser_session = await self.database.browser_sessions.get_persistent_browser_session(
|
||||
browser_session_id, organization_id
|
||||
)
|
||||
|
||||
|
|
@ -246,11 +246,13 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
self, runnable_id: str, organization_id: str
|
||||
) -> PersistentBrowserSession | None:
|
||||
"""Get a specific browser session by runnable ID."""
|
||||
return await self.database.get_persistent_browser_session_by_runnable_id(runnable_id, organization_id)
|
||||
return await self.database.browser_sessions.get_persistent_browser_session_by_runnable_id(
|
||||
runnable_id, organization_id
|
||||
)
|
||||
|
||||
async def get_active_sessions(self, organization_id: str) -> list[PersistentBrowserSession]:
|
||||
"""Get all active sessions for an organization."""
|
||||
return await self.database.get_active_persistent_browser_sessions(organization_id)
|
||||
return await self.database.browser_sessions.get_active_persistent_browser_sessions(organization_id)
|
||||
|
||||
async def get_browser_state(self, session_id: str, organization_id: str | None = None) -> BrowserState | None:
|
||||
"""Get a specific browser session's state by session ID."""
|
||||
|
|
@ -263,7 +265,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
|
||||
async def get_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession | None:
|
||||
"""Get a specific browser session by session ID."""
|
||||
return await self.database.get_persistent_browser_session(session_id, organization_id)
|
||||
return await self.database.browser_sessions.get_persistent_browser_session(session_id, organization_id)
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
|
|
@ -283,7 +285,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
"Creating new browser session",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
session = await self.database.create_persistent_browser_session(
|
||||
session = await self.database.browser_sessions.create_persistent_browser_session(
|
||||
organization_id=organization_id,
|
||||
runnable_type=runnable_type,
|
||||
runnable_id=runnable_id,
|
||||
|
|
@ -350,7 +352,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
await browser_state.close()
|
||||
return
|
||||
# Set started_at so renewal knows the browser is live
|
||||
await self.database.update_persistent_browser_session(
|
||||
await self.database.browser_sessions.update_persistent_browser_session(
|
||||
session_id,
|
||||
organization_id=organization_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
|
|
@ -375,7 +377,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
organization_id: str,
|
||||
) -> None:
|
||||
"""Occupy a specific browser session."""
|
||||
await self.database.occupy_persistent_browser_session(
|
||||
await self.database.browser_sessions.occupy_persistent_browser_session(
|
||||
session_id=session_id,
|
||||
runnable_type=runnable_type,
|
||||
runnable_id=runnable_id,
|
||||
|
|
@ -410,7 +412,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
|
||||
async def release_browser_session(self, session_id: str, organization_id: str) -> None:
|
||||
"""Release a specific browser session."""
|
||||
await self.database.release_persistent_browser_session(session_id, organization_id)
|
||||
await self.database.browser_sessions.release_persistent_browser_session(session_id, organization_id)
|
||||
|
||||
async def close_session(self, organization_id: str, browser_session_id: str) -> None:
|
||||
"""Close a specific browser session."""
|
||||
|
|
@ -491,13 +493,13 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
session_id=browser_session_id,
|
||||
)
|
||||
|
||||
await self.database.close_persistent_browser_session(browser_session_id, organization_id)
|
||||
await self.database.browser_sessions.close_persistent_browser_session(browser_session_id, organization_id)
|
||||
if settings.BROWSER_STREAMING_MODE == "cdp":
|
||||
await self.database.archive_browser_session_address(browser_session_id, organization_id)
|
||||
await self.database.browser_sessions.archive_browser_session_address(browser_session_id, organization_id)
|
||||
|
||||
async def close_all_sessions(self, organization_id: str) -> None:
|
||||
"""Close all browser sessions for an organization."""
|
||||
browser_sessions = await self.database.get_active_persistent_browser_sessions(organization_id)
|
||||
browser_sessions = await self.database.browser_sessions.get_active_persistent_browser_sessions(organization_id)
|
||||
for browser_session in browser_sessions:
|
||||
await self.close_session(organization_id, browser_session.persistent_browser_session_id)
|
||||
|
||||
|
|
@ -505,17 +507,17 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
"""Close sessions left active by a previous process."""
|
||||
if settings.BROWSER_STREAMING_MODE != "cdp":
|
||||
return
|
||||
stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions()
|
||||
stale_sessions = await self.database.browser_sessions.get_uncompleted_persistent_browser_sessions()
|
||||
for db_session in stale_sessions:
|
||||
LOG.info(
|
||||
"Closing stale browser session from previous run",
|
||||
session_id=db_session.persistent_browser_session_id,
|
||||
organization_id=db_session.organization_id,
|
||||
)
|
||||
await self.database.close_persistent_browser_session(
|
||||
await self.database.browser_sessions.close_persistent_browser_session(
|
||||
db_session.persistent_browser_session_id, db_session.organization_id
|
||||
)
|
||||
await self.database.archive_browser_session_address(
|
||||
await self.database.browser_sessions.archive_browser_session_address(
|
||||
db_session.persistent_browser_session_id, db_session.organization_id
|
||||
)
|
||||
|
||||
|
|
@ -524,7 +526,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
"""Close all browser sessions across all organizations."""
|
||||
LOG.info("Closing PersistentSessionsManager")
|
||||
if cls.instance:
|
||||
active_sessions = await cls.instance.database.get_all_active_persistent_browser_sessions()
|
||||
active_sessions = await cls.instance.database.browser_sessions.get_all_active_persistent_browser_sessions()
|
||||
for db_session in active_sessions:
|
||||
await cls.instance.close_session(db_session.organization_id, db_session.persistent_browser_session_id)
|
||||
LOG.info("PersistentSessionsManager is closed")
|
||||
|
|
|
|||
|
|
@ -22,21 +22,21 @@ def create_forge_stub_app() -> ForgeApp:
|
|||
fake_app_module.AGENT_FUNCTION.validate_block_execution = AsyncMock()
|
||||
fake_app_module.AGENT_FUNCTION.validate_code_block = AsyncMock()
|
||||
fake_app_module.agent = _LazyNamespace()
|
||||
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.get_last_task_for_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.get_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.get_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.get_task = AsyncMock()
|
||||
fake_app_module.DATABASE.update_task = AsyncMock()
|
||||
fake_app_module.DATABASE.update_task_v2 = AsyncMock()
|
||||
fake_app_module.DATABASE.organizations = _LazyNamespace()
|
||||
fake_app_module.DATABASE.workflows = _LazyNamespace()
|
||||
fake_app_module.DATABASE.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.update_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.tasks.get_last_task_for_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.workflow_runs.get_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.get_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.tasks.get_task = AsyncMock()
|
||||
fake_app_module.DATABASE.tasks.update_task = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.update_task_v2 = AsyncMock()
|
||||
fake_app_module.DATABASE.organizations.get_organization = AsyncMock()
|
||||
fake_app_module.DATABASE.workflows.get_workflow = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.create_workflow_run_block = AsyncMock()
|
||||
fake_app_module.DATABASE.workflow_runs.update_workflow_run = AsyncMock()
|
||||
fake_app_module.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter = AsyncMock()
|
||||
fake_app_module.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
fake_app_module.LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.SECONDARY_LLM_API_HANDLER = AsyncMock()
|
||||
fake_app_module.AUTO_COMPLETION_LLM_API_HANDLER = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -196,7 +196,10 @@ def test_agent_db_has_typed_repo_attributes():
|
|||
assert isinstance(db.tasks, TasksRepository)
|
||||
assert isinstance(db.credentials, CredentialRepository)
|
||||
assert hasattr(db, "get_task") # backward compat delegate
|
||||
assert hasattr(db, "workflows")
|
||||
# Migrated domains no longer have delegates on AgentDB:
|
||||
assert not hasattr(db, "create_workflow")
|
||||
assert not hasattr(db, "get_organization")
|
||||
assert not hasattr(db, "get_credential")
|
||||
|
||||
|
||||
def test_agent_db_delegates_route_to_repositories():
|
||||
|
|
|
|||
|
|
@ -242,10 +242,10 @@ def setup_parallel_verification_mocks(
|
|||
extract_action: Any | None = None,
|
||||
) -> ParallelVerificationMocks:
|
||||
create_step_mock = AsyncMock(return_value=next_step)
|
||||
monkeypatch.setattr(app.DATABASE, "create_step", create_step_mock)
|
||||
monkeypatch.setattr(app.DATABASE.tasks, "create_step", create_step_mock)
|
||||
|
||||
get_task_steps_mock = AsyncMock(return_value=[step])
|
||||
monkeypatch.setattr(app.DATABASE, "get_task_steps", get_task_steps_mock)
|
||||
monkeypatch.setattr(app.DATABASE.tasks, "get_task_steps", get_task_steps_mock)
|
||||
|
||||
sleep_mock = AsyncMock(return_value=None)
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", sleep_mock)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def mock_app():
|
|||
mock = MagicMock()
|
||||
mock.SINGLE_CLICK_AGENT_LLM_API_HANDLER = AsyncMock(return_value={"actions": []})
|
||||
mock.DATABASE = MagicMock()
|
||||
mock.DATABASE.get_step = AsyncMock(return_value=MagicMock())
|
||||
mock.DATABASE.tasks.get_step = AsyncMock(return_value=MagicMock())
|
||||
return mock
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ async def test_cached_content_removed_from_non_extract_prompts() -> None:
|
|||
# Ensure app dependencies referenced inside the handler resolve to async mocks.
|
||||
forge_module.app.ARTIFACT_MANAGER = MagicMock()
|
||||
forge_module.app.DATABASE = MagicMock()
|
||||
forge_module.app.DATABASE.update_step = AsyncMock()
|
||||
forge_module.app.DATABASE.update_thought = AsyncMock()
|
||||
forge_module.app.DATABASE.tasks.update_step = AsyncMock()
|
||||
forge_module.app.DATABASE.observer.update_thought = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.sdk.api.llm.api_handler_factory.LLMConfigRegistry.get_config", return_value=mock_config),
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ async def test_batch_actions_preserve_per_task_ordering() -> None:
|
|||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
|
||||
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block_a, run_block_b])
|
||||
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task_a, mock_task_b])
|
||||
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=all_actions_descending)
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
|
|
@ -187,9 +187,9 @@ async def test_batch_actions_without_reverse_would_be_wrong() -> None:
|
|||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=actions_descending)
|
||||
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=actions_descending)
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
|
|
@ -250,9 +250,9 @@ async def test_batch_actions_preserve_none_element_id() -> None:
|
|||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[action_extract])
|
||||
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(return_value=[run_block])
|
||||
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=[action_extract])
|
||||
|
||||
result = await transform_workflow_run_to_code_gen_input(workflow_run_id="wr_test", organization_id="org_test")
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class TestScreenshotAfterExecutionBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
await ScriptSkyvernPage._create_screenshot_after_execution()
|
||||
|
|
@ -132,7 +132,7 @@ class TestScreenshotAfterExecutionBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
await ScriptSkyvernPage._create_screenshot_after_execution()
|
||||
|
|
@ -170,7 +170,7 @@ class TestHtmlActionAfterExecutionBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
mock_frame = MagicMock()
|
||||
|
|
@ -210,7 +210,7 @@ class TestHtmlActionAfterExecutionBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
mock_frame = MagicMock()
|
||||
|
|
@ -249,7 +249,7 @@ class TestFinalScreenshotBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
await ScriptSkyvernPage._create_final_screenshot()
|
||||
|
|
@ -283,7 +283,7 @@ class TestFinalScreenshotBundling:
|
|||
):
|
||||
mock_ctx.ensure_context.return_value = context
|
||||
mock_get_bs.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_step = AsyncMock(return_value=step)
|
||||
mock_app.DATABASE.tasks.get_step = AsyncMock(return_value=step)
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
|
||||
await ScriptSkyvernPage._create_final_screenshot()
|
||||
|
|
@ -317,9 +317,9 @@ class TestUpdateWorkflowBlockFlush:
|
|||
):
|
||||
mock_ctx.current.return_value = context
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
mock_app.DATABASE.update_step = AsyncMock()
|
||||
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_step = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
|
|
@ -348,9 +348,9 @@ class TestUpdateWorkflowBlockFlush:
|
|||
):
|
||||
mock_ctx.current.return_value = context
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
mock_app.DATABASE.update_step = AsyncMock()
|
||||
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_step = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
|
|
@ -379,7 +379,7 @@ class TestUpdateWorkflowBlockFlush:
|
|||
):
|
||||
mock_ctx.current.return_value = context
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
mock_app.WORKFLOW_SERVICE.send_workflow_response = AsyncMock()
|
||||
mock_run_ctx.get_run_context.return_value = None
|
||||
|
||||
|
|
@ -407,9 +407,9 @@ class TestUpdateWorkflowBlockFlush:
|
|||
):
|
||||
mock_ctx.current.return_value = context
|
||||
mock_app.ARTIFACT_MANAGER = mock_manager
|
||||
mock_app.DATABASE.update_step = AsyncMock()
|
||||
mock_app.DATABASE.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.update_workflow_run_block = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_step = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=MagicMock(extracted_information=None))
|
||||
mock_app.DATABASE.observer.update_workflow_run_block = AsyncMock()
|
||||
mock_app.STORAGE.get_downloaded_files = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_task_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
mock_app.WORKFLOW_SERVICE.get_recent_workflow_screenshot_artifacts = AsyncMock(return_value=[])
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""Tests that update_credential() accepts user_context and save_browser_session_intent
|
||||
on both CredentialRepository and CredentialsMixin."""
|
||||
on CredentialRepository."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.db.mixins.credentials import CredentialsMixin
|
||||
from skyvern.forge.sdk.db.repositories.credentials import CredentialRepository
|
||||
from tests.unit.conftest import MockAsyncSessionCtx, make_mock_session
|
||||
|
||||
|
|
@ -15,13 +14,6 @@ def _make_credential_repo(mock_credential: MagicMock) -> CredentialRepository:
|
|||
return CredentialRepository(session_factory=lambda: MockAsyncSessionCtx(mock_session))
|
||||
|
||||
|
||||
def _make_credential_mixin(mock_credential: MagicMock) -> CredentialsMixin:
|
||||
mock_session = make_mock_session(mock_credential)
|
||||
mixin = CredentialsMixin.__new__(CredentialsMixin)
|
||||
mixin.Session = lambda: MockAsyncSessionCtx(mock_session) # type: ignore[assignment]
|
||||
return mixin
|
||||
|
||||
|
||||
# --- CredentialRepository tests ---
|
||||
|
||||
|
||||
|
|
@ -75,40 +67,3 @@ async def test_repo_update_credential_unset_params_not_applied() -> None:
|
|||
|
||||
assert mock_credential.user_context == "existing"
|
||||
assert mock_credential.save_browser_session_intent is True
|
||||
|
||||
|
||||
# --- CredentialsMixin tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixin_update_credential_accepts_user_context() -> None:
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.name = "test"
|
||||
mock_credential.user_context = None
|
||||
mixin = _make_credential_mixin(mock_credential)
|
||||
|
||||
with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()):
|
||||
await mixin.update_credential(
|
||||
credential_id="cred_123",
|
||||
organization_id="org_123",
|
||||
user_context="Click SSO button first",
|
||||
)
|
||||
|
||||
assert mock_credential.user_context == "Click SSO button first"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixin_update_credential_accepts_save_browser_session_intent() -> None:
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.name = "test"
|
||||
mock_credential.save_browser_session_intent = False
|
||||
mixin = _make_credential_mixin(mock_credential)
|
||||
|
||||
with patch("skyvern.forge.sdk.schemas.credentials.Credential.model_validate", return_value=MagicMock()):
|
||||
await mixin.update_credential(
|
||||
credential_id="cred_123",
|
||||
organization_id="org_123",
|
||||
save_browser_session_intent=True,
|
||||
)
|
||||
|
||||
assert mock_credential.save_browser_session_intent is True
|
||||
|
|
|
|||
|
|
@ -444,7 +444,7 @@ async def test_handle_action_navigates_back_from_blank_page_after_download() ->
|
|||
|
||||
mock_app = MagicMock()
|
||||
mock_app.BROWSER_MANAGER.get_for_task.return_value = browser_state
|
||||
mock_app.DATABASE.create_action = AsyncMock(return_value=action)
|
||||
mock_app.DATABASE.workflow_params.create_action = AsyncMock(return_value=action)
|
||||
mock_app.STORAGE = MagicMock()
|
||||
|
||||
with (
|
||||
|
|
@ -517,7 +517,7 @@ async def test_handle_action_does_not_navigate_back_when_page_url_unchanged() ->
|
|||
|
||||
mock_app = MagicMock()
|
||||
mock_app.BROWSER_MANAGER.get_for_task.return_value = browser_state
|
||||
mock_app.DATABASE.create_action = AsyncMock(return_value=action)
|
||||
mock_app.DATABASE.workflow_params.create_action = AsyncMock(return_value=action)
|
||||
mock_app.STORAGE = MagicMock()
|
||||
|
||||
with (
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
|
|||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
# Simulate FailedToNavigateToUrl scenario
|
||||
from skyvern.exceptions import FailedToNavigateToUrl
|
||||
|
|
@ -103,8 +103,8 @@ async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
|
|||
assert result is True
|
||||
|
||||
# Verify errors were stored
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
mock_app.DATABASE.tasks.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "page_not_found"
|
||||
|
||||
|
|
@ -138,9 +138,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
|
|||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.BROWSER_MANAGER.get_for_task.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_task_steps = AsyncMock(return_value=[step, step, step])
|
||||
mock_app.DATABASE.get_task = AsyncMock(return_value=task)
|
||||
mock_app.DATABASE.update_task = AsyncMock(return_value=task)
|
||||
mock_app.DATABASE.tasks.get_task_steps = AsyncMock(return_value=[step, step, step])
|
||||
mock_app.DATABASE.tasks.get_task = AsyncMock(return_value=task)
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock(return_value=task)
|
||||
# create_step is awaited in handle_failed_step retry branch; avoid MagicMock in await
|
||||
next_step = make_step(
|
||||
now,
|
||||
|
|
@ -151,9 +151,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
|
|||
retry_index=step.retry_index + 1,
|
||||
output=None,
|
||||
)
|
||||
mock_app.DATABASE.create_step = AsyncMock(return_value=next_step)
|
||||
mock_app.DATABASE.tasks.create_step = AsyncMock(return_value=next_step)
|
||||
|
||||
# Async mock that forwards to mock_app.DATABASE.update_task so we never await MagicMock inside real update_task
|
||||
# Async mock that forwards to mock_app.DATABASE.tasks.update_task so we never await MagicMock inside real update_task
|
||||
async def mock_update_task(
|
||||
_self,
|
||||
task,
|
||||
|
|
@ -173,7 +173,9 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
|
|||
updates["errors"] = errors
|
||||
if failure_category is not None:
|
||||
updates["failure_category"] = failure_category
|
||||
return await mock_app.DATABASE.update_task(task.task_id, organization_id=task.organization_id, **updates)
|
||||
return await mock_app.DATABASE.tasks.update_task(
|
||||
task.task_id, organization_id=task.organization_id, **updates
|
||||
)
|
||||
|
||||
with patch.object(ForgeAgent, "summary_failure_reason_for_max_retries", mock_summary):
|
||||
with patch.object(ForgeAgent, "update_task", mock_update_task):
|
||||
|
|
@ -182,8 +184,8 @@ async def test_max_retries_with_error_detection(agent, mock_browser_state):
|
|||
assert result is None # No next step when max retries exceeded
|
||||
|
||||
# Verify errors include both system and user-defined errors
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
mock_app.DATABASE.tasks.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
|
||||
errors = call_kwargs["errors"]
|
||||
assert len(errors) == 2
|
||||
|
|
@ -220,7 +222,7 @@ async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
|
|||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
# Simulate ScrapingFailed scenario
|
||||
from skyvern.exceptions import ScrapingFailed
|
||||
|
|
@ -233,8 +235,8 @@ async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
|
|||
assert result is True
|
||||
|
||||
# Verify errors were stored
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
mock_app.DATABASE.tasks.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "login_required"
|
||||
|
||||
|
|
@ -268,12 +270,12 @@ async def test_multiple_failures_accumulate_errors(agent, mock_browser_state):
|
|||
|
||||
with create_error_detection_mocks(first_detected):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
await agent.fail_task(task, step, "First failure", mock_browser_state)
|
||||
|
||||
# Only new errors are passed — DB handles appending to existing ones
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
|
||||
|
||||
|
|
@ -304,15 +306,15 @@ async def test_error_detection_with_workflow_task(agent, mock_browser_state):
|
|||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Workflow task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify errors were stored for workflow task
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
mock_app.DATABASE.tasks.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert call_kwargs["task_id"] == task.task_id
|
||||
# workflow_run_id is not passed in the update call, only task_id and errors
|
||||
assert "workflow_run_id" not in call_kwargs
|
||||
|
|
@ -350,7 +352,7 @@ async def test_error_detection_performance_doesnt_block_failure(agent, mock_brow
|
|||
mock_detect.side_effect = slow_detection
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
import time
|
||||
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_brow
|
|||
mock_detect.return_value = detected_errors
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
|
|
@ -87,8 +87,8 @@ async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_brow
|
|||
)
|
||||
|
||||
# Verify task errors were updated in database
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
mock_app.DATABASE.tasks.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert call_kwargs["task_id"] == task.task_id
|
||||
assert call_kwargs["organization_id"] == task.organization_id
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
|
|
@ -112,7 +112,7 @@ async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
|
|||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
|
|||
mock_detect.assert_not_called()
|
||||
|
||||
# Verify database update was NOT called for errors
|
||||
mock_app.DATABASE.update_task.assert_not_called()
|
||||
mock_app.DATABASE.tasks.update_task.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -150,7 +150,7 @@ async def test_fail_task_without_browser_state(agent):
|
|||
mock_detect.return_value = []
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
# Call without browser_state
|
||||
result = await agent.fail_task(task, step, "Task failed", browser_state=None)
|
||||
|
|
@ -188,7 +188,7 @@ async def test_fail_task_without_step(agent, mock_browser_state):
|
|||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
# Call without step
|
||||
result = await agent.fail_task(task, None, "Task failed", mock_browser_state)
|
||||
|
|
@ -228,7 +228,7 @@ async def test_fail_task_error_detection_fails_gracefully(agent, mock_browser_st
|
|||
mock_detect.side_effect = Exception("Detection failed")
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
# Should not raise exception
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
|
@ -270,14 +270,14 @@ async def test_fail_task_multiple_errors_detected(agent, mock_browser_state):
|
|||
mock_detect.return_value = detected_errors
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify only new errors were passed (DB handles appending to existing errors)
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
call_kwargs = mock_app.DATABASE.tasks.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 2
|
||||
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
|
||||
assert call_kwargs["errors"][1]["error_code"] == "address_invalid"
|
||||
|
|
@ -308,14 +308,14 @@ async def test_fail_task_no_errors_detected(agent, mock_browser_state):
|
|||
mock_detect.return_value = []
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
mock_app.DATABASE.tasks.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Database update for errors should not be called
|
||||
mock_app.DATABASE.update_task.assert_not_called()
|
||||
mock_app.DATABASE.tasks.update_task.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -336,7 +336,7 @@ async def test_transform_forloop_block_with_mocked_db() -> None:
|
|||
):
|
||||
mock_get_wfr.return_value = mock_workflow_run_resp
|
||||
mock_app.WORKFLOW_SERVICE.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_app.DATABASE.get_workflow_run_blocks = AsyncMock(
|
||||
mock_app.DATABASE.observer.get_workflow_run_blocks = AsyncMock(
|
||||
return_value=[
|
||||
mock_forloop_run_block,
|
||||
mock_child_run_block,
|
||||
|
|
@ -345,8 +345,8 @@ async def test_transform_forloop_block_with_mocked_db() -> None:
|
|||
# B1 optimization: Mock batch methods instead of individual queries
|
||||
mock_task.task_id = "task_extraction_789"
|
||||
mock_action.task_id = "task_extraction_789"
|
||||
mock_app.DATABASE.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.get_tasks_actions = AsyncMock(return_value=[mock_action])
|
||||
mock_app.DATABASE.tasks.get_tasks_by_ids = AsyncMock(return_value=[mock_task])
|
||||
mock_app.DATABASE.tasks.get_tasks_actions = AsyncMock(return_value=[mock_action])
|
||||
|
||||
# Call the transformation
|
||||
result = await transform_workflow_run_to_code_gen_input(
|
||||
|
|
|
|||
15
tests/unit/test_no_direct_db_delegates.py
Normal file
15
tests/unit/test_no_direct_db_delegates.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""
|
||||
Verify the no-direct-db-delegates hook script stays consistent.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
def test_hook_passes_on_current_codebase() -> None:
|
||||
"""The hook should pass cleanly — all legacy files are in the allowlist."""
|
||||
result = subprocess.run(
|
||||
["./scripts/check_no_direct_db_delegates.sh"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
assert result.returncode == 0, f"Hook unexpectedly failed. New direct delegate calls found:\n{result.stdout}"
|
||||
|
|
@ -348,7 +348,7 @@ async def test_agent_step_persists_artifacts_when_using_speculative_plan(
|
|||
monkeypatch.setattr("skyvern.forge.agent.app.AGENT_FUNCTION.post_action_execution", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.asyncio.sleep", AsyncMock(return_value=None))
|
||||
monkeypatch.setattr("skyvern.forge.agent.random.uniform", lambda *_args, **_kwargs: 0)
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.create_action", AsyncMock())
|
||||
monkeypatch.setattr("skyvern.forge.agent.app.DATABASE.workflow_params.create_action", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
"skyvern.forge.agent.app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached",
|
||||
AsyncMock(return_value=False),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from skyvern.forge.sdk.routes import agent_protocol
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
mock_database = SimpleNamespace(
|
||||
mock_workflow_runs = SimpleNamespace(
|
||||
get_all_runs_v2=AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
|
|
@ -28,6 +28,7 @@ async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: py
|
|||
]
|
||||
)
|
||||
)
|
||||
mock_database = SimpleNamespace(workflow_runs=mock_workflow_runs)
|
||||
monkeypatch.setattr(agent_protocol.app, "DATABASE", mock_database)
|
||||
|
||||
response = await agent_protocol.get_runs_v2(
|
||||
|
|
@ -37,7 +38,7 @@ async def test_get_runs_v2_serializes_mapping_rows_from_database(monkeypatch: py
|
|||
search_key="abc",
|
||||
)
|
||||
|
||||
mock_database.get_all_runs_v2.assert_awaited_once_with(
|
||||
mock_workflow_runs.get_all_runs_v2.assert_awaited_once_with(
|
||||
"org_123",
|
||||
page=2,
|
||||
page_size=5,
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class TestTriggerScriptReviewerCap:
|
|||
self.mock_cache.get_lock = MagicMock(return_value=mock_lock)
|
||||
# Mock is_script_pinned to return False (not pinned) so tests reach the cap logic
|
||||
self._pin_patcher = patch(
|
||||
"skyvern.forge.sdk.workflow.service.app.DATABASE.is_script_pinned",
|
||||
"skyvern.forge.sdk.workflow.service.app.DATABASE.scripts.is_script_pinned",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -421,12 +421,12 @@ async def test_terminate_calls_handler_and_raises(mock_scraped_page, mock_ai):
|
|||
return_value=mock_context,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_task,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_step,
|
||||
),
|
||||
|
|
@ -481,12 +481,12 @@ async def test_terminate_raises_even_when_task_not_found(mock_scraped_page, mock
|
|||
return_value=mock_context,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
|
|
@ -538,12 +538,12 @@ async def test_terminate_raises_even_when_handler_fails(mock_scraped_page, mock_
|
|||
return_value=mock_context,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_task",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_task",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_task,
|
||||
),
|
||||
patch(
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.get_step",
|
||||
"skyvern.core.script_generations.script_skyvern_page.app.DATABASE.tasks.get_step",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_step,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -526,8 +526,9 @@ def _make_app_mocks() -> tuple[MagicMock, MagicMock]:
|
|||
mock_storage.store_artifact = AsyncMock()
|
||||
|
||||
mock_database = MagicMock()
|
||||
mock_database.bulk_create_artifacts = AsyncMock()
|
||||
mock_database.update_action_screenshot_artifact_id = AsyncMock()
|
||||
mock_database.artifacts = MagicMock()
|
||||
mock_database.artifacts.bulk_create_artifacts = AsyncMock()
|
||||
mock_database.artifacts.update_action_screenshot_artifact_id = AsyncMock()
|
||||
return mock_storage, mock_database
|
||||
|
||||
|
||||
|
|
@ -557,9 +558,9 @@ class TestFlushStepArchive:
|
|||
await manager.flush_step_archive(step.step_id)
|
||||
|
||||
mock_storage.store_artifact.assert_awaited_once()
|
||||
mock_database.bulk_create_artifacts.assert_awaited_once()
|
||||
mock_database.artifacts.bulk_create_artifacts.assert_awaited_once()
|
||||
# The artifact list should include the parent + 6 member rows (scrape produces 6 entries)
|
||||
call_args = mock_database.bulk_create_artifacts.call_args[0][0]
|
||||
call_args = mock_database.artifacts.bulk_create_artifacts.call_args[0][0]
|
||||
assert len(call_args) == 7 # 1 parent + 6 members
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -612,7 +613,7 @@ class TestFlushStepArchive:
|
|||
|
||||
# store_artifact and bulk_create_artifacts should only be called once
|
||||
assert mock_storage.store_artifact.await_count == 1
|
||||
assert mock_database.bulk_create_artifacts.await_count == 1
|
||||
assert mock_database.artifacts.bulk_create_artifacts.await_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_nonexistent_step_id_is_noop(self) -> None:
|
||||
|
|
@ -626,7 +627,7 @@ class TestFlushStepArchive:
|
|||
await manager.flush_step_archive("nonexistent_step_id")
|
||||
|
||||
mock_storage.store_artifact.assert_not_awaited()
|
||||
mock_database.bulk_create_artifacts.assert_not_awaited()
|
||||
mock_database.artifacts.bulk_create_artifacts.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_applies_pending_screenshot_updates(self) -> None:
|
||||
|
|
@ -652,7 +653,7 @@ class TestFlushStepArchive:
|
|||
mock_app.DATABASE = mock_database
|
||||
await manager.flush_step_archive(step.step_id)
|
||||
|
||||
mock_database.update_action_screenshot_artifact_id.assert_awaited_once_with(
|
||||
mock_database.artifacts.update_action_screenshot_artifact_id.assert_awaited_once_with(
|
||||
organization_id="org_1",
|
||||
action_id="action_1",
|
||||
screenshot_artifact_id="art_1",
|
||||
|
|
@ -682,10 +683,10 @@ class TestFlushStepArchive:
|
|||
await manager.flush_step_archive(step.step_id)
|
||||
# Reset call counts to detect any additional calls from wait_for_upload_aiotasks
|
||||
mock_storage.store_artifact.reset_mock()
|
||||
mock_database.bulk_create_artifacts.reset_mock()
|
||||
mock_database.artifacts.bulk_create_artifacts.reset_mock()
|
||||
# Simulate the end-of-task flush fallback
|
||||
await manager.wait_for_upload_aiotasks([step.task_id])
|
||||
|
||||
# The fallback should find nothing to flush — no extra uploads
|
||||
mock_storage.store_artifact.assert_not_awaited()
|
||||
mock_database.bulk_create_artifacts.assert_not_awaited()
|
||||
mock_database.artifacts.bulk_create_artifacts.assert_not_awaited()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
|
|||
organization = SimpleNamespace(organization_id="org_123")
|
||||
user_url = "https://example.com"
|
||||
|
||||
app.DATABASE.create_task_v2.return_value = SimpleNamespace(
|
||||
app.DATABASE.observer.create_task_v2.return_value = SimpleNamespace(
|
||||
observer_cruise_id="tsk_123",
|
||||
workflow_run_id=None,
|
||||
url=user_url,
|
||||
|
|
@ -24,14 +24,14 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
|
|||
title=DEFAULT_WORKFLOW_TITLE,
|
||||
)
|
||||
app.WORKFLOW_SERVICE.setup_workflow_run.return_value = SimpleNamespace(workflow_run_id="wr_123")
|
||||
app.DATABASE.update_task_v2.return_value = SimpleNamespace(
|
||||
app.DATABASE.observer.update_task_v2.return_value = SimpleNamespace(
|
||||
observer_cruise_id="tsk_123",
|
||||
workflow_run_id="wr_123",
|
||||
workflow_id="wf_123",
|
||||
workflow_permanent_id="wpid_123",
|
||||
url=user_url,
|
||||
)
|
||||
app.DATABASE.create_task_run.return_value = SimpleNamespace(run_id="tsk_123")
|
||||
app.DATABASE.tasks.create_task_run.return_value = SimpleNamespace(run_id="tsk_123")
|
||||
|
||||
await initialize_task_v2(
|
||||
organization=organization,
|
||||
|
|
@ -40,7 +40,7 @@ async def test_initialize_task_v2_populates_task_run_url_when_user_url_is_known(
|
|||
create_task_run=True,
|
||||
)
|
||||
|
||||
app.DATABASE.create_task_run.assert_awaited_once_with(
|
||||
app.DATABASE.tasks.create_task_run.assert_awaited_once_with(
|
||||
task_run_type=RunType.task_v2,
|
||||
organization_id="org_123",
|
||||
run_id="tsk_123",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ async def test_org_email_bitwarden_auth_falls_back_to_global_credentials(
|
|||
)
|
||||
)
|
||||
|
||||
class FakeOrgRepo:
|
||||
class FakeOrganizationsRepo:
|
||||
async def get_valid_org_auth_token(self, organization_id: str, token_type: str) -> object:
|
||||
assert organization_id == "org-1"
|
||||
assert token_type == "bitwarden_credential"
|
||||
|
|
@ -28,7 +28,7 @@ async def test_org_email_bitwarden_auth_falls_back_to_global_credentials(
|
|||
|
||||
class FakeDatabase:
|
||||
def __init__(self) -> None:
|
||||
self.organizations = FakeOrgRepo()
|
||||
self.organizations = FakeOrganizationsRepo()
|
||||
|
||||
fake_app = SimpleNamespace(DATABASE=FakeDatabase())
|
||||
monkeypatch.setattr(cm, "app", fake_app)
|
||||
|
|
|
|||
|
|
@ -186,8 +186,8 @@ async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: p
|
|||
ctx.values["flag"] = True
|
||||
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
|
||||
|
||||
app.DATABASE.update_workflow_run_block.reset_mock()
|
||||
app.DATABASE.create_or_update_workflow_run_output_parameter.reset_mock()
|
||||
app.DATABASE.observer.update_workflow_run_block.reset_mock()
|
||||
app.DATABASE.workflow_runs.create_or_update_workflow_run_output_parameter.reset_mock()
|
||||
|
||||
result = await block.execute(
|
||||
workflow_run_id="run-1",
|
||||
|
|
@ -202,7 +202,7 @@ async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: p
|
|||
assert ctx.blocks_metadata["cond"]["branch_taken"] == "next"
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = app.DATABASE.update_workflow_run_block.call_args
|
||||
call_args = app.DATABASE.observer.update_workflow_run_block.call_args
|
||||
assert call_args.kwargs["workflow_run_block_id"] == "wrb-1"
|
||||
assert call_args.kwargs["output"] == metadata
|
||||
assert call_args.kwargs["status"] == BlockStatus.completed
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ class TestBuildBlockResultPassesErrorCodes:
|
|||
|
||||
from skyvern.forge import app
|
||||
|
||||
app.DATABASE.update_workflow_run_block.reset_mock()
|
||||
app.DATABASE.observer.update_workflow_run_block.reset_mock()
|
||||
|
||||
result = await block.build_block_result(
|
||||
success=False,
|
||||
|
|
@ -257,8 +257,8 @@ class TestBuildBlockResultPassesErrorCodes:
|
|||
error_codes=["FILE_PARSER_ERROR"],
|
||||
)
|
||||
|
||||
app.DATABASE.update_workflow_run_block.assert_called_once()
|
||||
call_kwargs = app.DATABASE.update_workflow_run_block.call_args[1]
|
||||
app.DATABASE.observer.update_workflow_run_block.assert_called_once()
|
||||
call_kwargs = app.DATABASE.observer.update_workflow_run_block.call_args[1]
|
||||
assert call_kwargs["error_codes"] == ["FILE_PARSER_ERROR"]
|
||||
assert result.error_codes == ["FILE_PARSER_ERROR"]
|
||||
|
||||
|
|
@ -268,7 +268,7 @@ class TestBuildBlockResultPassesErrorCodes:
|
|||
|
||||
from skyvern.forge import app
|
||||
|
||||
app.DATABASE.update_workflow_run_block.reset_mock()
|
||||
app.DATABASE.observer.update_workflow_run_block.reset_mock()
|
||||
|
||||
result = await block.build_block_result(
|
||||
success=True,
|
||||
|
|
@ -278,7 +278,7 @@ class TestBuildBlockResultPassesErrorCodes:
|
|||
organization_id="org_test",
|
||||
)
|
||||
|
||||
app.DATABASE.update_workflow_run_block.assert_called_once()
|
||||
call_kwargs = app.DATABASE.update_workflow_run_block.call_args[1]
|
||||
app.DATABASE.observer.update_workflow_run_block.assert_called_once()
|
||||
call_kwargs = app.DATABASE.observer.update_workflow_run_block.call_args[1]
|
||||
assert call_kwargs["error_codes"] is None
|
||||
assert result.error_codes == []
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestCheckTriggerDepth:
|
|||
mock_run = MagicMock()
|
||||
mock_run.parent_workflow_run_id = None
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
|
||||
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=mock_run)
|
||||
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(return_value=mock_run)
|
||||
depth = await block._check_trigger_depth("wr_current")
|
||||
assert depth == 0
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ class TestCheckTriggerDepth:
|
|||
run_no_parent.parent_workflow_run_id = None
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
|
||||
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=[run_with_parent, run_no_parent])
|
||||
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=[run_with_parent, run_no_parent])
|
||||
depth = await block._check_trigger_depth("wr_current")
|
||||
assert depth == 1
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ class TestCheckTriggerDepth:
|
|||
runs.append(run)
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
|
||||
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=runs)
|
||||
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=runs)
|
||||
with pytest.raises(InvalidWorkflowDefinition, match="depth exceeds maximum"):
|
||||
await block._check_trigger_depth("wr_current")
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ class TestCheckTriggerDepth:
|
|||
runs.append(run)
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
|
||||
mock_app.DATABASE.get_workflow_run = AsyncMock(side_effect=runs)
|
||||
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(side_effect=runs)
|
||||
depth = await block._check_trigger_depth("wr_current")
|
||||
assert depth == block.MAX_TRIGGER_DEPTH - 1
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ class TestCheckTriggerDepth:
|
|||
async def test_run_not_found_returns_zero(self) -> None:
|
||||
block = _make_block()
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app") as mock_app:
|
||||
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=None)
|
||||
mock_app.DATABASE.workflow_runs.get_workflow_run = AsyncMock(return_value=None)
|
||||
depth = await block._check_trigger_depth("wr_nonexistent")
|
||||
assert depth == 0
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def _make_session(status: str) -> PersistentBrowserSession:
|
|||
async def test_rejects_update_when_already_final(desired_status: str):
|
||||
"""A finalized session must not accept any status update."""
|
||||
db = AsyncMock()
|
||||
db.get_persistent_browser_session.return_value = _make_session(
|
||||
db.browser_sessions.get_persistent_browser_session.return_value = _make_session(
|
||||
PersistentBrowserSessionStatus.completed,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue