shu/add_workflow_runs_api ()

This commit is contained in:
Shuchang Zheng 2025-04-01 15:52:35 -04:00 committed by GitHub
parent f774135049
commit e26b816f67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 166 additions and 51 deletions
evaluation/core
skyvern

View file

@ -14,7 +14,7 @@ from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.files import create_folder_if_not_exist
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request
from skyvern.forge.sdk.schemas.tasks import TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponse, WorkflowRunStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponseBase, WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation
@ -71,7 +71,7 @@ class SkyvernClient:
assert response.status_code == 200, f"Expected to get task response status 200, but got {response.status_code}"
return TaskResponse(**response.json())
async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponse:
async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponseBase:
url = f"{self.base_url}/workflows/{workflow_pid}/runs/{workflow_run_id}"
headers = {"x-api-key": self.credentials}
async with httpx.AsyncClient() as client:
@ -79,7 +79,7 @@ class SkyvernClient:
assert response.status_code == 200, (
f"Expected to get workflow run response status 200, but got {response.status_code}"
)
return WorkflowRunResponse(**response.json())
return WorkflowRunResponseBase(**response.json())
class Evaluator:

View file

@ -15,6 +15,6 @@ setup_logger()
from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponse"]
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]

View file

@ -55,13 +55,21 @@ from skyvern.forge.sdk.workflow.models.workflow import (
Workflow,
WorkflowRequestBody,
WorkflowRun,
WorkflowRunResponse,
WorkflowRunResponseBase,
WorkflowRunStatus,
WorkflowStatus,
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.schemas.runs import (
RunEngine,
RunResponse,
RunType,
TaskRunRequest,
TaskRunResponse,
WorkflowRunRequest,
WorkflowRunResponse,
)
from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service
from skyvern.webeye.actions.actions import Action
LOG = structlog.get_logger()
@ -620,7 +628,7 @@ async def get_actions(
tags=["agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow",
"x-fern-sdk-method-name": "run_workflow_legacy",
},
)
@legacy_base_router.post(
@ -628,7 +636,7 @@ async def get_actions(
response_model=RunWorkflowResponse,
include_in_schema=False,
)
async def run_workflow(
async def run_workflow_legacy(
request: Request,
background_tasks: BackgroundTasks,
workflow_id: str, # this is the workflow_permanent_id internally
@ -647,42 +655,19 @@ async def run_workflow(
browser_session_id=workflow_request.browser_session_id,
)
if template:
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_run = await workflow_service.run_workflow(
workflow_id=workflow_id,
organization_id=current_org.organization_id,
workflow_request=workflow_request,
workflow_permanent_id=workflow_id,
organization_id=current_org.organization_id,
template=template,
version=version,
max_steps_override=x_max_steps_override,
is_template_workflow=template,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_id,
organization_id=None if template else current_org.organization_id,
version=version,
)
await app.DATABASE.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=current_org.organization_id,
run_id=workflow_run.workflow_run_id,
title=workflow.title,
)
if x_max_steps_override:
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_workflow(
max_steps=x_max_steps_override,
api_key=x_api_key,
request_id=request_id,
request=request,
background_tasks=background_tasks,
organization_id=current_org.organization_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override,
browser_session_id=workflow_request.browser_session_id,
api_key=x_api_key,
)
return RunWorkflowResponse(
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
@ -806,7 +791,7 @@ async def get_workflow_run_timeline(
@legacy_base_router.get(
"/workflows/runs/{workflow_run_id}",
response_model=WorkflowRunResponse,
response_model=WorkflowRunResponseBase,
tags=["agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
@ -815,13 +800,13 @@ async def get_workflow_run_timeline(
)
@legacy_base_router.get(
"/workflows/runs/{workflow_run_id}/",
response_model=WorkflowRunResponse,
response_model=WorkflowRunResponseBase,
include_in_schema=False,
)
async def get_workflow_run(
workflow_run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
analytics.capture("skyvern-oss-agent-workflow-run-get")
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(
workflow_run_id=workflow_run_id,
@ -1385,7 +1370,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
@base_router.post(
"/tasks",
"/tasks/run",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
@ -1398,7 +1383,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
400: {"description": "Invalid agent engine"},
},
)
@base_router.post("/tasks/", include_in_schema=False)
@base_router.post("/tasks/run/", include_in_schema=False)
async def run_task(
request: Request,
background_tasks: BackgroundTasks,
@ -1523,3 +1508,69 @@ async def run_task(
),
)
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
@base_router.post(
"/workflows/run",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "run_workflow",
},
description="Run a workflow",
summary="Run a workflow",
responses={
200: {"description": "Successfully run workflow"},
400: {"description": "Invalid workflow run request"},
},
)
@base_router.post("/workflows/run/", include_in_schema=False)
async def run_workflow(
request: Request,
background_tasks: BackgroundTasks,
workflow_run_request: WorkflowRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
template: bool = Query(False),
x_api_key: Annotated[str | None, Header()] = None,
x_max_steps_override: Annotated[int | None, Header()] = None,
) -> WorkflowRunResponse:
analytics.capture("skyvern-oss-run-workflow")
await PermissionCheckerFactory.get_instance().check(
current_org, browser_session_id=workflow_run_request.browser_session_id
)
workflow_id = workflow_run_request.workflow_id
context = skyvern_context.ensure_context()
request_id = context.request_id
legacy_workflow_request = WorkflowRequestBody(
data=workflow_run_request.parameters,
proxy_location=workflow_run_request.proxy_location,
webhook_callback_url=workflow_run_request.webhook_url,
totp_identifier=workflow_run_request.totp_identifier,
totp_url=workflow_run_request.totp_url,
browser_session_id=workflow_run_request.browser_session_id,
)
workflow_run = await workflow_service.run_workflow(
workflow_id=workflow_id,
organization_id=current_org.organization_id,
workflow_request=legacy_workflow_request,
template=template,
version=None,
max_steps=x_max_steps_override,
api_key=x_api_key,
request_id=request_id,
request=request,
background_tasks=background_tasks,
)
return WorkflowRunResponse(
run_id=workflow_run.workflow_run_id,
run_type=RunType.workflow_run,
status=str(workflow_run.status),
output=workflow_run.output,
failure_reason=workflow_run.failure_reason,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
run_request=workflow_run_request,
downloaded_files=workflow_run.downloaded_files,
recording_url=workflow_run.recording_url,
)

View file

@ -133,7 +133,7 @@ class WorkflowRunOutputParameter(BaseModel):
created_at: datetime
class WorkflowRunResponse(BaseModel):
class WorkflowRunResponseBase(BaseModel):
workflow_id: str
workflow_run_id: str
status: WorkflowRunStatus

View file

@ -80,7 +80,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRun,
WorkflowRunOutputParameter,
WorkflowRunParameter,
WorkflowRunResponse,
WorkflowRunResponseBase,
WorkflowRunStatus,
WorkflowStatus,
)
@ -958,7 +958,7 @@ class WorkflowService:
workflow_run_id: str,
organization_id: str,
include_cost: bool = False,
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
if workflow_run is None:
LOG.error(f"Workflow run {workflow_run_id} not found")
@ -977,7 +977,7 @@ class WorkflowService:
workflow_run_id: str,
organization_id: str,
include_cost: bool = False,
) -> WorkflowRunResponse:
) -> WorkflowRunResponseBase:
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id)
if workflow is None:
LOG.error(f"Workflow {workflow_permanent_id} not found")
@ -1073,7 +1073,7 @@ class WorkflowService:
# successful steps are the ones that have a status of completed and the total count of unique step.order
successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed]
total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks))
return WorkflowRunResponse(
return WorkflowRunResponseBase(
workflow_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id,
status=workflow_run.status,

View file

@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo
from pydantic import BaseModel, Field, field_validator
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.utils.url_validators import validate_url
@ -206,6 +207,8 @@ class BaseRunResponse(BaseModel):
output: dict | list | str | None = Field(
default=None, description="Output data from the run, if any. Format depends on the schema in the input"
)
downloaded_files: list[FileInfo] | None = Field(default=None, description="List of files downloaded during the run")
recording_url: str | None = Field(default=None, description="URL to the recording of the run")
failure_reason: str | None = Field(default=None, description="Reason for failure if the run failed")
created_at: datetime = Field(description="Timestamp when this run was created")
modified_at: datetime = Field(description="Timestamp when this run was last modified")

View file

@ -0,0 +1,61 @@
import structlog
from fastapi import BackgroundTasks, Request
from skyvern.forge import app
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.workflow.exceptions import InvalidTemplateWorkflowPermanentId
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun
from skyvern.schemas.runs import RunType
LOG = structlog.get_logger(__name__)
async def run_workflow(
workflow_id: str,
organization_id: str,
workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body
template: bool = False,
version: int | None = None,
max_steps: int | None = None,
api_key: str | None = None,
request_id: str | None = None,
request: Request | None = None,
background_tasks: BackgroundTasks | None = None,
) -> WorkflowRun:
if template:
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
request_id=request_id,
workflow_request=workflow_request,
workflow_permanent_id=workflow_id,
organization_id=organization_id,
version=version,
max_steps_override=max_steps,
is_template_workflow=template,
)
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
workflow_permanent_id=workflow_id,
organization_id=None if template else organization_id,
version=version,
)
await app.DATABASE.create_task_run(
task_run_type=RunType.workflow_run,
organization_id=organization_id,
run_id=workflow_run.workflow_run_id,
title=workflow.title,
)
if max_steps:
LOG.info("Overriding max steps per run", max_steps_override=max_steps)
await AsyncExecutorFactory.get_executor().execute_workflow(
request=request,
background_tasks=background_tasks,
organization_id=organization_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=max_steps,
browser_session_id=workflow_request.browser_session_id,
api_key=api_key,
)
return workflow_run