mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
fix: teach script reviewer to detect hardcoded per-run values in selectors (#5325)
This commit is contained in:
parent
3a25fe553e
commit
eae6622d15
20 changed files with 569 additions and 505 deletions
|
|
@ -426,7 +426,7 @@ export type WorkflowRunApiResponse = {
|
|||
finished_at: string | null; // ISO 8601
|
||||
modified_at: string;
|
||||
proxy_location: ProxyLocation | null;
|
||||
script_run: { ai_fallback_triggered: boolean } | null;
|
||||
script_run: boolean | null;
|
||||
status: Status;
|
||||
title?: string;
|
||||
trigger_type?: TriggerType | null;
|
||||
|
|
|
|||
|
|
@ -3,13 +3,19 @@ import { useCredentialGetter } from "@/hooks/useCredentialGetter";
|
|||
import { useQuery } from "@tanstack/react-query";
|
||||
import { Status, Task, TriggerType, WorkflowRunApiResponse } from "@/api/types";
|
||||
|
||||
type QueryReturnType = Array<Task | WorkflowRunApiResponse>;
|
||||
type UseQueryOptions = Omit<
|
||||
Parameters<typeof useQuery<QueryReturnType>>[0],
|
||||
"queryKey" | "queryFn"
|
||||
>;
|
||||
|
||||
type Props = {
|
||||
page?: number;
|
||||
pageSize?: number;
|
||||
statusFilters?: Array<Status>;
|
||||
triggerTypeFilters?: Array<TriggerType>;
|
||||
search?: string;
|
||||
};
|
||||
} & UseQueryOptions;
|
||||
|
||||
function useRunsQuery({
|
||||
page = 1,
|
||||
|
|
@ -17,11 +23,9 @@ function useRunsQuery({
|
|||
statusFilters,
|
||||
triggerTypeFilters,
|
||||
search,
|
||||
...queryOptions
|
||||
}: Props) {
|
||||
const credentialGetter = useCredentialGetter();
|
||||
return useQuery<Array<Task | WorkflowRunApiResponse>>({
|
||||
...queryOptions,
|
||||
queryKey: [
|
||||
"runs",
|
||||
{ statusFilters, triggerTypeFilters },
|
||||
|
|
|
|||
|
|
@ -279,7 +279,7 @@ function RunHistory() {
|
|||
const workflowTitle = (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="truncate">{run.workflow_title ?? ""}</span>
|
||||
{run.script_run != null && (
|
||||
{run.script_run === true && (
|
||||
<Tip content="Ran with code">
|
||||
<LightningBoltIcon className="text-[gold]" />
|
||||
</Tip>
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ function WorkflowPage() {
|
|||
) : (
|
||||
workflowRuns?.map((workflowRun) => {
|
||||
const workflowRunId =
|
||||
workflowRun.script_run != null ? (
|
||||
workflowRun.script_run === true ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Tip content="Ran with code">
|
||||
<LightningBoltIcon className="text-[gold]" />
|
||||
|
|
|
|||
|
|
@ -34,6 +34,14 @@ These are the KNOWN parameter names for `context.parameters[...]`:
|
|||
{% endfor %}
|
||||
For fields not covered by these parameters, use `ai='proactive'` with a descriptive prompt (see Rule 8b).
|
||||
{% endif %}
|
||||
{% if run_parameter_values %}
|
||||
|
||||
## Current Run Parameter Values
|
||||
These are the ACTUAL values for this run. If any of these appear (even partially) in a selector or hardcoded string, that value is per-run and MUST be made dynamic — use `context.parameters['key']` or `ai='proactive'` with a prompt referencing the parameter.
|
||||
{% for key, value in run_parameter_values.items() %}
|
||||
- `{{ key }}` = `{{ value[:80] }}{% if value|length > 80 %}…{% endif %}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Existing Cached Code (DO NOT REMOVE ANY EXISTING PATHS)
|
||||
```python
|
||||
|
|
@ -93,9 +101,11 @@ The following failures occurred in PREVIOUS runs. Learn from these — do NOT re
|
|||
- **Error**: {{ h.error_message or "N/A" }}
|
||||
- **Reviewer fix applied**: {{ h.reviewer_output or "No fix recorded" }}
|
||||
- **Did the AI fallback succeed?**: {{ "Yes" if h.fallback_succeeded else ("No" if h.fallback_succeeded is sameas false else "Unknown") }}
|
||||
{% if h.run_parameters %}- **Run parameters**: {% for k, v in h.run_parameters.items() %}`{{ k }}={{ v[:60] }}{% if v|length > 60 %}…{% endif %}`{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach.
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach. **If different past episodes show different values for the same parameter**, the block handles per-run data — any hardcoded reference to those values must be dynamic.
|
||||
{% endif %}
|
||||
|
||||
{% if user_instructions %}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,14 @@ These are the KNOWN parameter names for `context.parameters[...]`:
|
|||
{% endfor %}
|
||||
For fields not covered by these parameters, use `"ai": "proactive"` with a descriptive prompt (see Rule 9b).
|
||||
{% endif %}
|
||||
{% if run_parameter_values %}
|
||||
|
||||
## Current Run Parameter Values
|
||||
These are the ACTUAL values for this run. If any of these appear (even partially) in a FIELD_MAP value or hardcoded string, that value is per-run and MUST use `context.parameters['key']` instead.
|
||||
{% for key, value in run_parameter_values.items() %}
|
||||
- `{{ key }}` = `{{ value[:80] }}{% if value|length > 80 %}…{% endif %}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Existing Cached Code
|
||||
```python
|
||||
|
|
@ -96,9 +104,11 @@ The following failures occurred in PREVIOUS runs. Learn from these — do NOT re
|
|||
- **Error**: {{ h.error_message or "N/A" }}
|
||||
- **Reviewer fix applied**: {{ h.reviewer_output or "No fix recorded" }}
|
||||
- **Did the AI fallback succeed?**: {{ "Yes" if h.fallback_succeeded else ("No" if h.fallback_succeeded is sameas false else "Unknown") }}
|
||||
{% if h.run_parameters %}- **Run parameters**: {% for k, v in h.run_parameters.items() %}`{{ k }}={{ v[:60] }}{% if v|length > 60 %}…{% endif %}`{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach.
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach. **If different past episodes show different values for the same parameter**, the block handles per-run data — any hardcoded reference to those values must use `context.parameters['key']` instead.
|
||||
{% endif %}
|
||||
|
||||
{% if user_instructions %}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,21 @@ You are a script reviewer for a browser automation system. Your job is to update
|
|||
11. **FOR-LOOP BLOCKS**: When code runs inside a for-loop (e.g. by iterating over extracted items), every `page.click()`, `page.goto()`, or `page.download_file()` call MUST use per-iteration data to target the specific item. Never use a static selector or generic prompt that would resolve to the same element on every iteration.
|
||||
- **File download blocks**: Use `context.download_selector()` which builds a CSS selector (e.g. `a[href*="filename.pdf"]`) from URL-like fields in the current loop value. Combine with `context.prompt` for AI fallback: `await page.click(selector=context.download_selector(), prompt=context.prompt, ai="fallback")`. If `download_selector()` returns `None` (no URL in loop value), fall back to a text-based selector using `context.loop_value`.
|
||||
- **Other blocks**: Use `context.loop_value` (dict) to build selectors. Prefer `page.goto(context.loop_value["url"])` for direct URLs, or include the item's title/identifier in the selector: `selector=f'a:has-text("{context.loop_value["title"]}")'`.
|
||||
12. **NEVER HARDCODE PER-RUN VALUES IN SELECTORS OR PROMPTS**. If a selector or prompt contains a value that varies per run (a person's name, an ID, an address, a company name), it MUST be dynamic — never a literal string. This applies even when the page shows an abbreviated form (e.g., "J. Smith") that doesn't exactly match the parameter ("John H. Smith").
|
||||
- **How to detect**: Check the **Current Run Parameter Values** and **Past Episode Parameters** sections below. If different episodes show different values for the same parameter (e.g., `full_name: "John Smith"` in one run, `full_name: "Jane Doe"` in another), any selector referencing that name is per-run and must be dynamic.
|
||||
- **Wrong** (hardcodes a name that changes every run):
|
||||
```python
|
||||
await page.click(selector='a[aria-label="John Smith select for details"]', ai='fallback', prompt='Click provider')
|
||||
```
|
||||
- **Correct — option A** (dynamic selector via f-string, with AI fallback for fuzzy matching):
|
||||
```python
|
||||
await page.click(selector=f'a:has-text("{context.parameters["full_name"]}")', ai='fallback', prompt=f"Click on the entry for {context.parameters['full_name']}")
|
||||
```
|
||||
- **Correct — option B** (AI-only, best when the page abbreviates or reformats the name):
|
||||
```python
|
||||
await page.click(ai='proactive', prompt=f"Click on the provider entry that matches {context.parameters['full_name']}")
|
||||
```
|
||||
- Use **option A** when the page shows the exact parameter value (fast, no LLM call if selector hits). Use **option B** when the page reformats/abbreviates names — the AI handles fuzzy matching that a static selector cannot.
|
||||
|
||||
## Block Navigation Goal
|
||||
```
|
||||
|
|
@ -44,6 +59,14 @@ These are the KNOWN parameter names for `context.parameters[...]`:
|
|||
{% endfor %}
|
||||
For fields not covered by these parameters: use Python code if the value is deterministic (see "Deterministic Logic"), or `ai='proactive'` if it requires subjective judgment (see Rule 9c).
|
||||
{% endif %}
|
||||
{% if run_parameter_values %}
|
||||
|
||||
## Current Run Parameter Values
|
||||
These are the ACTUAL values for this run. If any of these appear (even partially) in a selector or hardcoded string, that value is per-run and MUST be made dynamic (see Rule 12).
|
||||
{% for key, value in run_parameter_values.items() %}
|
||||
- `{{ key }}` = `{{ value[:80] }}{% if value|length > 80 %}…{% endif %}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Existing Cached Code
|
||||
```python
|
||||
|
|
@ -110,9 +133,11 @@ The following failures occurred in PREVIOUS runs and were already reviewed. Lear
|
|||
- **Error**: {{ h.error_message or "N/A" }}
|
||||
- **Reviewer fix applied**: {{ h.reviewer_output or "No fix recorded" }}
|
||||
- **Did the AI fallback succeed?**: {{ "Yes" if h.fallback_succeeded else ("No" if h.fallback_succeeded is sameas false else "Unknown") }}
|
||||
{% if h.run_parameters %}- **Run parameters**: {% for k, v in h.run_parameters.items() %}`{{ k }}={{ v[:60] }}{% if v|length > 60 %}…{% endif %}`{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach instead of repeating the same pattern.
|
||||
**Key takeaway:** If the same error appears above multiple times, the previous fixes did NOT solve it. Try a fundamentally different approach instead of repeating the same pattern. **If different past episodes show different values for the same parameter** (e.g., different names, different IDs), the block handles per-run data — any selector or hardcoded string referencing those values must be dynamic (see Rule 12).
|
||||
{% endif %}
|
||||
|
||||
{% if user_instructions %}
|
||||
|
|
|
|||
|
|
@ -109,8 +109,6 @@ __all__ = ["AgentDB", "ScheduleLimitExceededError"]
|
|||
|
||||
|
||||
class AgentDB(BaseAlchemyDB):
|
||||
_background_tasks: set[asyncio.Task] = set() # noqa: RUF012
|
||||
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
|
||||
super().__init__(db_engine or _build_engine(database_string))
|
||||
self.debug_enabled = debug_enabled
|
||||
|
|
@ -233,25 +231,7 @@ class AgentDB(BaseAlchemyDB):
|
|||
return await self.tasks.clear_task_failure_reason(*args, **kwargs)
|
||||
|
||||
async def update_task(self, *args: Any, **kwargs: Any) -> Any:
|
||||
updated_task = await self.tasks.update_task(*args, **kwargs)
|
||||
|
||||
# Best-effort fire-and-forget write-through to task_runs.
|
||||
# Mirrors the WorkflowService pattern — cron catches any missed syncs.
|
||||
status = kwargs.get("status")
|
||||
if status is not None:
|
||||
task = asyncio.create_task(
|
||||
self.workflow_params.sync_task_run_status(
|
||||
organization_id=updated_task.organization_id or "",
|
||||
run_id=updated_task.task_id,
|
||||
status=status.value,
|
||||
started_at=updated_task.started_at,
|
||||
finished_at=updated_task.finished_at,
|
||||
),
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return updated_task
|
||||
return await self.tasks.update_task(*args, **kwargs)
|
||||
|
||||
async def update_task_2fa_state(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.tasks.update_task_2fa_state(*args, **kwargs)
|
||||
|
|
@ -470,25 +450,25 @@ class AgentDB(BaseAlchemyDB):
|
|||
return await self.workflow_params.retrieve_action_plan(*args, **kwargs)
|
||||
|
||||
async def create_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.create_task_run(*args, **kwargs)
|
||||
return await self.tasks.create_task_run(*args, **kwargs)
|
||||
|
||||
async def update_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_task_run(*args, **kwargs)
|
||||
return await self.tasks.update_task_run(*args, **kwargs)
|
||||
|
||||
async def sync_task_run_status(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.sync_task_run_status(*args, **kwargs)
|
||||
return await self.tasks.sync_task_run_status(*args, **kwargs)
|
||||
|
||||
async def update_job_run_compute_cost(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.update_job_run_compute_cost(*args, **kwargs)
|
||||
return await self.tasks.update_job_run_compute_cost(*args, **kwargs)
|
||||
|
||||
async def cache_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.cache_task_run(*args, **kwargs)
|
||||
return await self.tasks.cache_task_run(*args, **kwargs)
|
||||
|
||||
async def get_cached_task_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_cached_task_run(*args, **kwargs)
|
||||
return await self.tasks.get_cached_task_run(*args, **kwargs)
|
||||
|
||||
async def get_run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return await self.workflow_params.get_run(*args, **kwargs)
|
||||
return await self.tasks.get_run(*args, **kwargs)
|
||||
|
||||
# -- Artifact delegates --
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,15 @@ from skyvern.forge.sdk.db.models import (
|
|||
ActionModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunStatus, RunType
|
||||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
|
|
@ -737,3 +739,197 @@ class TasksMixin:
|
|||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def sync_task_run_status(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
status: str,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Best-effort write-through: propagate status from source table to task_runs.
|
||||
|
||||
Does NOT raise if the task_runs row is missing (race at creation time).
|
||||
"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
vals: dict[str, Any] = {"status": status}
|
||||
if started_at is not None:
|
||||
vals["started_at"] = started_at
|
||||
if finished_at is not None:
|
||||
vals["finished_at"] = finished_at
|
||||
stmt = (
|
||||
update(TaskRunModel)
|
||||
.where(TaskRunModel.run_id == run_id)
|
||||
.where(TaskRunModel.organization_id == organization_id)
|
||||
.values(**vals)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort task_run status sync failed",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: RunStatus | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
# script_run, started_at, finished_at are intentionally omitted here —
|
||||
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
|
||||
) -> Run:
|
||||
searchable_text = " ".join(filter(None, [title, url]))
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
status=status,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
debug_session_id=debug_session_id,
|
||||
searchable_text=searchable_text or None,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title is not None:
|
||||
task_run.title = title
|
||||
if url is not None:
|
||||
task_run.url = url
|
||||
if url_hash is not None:
|
||||
task_run.url_hash = url_hash
|
||||
if status is not None:
|
||||
task_run.status = status
|
||||
if started_at is not None:
|
||||
task_run.started_at = started_at
|
||||
if finished_at is not None:
|
||||
task_run.finished_at = finished_at
|
||||
|
||||
# Recompute searchable_text when title or url changes
|
||||
if title is not None or url is not None:
|
||||
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
|
||||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
|
|
@ -24,7 +24,6 @@ from skyvern.forge.sdk.db.models import (
|
|||
OutputParameterModel,
|
||||
TaskGenerationModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowCopilotChatMessageModel,
|
||||
WorkflowCopilotChatModel,
|
||||
WorkflowParameterModel,
|
||||
|
|
@ -37,7 +36,6 @@ from skyvern.forge.sdk.db.utils import (
|
|||
hydrate_action,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||
|
|
@ -59,7 +57,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
|
|||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.schemas.runs import RunStatus, RunType
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -567,197 +564,3 @@ class WorkflowParametersMixin:
|
|||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: RunStatus | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
# script_run, started_at, finished_at are intentionally omitted here —
|
||||
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
|
||||
) -> Run:
|
||||
searchable_text = " ".join(filter(None, [title, url]))
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
status=status,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
debug_session_id=debug_session_id,
|
||||
searchable_text=searchable_text or None,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title is not None:
|
||||
task_run.title = title
|
||||
if url is not None:
|
||||
task_run.url = url
|
||||
if url_hash is not None:
|
||||
task_run.url_hash = url_hash
|
||||
if status is not None:
|
||||
task_run.status = status
|
||||
if started_at is not None:
|
||||
task_run.started_at = started_at
|
||||
if finished_at is not None:
|
||||
task_run.finished_at = finished_at
|
||||
|
||||
# Recompute searchable_text when title or url changes
|
||||
if title is not None or url is not None:
|
||||
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def sync_task_run_status(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
status: str,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Best-effort write-through: propagate status from source table to task_runs.
|
||||
|
||||
Does NOT raise if the task_runs row is missing (race at creation time).
|
||||
"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
vals: dict[str, Any] = {"status": status}
|
||||
if started_at is not None:
|
||||
vals["started_at"] = started_at
|
||||
if finished_at is not None:
|
||||
vals["finished_at"] = finished_at
|
||||
stmt = (
|
||||
update(TaskRunModel)
|
||||
.where(TaskRunModel.run_id == run_id)
|
||||
.where(TaskRunModel.organization_id == organization_id)
|
||||
.values(**vals)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort task_run status sync failed",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
|
|
|||
|
|
@ -298,7 +298,6 @@ class WorkflowRunsMixin:
|
|||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
trigger_type: list[WorkflowRunTriggerType] | None = None,
|
||||
include_debugger_runs: bool = False,
|
||||
search_key: str | None = None,
|
||||
) -> list[WorkflowRun | Task]:
|
||||
|
|
@ -358,10 +357,6 @@ class WorkflowRunsMixin:
|
|||
|
||||
if status:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
|
||||
if trigger_type:
|
||||
workflow_run_query = workflow_run_query.filter(
|
||||
WorkflowRunModel.trigger_type.in_([t.value for t in trigger_type])
|
||||
)
|
||||
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
|
||||
workflow_run_query_result = (await session.execute(workflow_run_query)).all()
|
||||
workflow_runs = [
|
||||
|
|
@ -369,20 +364,16 @@ class WorkflowRunsMixin:
|
|||
for run, title in workflow_run_query_result
|
||||
]
|
||||
|
||||
# Tasks don't have trigger_type — skip them when filtering by trigger type
|
||||
if trigger_type:
|
||||
tasks: list[Task] = []
|
||||
else:
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
|
||||
runs = workflow_runs + tasks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Sequence
|
||||
|
||||
|
|
@ -15,13 +16,15 @@ from skyvern.forge.sdk.db.models import (
|
|||
ActionModel,
|
||||
StepModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from skyvern.forge.sdk.db.utils import convert_to_step, convert_to_task, hydrate_action, serialize_proxy_location
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
|
||||
from skyvern.forge.sdk.utils.sanitization import sanitize_postgres_text
|
||||
from skyvern.schemas.runs import ProxyLocationInput
|
||||
from skyvern.schemas.runs import ProxyLocationInput, RunStatus, RunType
|
||||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
|
|
@ -29,6 +32,8 @@ LOG = structlog.get_logger()
|
|||
|
||||
|
||||
class TasksRepository(BaseRepository):
|
||||
_background_tasks: set[asyncio.Task] = set() # noqa: RUF012
|
||||
|
||||
@db_operation("create_task")
|
||||
async def create_task(
|
||||
self,
|
||||
|
|
@ -470,6 +475,22 @@ class TasksRepository(BaseRepository):
|
|||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
|
||||
# Best-effort fire-and-forget write-through to task_runs.
|
||||
# Mirrors the WorkflowService pattern — cron catches any missed syncs.
|
||||
if status is not None:
|
||||
bg = asyncio.create_task(
|
||||
self.sync_task_run_status(
|
||||
organization_id=updated_task.organization_id or "",
|
||||
run_id=updated_task.task_id,
|
||||
status=status.value,
|
||||
started_at=updated_task.started_at,
|
||||
finished_at=updated_task.finished_at,
|
||||
),
|
||||
)
|
||||
self._background_tasks.add(bg)
|
||||
bg.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
|
|
@ -729,3 +750,197 @@ class TasksRepository(BaseRepository):
|
|||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def sync_task_run_status(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
status: str,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Best-effort write-through: propagate status from source table to task_runs.
|
||||
|
||||
Does NOT raise if the task_runs row is missing (race at creation time).
|
||||
"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
vals: dict[str, Any] = {"status": status}
|
||||
if started_at is not None:
|
||||
vals["started_at"] = started_at
|
||||
if finished_at is not None:
|
||||
vals["finished_at"] = finished_at
|
||||
stmt = (
|
||||
update(TaskRunModel)
|
||||
.where(TaskRunModel.run_id == run_id)
|
||||
.where(TaskRunModel.organization_id == organization_id)
|
||||
.values(**vals)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort task_run status sync failed",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: RunStatus | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
# script_run, started_at, finished_at are intentionally omitted here —
|
||||
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
|
||||
) -> Run:
|
||||
searchable_text = " ".join(filter(None, [title, url]))
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
status=status,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
debug_session_id=debug_session_id,
|
||||
searchable_text=searchable_text or None,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title is not None:
|
||||
task_run.title = title
|
||||
if url is not None:
|
||||
task_run.url = url
|
||||
if url_hash is not None:
|
||||
task_run.url_hash = url_hash
|
||||
if status is not None:
|
||||
task_run.status = status
|
||||
if started_at is not None:
|
||||
task_run.started_at = started_at
|
||||
if finished_at is not None:
|
||||
task_run.finished_at = finished_at
|
||||
|
||||
# Recompute searchable_text when title or url changes
|
||||
if title is not None or url is not None:
|
||||
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
|
||||
|
||||
await session.commit()
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone
|
|||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db._error_handling import db_operation
|
||||
|
|
@ -26,7 +26,6 @@ from skyvern.forge.sdk.db.models import (
|
|||
OutputParameterModel,
|
||||
TaskGenerationModel,
|
||||
TaskModel,
|
||||
TaskRunModel,
|
||||
WorkflowCopilotChatMessageModel,
|
||||
WorkflowCopilotChatModel,
|
||||
WorkflowParameterModel,
|
||||
|
|
@ -39,7 +38,6 @@ from skyvern.forge.sdk.db.utils import (
|
|||
hydrate_action,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
||||
|
|
@ -61,7 +59,6 @@ from skyvern.forge.sdk.workflow.models.parameter import (
|
|||
WorkflowParameter,
|
||||
WorkflowParameterType,
|
||||
)
|
||||
from skyvern.schemas.runs import RunStatus, RunType
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
|
@ -564,197 +561,3 @@ class WorkflowParametersRepository(BaseRepository):
|
|||
|
||||
actions = (await session.scalars(query)).all()
|
||||
return [Action.model_validate(action) for action in actions]
|
||||
|
||||
@db_operation("create_task_run")
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: RunStatus | None = None,
|
||||
workflow_permanent_id: str | None = None,
|
||||
parent_workflow_run_id: str | None = None,
|
||||
debug_session_id: str | None = None,
|
||||
# script_run, started_at, finished_at are intentionally omitted here —
|
||||
# they are set via update_task_run() after the run starts/finishes (PRs 2-5).
|
||||
) -> Run:
|
||||
searchable_text = " ".join(filter(None, [title, url]))
|
||||
async with self.Session() as session:
|
||||
task_run = TaskRunModel(
|
||||
task_run_type=task_run_type,
|
||||
organization_id=organization_id,
|
||||
run_id=run_id,
|
||||
title=title,
|
||||
url=url,
|
||||
url_hash=url_hash,
|
||||
status=status,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
debug_session_id=debug_session_id,
|
||||
searchable_text=searchable_text or None,
|
||||
)
|
||||
session.add(task_run)
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
|
||||
@db_operation("update_task_run")
|
||||
async def update_task_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
title: str | None = None,
|
||||
url: str | None = None,
|
||||
url_hash: str | None = None,
|
||||
status: str | None = None,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
raise NotFoundError(f"TaskRun {run_id} not found")
|
||||
|
||||
if title is not None:
|
||||
task_run.title = title
|
||||
if url is not None:
|
||||
task_run.url = url
|
||||
if url_hash is not None:
|
||||
task_run.url_hash = url_hash
|
||||
if status is not None:
|
||||
task_run.status = status
|
||||
if started_at is not None:
|
||||
task_run.started_at = started_at
|
||||
if finished_at is not None:
|
||||
task_run.finished_at = finished_at
|
||||
|
||||
# Recompute searchable_text when title or url changes
|
||||
if title is not None or url is not None:
|
||||
task_run.searchable_text = " ".join(filter(None, [task_run.title, task_run.url])) or None
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def sync_task_run_status(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
status: str,
|
||||
started_at: datetime | None = None,
|
||||
finished_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Best-effort write-through: propagate status from source table to task_runs.
|
||||
|
||||
Does NOT raise if the task_runs row is missing (race at creation time).
|
||||
"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
vals: dict[str, Any] = {"status": status}
|
||||
if started_at is not None:
|
||||
vals["started_at"] = started_at
|
||||
if finished_at is not None:
|
||||
vals["finished_at"] = finished_at
|
||||
stmt = (
|
||||
update(TaskRunModel)
|
||||
.where(TaskRunModel.run_id == run_id)
|
||||
.where(TaskRunModel.organization_id == organization_id)
|
||||
.values(**vals)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort task_run status sync failed",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
status=status,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@db_operation("update_job_run_compute_cost")
|
||||
async def update_job_run_compute_cost(
|
||||
self,
|
||||
organization_id: str,
|
||||
run_id: str,
|
||||
instance_type: str | None = None,
|
||||
vcpu_millicores: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
compute_cost: float | None = None,
|
||||
) -> None:
|
||||
"""Update compute cost metrics for a job run."""
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(run_id=run_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first()
|
||||
if not task_run:
|
||||
LOG.warning(
|
||||
"TaskRun not found for compute cost update",
|
||||
run_id=run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
if instance_type is not None:
|
||||
task_run.instance_type = instance_type
|
||||
if vcpu_millicores is not None:
|
||||
task_run.vcpu_millicores = vcpu_millicores
|
||||
if memory_mb is not None:
|
||||
task_run.memory_mb = memory_mb
|
||||
if duration_ms is not None:
|
||||
task_run.duration_ms = duration_ms
|
||||
if compute_cost is not None:
|
||||
task_run.compute_cost = compute_cost
|
||||
await session.commit()
|
||||
|
||||
@db_operation("cache_task_run")
|
||||
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> Run:
|
||||
async with self.Session() as session:
|
||||
task_run = (
|
||||
await session.scalars(
|
||||
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
|
||||
)
|
||||
).first()
|
||||
if task_run:
|
||||
task_run.cached = True
|
||||
await session.commit()
|
||||
await session.refresh(task_run)
|
||||
return Run.model_validate(task_run)
|
||||
raise NotFoundError(f"Run {run_id} not found")
|
||||
|
||||
@db_operation("get_cached_task_run")
|
||||
async def get_cached_task_run(
|
||||
self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel)
|
||||
if task_run_type:
|
||||
query = query.filter_by(task_run_type=task_run_type)
|
||||
if url_hash:
|
||||
query = query.filter_by(url_hash=url_hash)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
||||
@db_operation("get_run")
|
||||
async def get_run(
|
||||
self,
|
||||
run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> Run | None:
|
||||
async with self.Session() as session:
|
||||
query = select(TaskRunModel).filter_by(run_id=run_id)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
task_run = (await session.scalars(query)).first()
|
||||
return Run.model_validate(task_run) if task_run else None
|
||||
|
|
|
|||
|
|
@ -307,7 +307,6 @@ class WorkflowRunsRepository(BaseRepository):
|
|||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
status: list[WorkflowRunStatus] | None = None,
|
||||
trigger_type: list[WorkflowRunTriggerType] | None = None,
|
||||
include_debugger_runs: bool = False,
|
||||
search_key: str | None = None,
|
||||
) -> list[WorkflowRun | Task]:
|
||||
|
|
@ -367,10 +366,6 @@ class WorkflowRunsRepository(BaseRepository):
|
|||
|
||||
if status:
|
||||
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
|
||||
if trigger_type:
|
||||
workflow_run_query = workflow_run_query.filter(
|
||||
WorkflowRunModel.trigger_type.in_([t.value for t in trigger_type])
|
||||
)
|
||||
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
|
||||
workflow_run_query_result = (await session.execute(workflow_run_query)).all()
|
||||
workflow_runs = [
|
||||
|
|
@ -378,20 +373,16 @@ class WorkflowRunsRepository(BaseRepository):
|
|||
for run, title in workflow_run_query_result
|
||||
]
|
||||
|
||||
# Tasks don't have trigger_type — skip them when filtering by trigger type
|
||||
if trigger_type:
|
||||
tasks: list[Task] = []
|
||||
else:
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
task_query = (
|
||||
select(TaskModel)
|
||||
.filter(TaskModel.organization_id == organization_id)
|
||||
.filter(TaskModel.workflow_run_id.is_(None))
|
||||
)
|
||||
if status:
|
||||
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||
task_query_result = (await session.scalars(task_query)).all()
|
||||
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||
|
||||
runs = workflow_runs + tasks
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from skyvern.forge.sdk.core import skyvern_context
|
|||
from skyvern.forge.sdk.core.curl_converter import curl_to_http_request_block_params
|
||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, WorkflowRunTriggerType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.routes.code_samples import (
|
||||
|
|
@ -2120,7 +2120,6 @@ async def get_runs(
|
|||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1),
|
||||
status: Annotated[list[WorkflowRunStatus] | None, Query()] = None,
|
||||
trigger_type: Annotated[list[WorkflowRunTriggerType] | None, Query()] = None,
|
||||
search_key: str | None = Query(
|
||||
None,
|
||||
description=(
|
||||
|
|
@ -2139,12 +2138,7 @@ async def get_runs(
|
|||
return []
|
||||
|
||||
runs = await app.DATABASE.get_all_runs(
|
||||
current_org.organization_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
status=status,
|
||||
trigger_type=trigger_type,
|
||||
search_key=search_key,
|
||||
current_org.organization_id, page=page, page_size=page_size, status=status, search_key=search_key
|
||||
)
|
||||
return ORJSONResponse([run.model_dump() for run in runs])
|
||||
|
||||
|
|
|
|||
|
|
@ -1330,8 +1330,12 @@ async def review_script_with_instructions(
|
|||
workflow_run_id=data.workflow_run_id,
|
||||
)
|
||||
for wf_param, run_param in run_param_tuples:
|
||||
if isinstance(run_param.value, str) and run_param.value:
|
||||
run_parameter_values[wf_param.key] = run_param.value
|
||||
if (
|
||||
run_param.value is not None
|
||||
and str(run_param.value).strip()
|
||||
and not wf_param.parameter_type.is_secret_or_credential()
|
||||
):
|
||||
run_parameter_values[wf_param.key] = str(run_param.value)
|
||||
except Exception:
|
||||
LOG.warning("Failed to load run parameter values", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -3655,12 +3655,11 @@ class WorkflowService:
|
|||
)
|
||||
# Best-effort fire-and-forget write-through to task_runs table.
|
||||
# Runs off the hot path so workflow status transitions stay fast.
|
||||
# Hold a reference in _background_tasks so Python doesn't GC the task.
|
||||
task = asyncio.create_task(
|
||||
bg = asyncio.create_task(
|
||||
self._sync_task_run_from_workflow_run(workflow_run, workflow_run_id, status),
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
self._background_tasks.add(bg)
|
||||
bg.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
|
@ -5454,8 +5453,12 @@ class WorkflowService:
|
|||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
for wf_param, run_param in run_param_tuples:
|
||||
if isinstance(run_param.value, str) and run_param.value:
|
||||
run_parameter_values[wf_param.key] = run_param.value
|
||||
if (
|
||||
run_param.value is not None
|
||||
and str(run_param.value).strip()
|
||||
and not wf_param.parameter_type.is_secret_or_credential()
|
||||
):
|
||||
run_parameter_values[wf_param.key] = str(run_param.value)
|
||||
except Exception:
|
||||
LOG.debug("Failed to load run parameter values for hardcoded-value check", exc_info=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Literal, Sequence
|
||||
|
|
@ -113,6 +114,33 @@ class ScriptReviewer:
|
|||
history_by_block[ep.block_label] = []
|
||||
history_by_block[ep.block_label].append(ep)
|
||||
|
||||
# Batch-load parameter values for historical episodes so the reviewer
|
||||
# can detect per-run values (e.g., different provider names across runs).
|
||||
# Passed explicitly to _review_block to avoid implicit instance state.
|
||||
historical_run_params: dict[str, dict[str, str]] = {}
|
||||
if historical_episodes:
|
||||
unique_run_ids = list({ep.workflow_run_id for ep in historical_episodes if ep.workflow_run_id})[:20]
|
||||
|
||||
async def _load_run_params(run_id: str) -> tuple[str, dict[str, str]]:
|
||||
try:
|
||||
param_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=run_id)
|
||||
params = {
|
||||
wf_param.key: str(run_param.value)
|
||||
for wf_param, run_param in param_tuples
|
||||
if run_param.value is not None
|
||||
and str(run_param.value).strip()
|
||||
and not wf_param.parameter_type.is_secret_or_credential()
|
||||
}
|
||||
return run_id, params
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to load params for historical episode run", workflow_run_id=run_id, exc_info=True
|
||||
)
|
||||
return run_id, {}
|
||||
|
||||
results = await asyncio.gather(*[_load_run_params(rid) for rid in unique_run_ids])
|
||||
historical_run_params = {rid: params for rid, params in results if params}
|
||||
|
||||
# Triage failed episodes — skip non-code-fixable failures.
|
||||
# When user provides explicit instructions, skip triage entirely.
|
||||
if user_instructions:
|
||||
|
|
@ -161,6 +189,7 @@ class ScriptReviewer:
|
|||
historical_episodes=history_by_block.get(block_label),
|
||||
run_parameter_values=run_parameter_values,
|
||||
user_instructions=user_instructions,
|
||||
historical_run_params=historical_run_params,
|
||||
)
|
||||
if updated_code:
|
||||
updated_blocks[block_label] = updated_code
|
||||
|
|
@ -529,6 +558,7 @@ class ScriptReviewer:
|
|||
run_parameter_values: dict[str, str] | None = None,
|
||||
user_instructions: str | None = None,
|
||||
preloaded_code: str | None = None,
|
||||
historical_run_params: dict[str, dict[str, str]] | None = None,
|
||||
) -> str | None:
|
||||
"""Review a single block's fallback episodes and generate updated code."""
|
||||
LOG.info(
|
||||
|
|
@ -604,16 +634,20 @@ class ScriptReviewer:
|
|||
goal_param_keys = set(re.findall(r"\{\{\s*(\w+)\s*\}\}", navigation_goal))
|
||||
parameter_keys = sorted(goal_param_keys | set(all_parameter_keys or []))
|
||||
|
||||
# Build historical episode summaries for cross-run context
|
||||
# Build historical episode summaries for cross-run context.
|
||||
# Include per-run parameter values so the reviewer can detect that
|
||||
# different runs had different names/IDs (→ selectors must be dynamic).
|
||||
history_summaries = []
|
||||
for ep in historical_episodes or []:
|
||||
history_summaries.append(
|
||||
{
|
||||
"error_message": ep.error_message,
|
||||
"reviewer_output": (ep.reviewer_output or "")[:500],
|
||||
"fallback_succeeded": ep.fallback_succeeded,
|
||||
}
|
||||
)
|
||||
summary: dict[str, object] = {
|
||||
"error_message": ep.error_message,
|
||||
"reviewer_output": (ep.reviewer_output or "")[:500],
|
||||
"fallback_succeeded": ep.fallback_succeeded,
|
||||
}
|
||||
ep_params = (historical_run_params or {}).get(ep.workflow_run_id)
|
||||
if ep_params:
|
||||
summary["run_parameters"] = ep_params
|
||||
history_summaries.append(summary)
|
||||
|
||||
# Build the reviewer prompt
|
||||
reviewer_prompt = prompt_engine.load_prompt(
|
||||
|
|
@ -636,6 +670,7 @@ class ScriptReviewer:
|
|||
stale_branches=stale_branch_info,
|
||||
parameter_keys=parameter_keys,
|
||||
historical_episodes=history_summaries,
|
||||
run_parameter_values=run_parameter_values,
|
||||
user_instructions=user_instructions,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,14 +19,14 @@ def mock_session():
|
|||
def agent_db(mock_session):
|
||||
db = AgentDB.__new__(AgentDB)
|
||||
db.Session = MagicMock(return_value=mock_session)
|
||||
# Set up workflow_params with the same session factory for delegation
|
||||
from skyvern.forge.sdk.db.repositories.workflow_parameters import WorkflowParametersRepository
|
||||
# Set up tasks repository (sync_task_run_status delegates to self.tasks)
|
||||
from skyvern.forge.sdk.db.repositories.tasks import TasksRepository
|
||||
|
||||
wp = WorkflowParametersRepository.__new__(WorkflowParametersRepository)
|
||||
wp.Session = MagicMock(return_value=mock_session)
|
||||
wp.debug_enabled = False
|
||||
wp._is_retryable_error_fn = None
|
||||
db.workflow_params = wp
|
||||
tasks = TasksRepository.__new__(TasksRepository)
|
||||
tasks.Session = MagicMock(return_value=mock_session)
|
||||
tasks.debug_enabled = False
|
||||
tasks._is_retryable_error_fn = None
|
||||
db.tasks = tasks
|
||||
return db
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def _load_task_runs_sync_activity_module(monkeypatch: pytest.MonkeyPatch):
|
|||
monkeypatch.setitem(sys.modules, "temporalio", temporalio_module)
|
||||
monkeypatch.setitem(sys.modules, "structlog", structlog_module)
|
||||
|
||||
module_path = _SOURCE_FILE
|
||||
module_path = _repo_root() / "workers" / "cron_worker" / "task_runs_sync_activity.py"
|
||||
spec = importlib.util.spec_from_file_location("test_task_runs_sync_activity_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec is not None and spec.loader is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue