mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-04-25 17:09:10 +00:00
shu/add_workflow_runs_api (#2063)
This commit is contained in:
parent
f774135049
commit
e26b816f67
7 changed files with 166 additions and 51 deletions
|
@ -14,7 +14,7 @@ from skyvern.forge.prompts import prompt_engine
|
|||
from skyvern.forge.sdk.api.files import create_folder_if_not_exist
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskRequest, TaskResponse, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponse, WorkflowRunStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRunResponseBase, WorkflowRunStatus
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
|
||||
|
||||
|
@ -71,7 +71,7 @@ class SkyvernClient:
|
|||
assert response.status_code == 200, f"Expected to get task response status 200, but got {response.status_code}"
|
||||
return TaskResponse(**response.json())
|
||||
|
||||
async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponse:
|
||||
async def get_workflow_run(self, workflow_pid: str, workflow_run_id: str) -> WorkflowRunResponseBase:
|
||||
url = f"{self.base_url}/workflows/{workflow_pid}/runs/{workflow_run_id}"
|
||||
headers = {"x-api-key": self.credentials}
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
@ -79,7 +79,7 @@ class SkyvernClient:
|
|||
assert response.status_code == 200, (
|
||||
f"Expected to get workflow run response status 200, but got {response.status_code}"
|
||||
)
|
||||
return WorkflowRunResponse(**response.json())
|
||||
return WorkflowRunResponseBase(**response.json())
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
|
|
@ -15,6 +15,6 @@ setup_logger()
|
|||
|
||||
from skyvern.forge import app # noqa: E402, F401
|
||||
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponseBase # noqa: E402
|
||||
|
||||
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponse"]
|
||||
__all__ = ["SkyvernAgent", "SkyvernClient", "WorkflowRunResponseBase"]
|
||||
|
|
|
@ -55,13 +55,21 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
|||
Workflow,
|
||||
WorkflowRequestBody,
|
||||
WorkflowRun,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunResponseBase,
|
||||
WorkflowRunStatus,
|
||||
WorkflowStatus,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
|
||||
from skyvern.services import run_service, task_v1_service, task_v2_service
|
||||
from skyvern.schemas.runs import (
|
||||
RunEngine,
|
||||
RunResponse,
|
||||
RunType,
|
||||
TaskRunRequest,
|
||||
TaskRunResponse,
|
||||
WorkflowRunRequest,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
from skyvern.services import run_service, task_v1_service, task_v2_service, workflow_service
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
@ -620,7 +628,7 @@ async def get_actions(
|
|||
tags=["agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "run_workflow",
|
||||
"x-fern-sdk-method-name": "run_workflow_legacy",
|
||||
},
|
||||
)
|
||||
@legacy_base_router.post(
|
||||
|
@ -628,7 +636,7 @@ async def get_actions(
|
|||
response_model=RunWorkflowResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def run_workflow(
|
||||
async def run_workflow_legacy(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
workflow_id: str, # this is the workflow_permanent_id internally
|
||||
|
@ -647,42 +655,19 @@ async def run_workflow(
|
|||
browser_session_id=workflow_request.browser_session_id,
|
||||
)
|
||||
|
||||
if template:
|
||||
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
|
||||
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
|
||||
|
||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||
request_id=request_id,
|
||||
workflow_run = await workflow_service.run_workflow(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_request=workflow_request,
|
||||
workflow_permanent_id=workflow_id,
|
||||
organization_id=current_org.organization_id,
|
||||
template=template,
|
||||
version=version,
|
||||
max_steps_override=x_max_steps_override,
|
||||
is_template_workflow=template,
|
||||
)
|
||||
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
|
||||
workflow_permanent_id=workflow_id,
|
||||
organization_id=None if template else current_org.organization_id,
|
||||
version=version,
|
||||
)
|
||||
await app.DATABASE.create_task_run(
|
||||
task_run_type=RunType.workflow_run,
|
||||
organization_id=current_org.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
title=workflow.title,
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||
await AsyncExecutorFactory.get_executor().execute_workflow(
|
||||
max_steps=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
request_id=request_id,
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
browser_session_id=workflow_request.browser_session_id,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
|
||||
return RunWorkflowResponse(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
|
@ -806,7 +791,7 @@ async def get_workflow_run_timeline(
|
|||
|
||||
@legacy_base_router.get(
|
||||
"/workflows/runs/{workflow_run_id}",
|
||||
response_model=WorkflowRunResponse,
|
||||
response_model=WorkflowRunResponseBase,
|
||||
tags=["agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
|
@ -815,13 +800,13 @@ async def get_workflow_run_timeline(
|
|||
)
|
||||
@legacy_base_router.get(
|
||||
"/workflows/runs/{workflow_run_id}/",
|
||||
response_model=WorkflowRunResponse,
|
||||
response_model=WorkflowRunResponseBase,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_workflow_run(
|
||||
workflow_run_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> WorkflowRunResponse:
|
||||
) -> WorkflowRunResponseBase:
|
||||
analytics.capture("skyvern-oss-agent-workflow-run-get")
|
||||
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response_by_workflow_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
@ -1385,7 +1370,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
|
|||
|
||||
|
||||
@base_router.post(
|
||||
"/tasks",
|
||||
"/tasks/run",
|
||||
tags=["Agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
|
@ -1398,7 +1383,7 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
|
|||
400: {"description": "Invalid agent engine"},
|
||||
},
|
||||
)
|
||||
@base_router.post("/tasks/", include_in_schema=False)
|
||||
@base_router.post("/tasks/run/", include_in_schema=False)
|
||||
async def run_task(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
|
@ -1523,3 +1508,69 @@ async def run_task(
|
|||
),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/workflows/run",
|
||||
tags=["Agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "run_workflow",
|
||||
},
|
||||
description="Run a workflow",
|
||||
summary="Run a workflow",
|
||||
responses={
|
||||
200: {"description": "Successfully run workflow"},
|
||||
400: {"description": "Invalid workflow run request"},
|
||||
},
|
||||
)
|
||||
@base_router.post("/workflows/run/", include_in_schema=False)
|
||||
async def run_workflow(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
workflow_run_request: WorkflowRunRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
template: bool = Query(False),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> WorkflowRunResponse:
|
||||
analytics.capture("skyvern-oss-run-workflow")
|
||||
await PermissionCheckerFactory.get_instance().check(
|
||||
current_org, browser_session_id=workflow_run_request.browser_session_id
|
||||
)
|
||||
workflow_id = workflow_run_request.workflow_id
|
||||
context = skyvern_context.ensure_context()
|
||||
request_id = context.request_id
|
||||
legacy_workflow_request = WorkflowRequestBody(
|
||||
data=workflow_run_request.parameters,
|
||||
proxy_location=workflow_run_request.proxy_location,
|
||||
webhook_callback_url=workflow_run_request.webhook_url,
|
||||
totp_identifier=workflow_run_request.totp_identifier,
|
||||
totp_url=workflow_run_request.totp_url,
|
||||
browser_session_id=workflow_run_request.browser_session_id,
|
||||
)
|
||||
workflow_run = await workflow_service.run_workflow(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=current_org.organization_id,
|
||||
workflow_request=legacy_workflow_request,
|
||||
template=template,
|
||||
version=None,
|
||||
max_steps=x_max_steps_override,
|
||||
api_key=x_api_key,
|
||||
request_id=request_id,
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
|
||||
return WorkflowRunResponse(
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
run_type=RunType.workflow_run,
|
||||
status=str(workflow_run.status),
|
||||
output=workflow_run.output,
|
||||
failure_reason=workflow_run.failure_reason,
|
||||
created_at=workflow_run.created_at,
|
||||
modified_at=workflow_run.modified_at,
|
||||
run_request=workflow_run_request,
|
||||
downloaded_files=workflow_run.downloaded_files,
|
||||
recording_url=workflow_run.recording_url,
|
||||
)
|
||||
|
|
|
@ -133,7 +133,7 @@ class WorkflowRunOutputParameter(BaseModel):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
class WorkflowRunResponseBase(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
status: WorkflowRunStatus
|
||||
|
|
|
@ -80,7 +80,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
|||
WorkflowRun,
|
||||
WorkflowRunOutputParameter,
|
||||
WorkflowRunParameter,
|
||||
WorkflowRunResponse,
|
||||
WorkflowRunResponseBase,
|
||||
WorkflowRunStatus,
|
||||
WorkflowStatus,
|
||||
)
|
||||
|
@ -958,7 +958,7 @@ class WorkflowService:
|
|||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
include_cost: bool = False,
|
||||
) -> WorkflowRunResponse:
|
||||
) -> WorkflowRunResponseBase:
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
if workflow_run is None:
|
||||
LOG.error(f"Workflow run {workflow_run_id} not found")
|
||||
|
@ -977,7 +977,7 @@ class WorkflowService:
|
|||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
include_cost: bool = False,
|
||||
) -> WorkflowRunResponse:
|
||||
) -> WorkflowRunResponseBase:
|
||||
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id)
|
||||
if workflow is None:
|
||||
LOG.error(f"Workflow {workflow_permanent_id} not found")
|
||||
|
@ -1073,7 +1073,7 @@ class WorkflowService:
|
|||
# successful steps are the ones that have a status of completed and the total count of unique step.order
|
||||
successful_steps = [step for step in workflow_run_steps if step.status == StepStatus.completed]
|
||||
total_cost = 0.1 * (len(successful_steps) + len(text_prompt_blocks))
|
||||
return WorkflowRunResponse(
|
||||
return WorkflowRunResponseBase(
|
||||
workflow_id=workflow.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
status=workflow_run.status,
|
||||
|
|
|
@ -5,6 +5,7 @@ from zoneinfo import ZoneInfo
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.utils.url_validators import validate_url
|
||||
|
||||
|
||||
|
@ -206,6 +207,8 @@ class BaseRunResponse(BaseModel):
|
|||
output: dict | list | str | None = Field(
|
||||
default=None, description="Output data from the run, if any. Format depends on the schema in the input"
|
||||
)
|
||||
downloaded_files: list[FileInfo] | None = Field(default=None, description="List of files downloaded during the run")
|
||||
recording_url: str | None = Field(default=None, description="URL to the recording of the run")
|
||||
failure_reason: str | None = Field(default=None, description="Reason for failure if the run failed")
|
||||
created_at: datetime = Field(description="Timestamp when this run was created")
|
||||
modified_at: datetime = Field(description="Timestamp when this run was last modified")
|
||||
|
|
61
skyvern/services/workflow_service.py
Normal file
61
skyvern/services/workflow_service.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import structlog
|
||||
from fastapi import BackgroundTasks, Request
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.workflow.exceptions import InvalidTemplateWorkflowPermanentId
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody, WorkflowRun
|
||||
from skyvern.schemas.runs import RunType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
workflow_id: str,
|
||||
organization_id: str,
|
||||
workflow_request: WorkflowRequestBody, # this is the deprecated workflow request body
|
||||
template: bool = False,
|
||||
version: int | None = None,
|
||||
max_steps: int | None = None,
|
||||
api_key: str | None = None,
|
||||
request_id: str | None = None,
|
||||
request: Request | None = None,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> WorkflowRun:
|
||||
if template:
|
||||
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
|
||||
raise InvalidTemplateWorkflowPermanentId(workflow_permanent_id=workflow_id)
|
||||
|
||||
workflow_run = await app.WORKFLOW_SERVICE.setup_workflow_run(
|
||||
request_id=request_id,
|
||||
workflow_request=workflow_request,
|
||||
workflow_permanent_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
version=version,
|
||||
max_steps_override=max_steps,
|
||||
is_template_workflow=template,
|
||||
)
|
||||
workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id(
|
||||
workflow_permanent_id=workflow_id,
|
||||
organization_id=None if template else organization_id,
|
||||
version=version,
|
||||
)
|
||||
await app.DATABASE.create_task_run(
|
||||
task_run_type=RunType.workflow_run,
|
||||
organization_id=organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
title=workflow.title,
|
||||
)
|
||||
if max_steps:
|
||||
LOG.info("Overriding max steps per run", max_steps_override=max_steps)
|
||||
await AsyncExecutorFactory.get_executor().execute_workflow(
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
organization_id=organization_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
max_steps_override=max_steps,
|
||||
browser_session_id=workflow_request.browser_session_id,
|
||||
api_key=api_key,
|
||||
)
|
||||
return workflow_run
|
Loading…
Add table
Reference in a new issue