TaskRunType -> RunType (#2041)

This commit is contained in:
Shuchang Zheng 2025-03-30 18:34:48 -07:00 committed by GitHub
parent d54c2af544
commit 05e28931bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 45 additions and 41 deletions

View file

@ -33,7 +33,8 @@ repos:
language_version: python3.11 language_version: python3.11
exclude: | exclude: |
(?x)( (?x)(
^skyvern/client/.* ^skyvern/client/.*|
^skyvern/__init__.py
) )
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks

View file

@ -13,6 +13,7 @@ tracer.configure(
setup_logger() setup_logger()
from skyvern.forge import app # noqa: E402, F401
from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402 from skyvern.agent import SkyvernAgent, SkyvernClient # noqa: E402
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402 from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunResponse # noqa: E402

View file

@ -5,7 +5,7 @@ import httpx
from skyvern.config import settings from skyvern.config import settings
from skyvern.exceptions import SkyvernClientException from skyvern.exceptions import SkyvernClientException
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse
class SkyvernClient: class SkyvernClient:
@ -29,11 +29,11 @@ class SkyvernClient:
error_code_mapping: dict[str, str] | None = None, error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None, proxy_location: ProxyLocation | None = None,
max_steps: int | None = None, max_steps: int | None = None,
) -> TaskRunResponse: ) -> RunResponse:
if engine == RunEngine.skyvern_v1: if engine == RunEngine.skyvern_v1:
return TaskRunResponse() return RunResponse()
elif engine == RunEngine.skyvern_v2: elif engine == RunEngine.skyvern_v2:
return TaskRunResponse() return RunResponse()
raise ValueError(f"Invalid engine: {engine}") raise ValueError(f"Invalid engine: {engine}")
async def run_workflow( async def run_workflow(
@ -69,8 +69,8 @@ class SkyvernClient:
async def get_run( async def get_run(
self, self,
run_id: str, run_id: str,
) -> TaskRunResponse: ) -> RunResponse:
return TaskRunResponse() return RunResponse()
async def get_workflow_run( async def get_workflow_run(
self, self,

View file

@ -66,8 +66,8 @@ from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
from skyvern.forge.sdk.schemas.runs import TaskRun
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Status, Thought, ThoughtType
from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import OrderBy, SortDirection, Task, TaskStatus
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode from skyvern.forge.sdk.schemas.totp_codes import TOTPCode
@ -91,7 +91,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
WorkflowStatus, WorkflowStatus,
) )
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation, RunType
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput from skyvern.webeye.actions.models import AgentStepOutput
@ -2783,7 +2783,7 @@ class AgentDB:
async def create_task_run( async def create_task_run(
self, self,
task_run_type: TaskRunType, task_run_type: RunType,
organization_id: str, organization_id: str,
run_id: str, run_id: str,
title: str | None = None, title: str | None = None,
@ -2931,7 +2931,7 @@ class AgentDB:
raise NotFoundError(f"TaskRun {run_id} not found") raise NotFoundError(f"TaskRun {run_id} not found")
async def get_cached_task_run( async def get_cached_task_run(
self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None self, task_run_type: RunType, url_hash: str | None = None, organization_id: str | None = None
) -> TaskRun | None: ) -> TaskRun | None:
async with self.Session() as session: async with self.Session() as session:
query = select(TaskRunModel) query = select(TaskRunModel)

View file

@ -41,7 +41,6 @@ from skyvern.forge.sdk.schemas.organizations import (
OrganizationUpdate, OrganizationUpdate,
) )
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
from skyvern.forge.sdk.schemas.tasks import ( from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse, CreateTaskResponse,
@ -71,7 +70,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowStatus, WorkflowStatus,
) )
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest
from skyvern.services import task_v1_service, task_v2_service from skyvern.services import task_v1_service, task_v2_service
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse from skyvern.webeye.schemas import BrowserSessionResponse
@ -445,7 +444,7 @@ async def get_runs(
@base_router.get( @base_router.get(
"/runs/{run_id}", "/runs/{run_id}",
tags=["agent"], tags=["agent"],
response_model=TaskRunResponse, response_model=RunResponse,
openapi_extra={ openapi_extra={
"x-fern-sdk-group-name": "agent", "x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "get_run", "x-fern-sdk-method-name": "get_run",
@ -453,13 +452,13 @@ async def get_runs(
) )
@base_router.get( @base_router.get(
"/runs/{run_id}/", "/runs/{run_id}/",
response_model=TaskRunResponse, response_model=RunResponse,
include_in_schema=False, include_in_schema=False,
) )
async def get_run( async def get_run(
run_id: str, run_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskRunResponse: ) -> RunResponse:
task_run_response = await task_run_service.get_task_run_response( task_run_response = await task_run_service.get_task_run_response(
run_id, organization_id=current_org.organization_id run_id, organization_id=current_org.organization_id
) )
@ -683,7 +682,7 @@ async def run_workflow(
version=version, version=version,
) )
await app.DATABASE.create_task_run( await app.DATABASE.create_task_run(
task_run_type=TaskRunType.workflow_run, task_run_type=RunType.workflow_run,
organization_id=current_org.organization_id, organization_id=current_org.organization_id,
run_id=workflow_run.workflow_run_id, run_id=workflow_run.workflow_run_id,
title=workflow.title, title=workflow.title,
@ -1512,7 +1511,7 @@ async def run_task(
run_request: TaskRunRequest, run_request: TaskRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None, x_api_key: Annotated[str | None, Header()] = None,
) -> TaskRunResponse: ) -> RunResponse:
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
@ -1555,7 +1554,7 @@ async def run_task(
background_tasks=background_tasks, background_tasks=background_tasks,
) )
# build the task run response # build the task run response
return TaskRunResponse( return RunResponse(
run_id=task_v1_response.task_id, run_id=task_v1_response.task_id,
title=task_v1_response.title, title=task_v1_response.title,
status=str(task_v1_response.status), status=str(task_v1_response.status),
@ -1603,7 +1602,7 @@ async def run_task(
max_steps_override=run_request.max_steps, max_steps_override=run_request.max_steps,
browser_session_id=run_request.browser_session_id, browser_session_id=run_request.browser_session_id,
) )
return TaskRunResponse( return RunResponse(
run_id=task_v2.observer_cruise_id, run_id=task_v2.observer_cruise_id,
title=run_request.title, title=run_request.title,
status=str(task_v2.status), status=str(task_v2.status),

View file

@ -1,20 +1,15 @@
from datetime import datetime from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from skyvern.schemas.runs import RunType
class TaskRunType(StrEnum):
task_v1 = "task_v1"
task_v2 = "task_v2"
workflow_run = "workflow_run"
class TaskRun(BaseModel): class TaskRun(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
task_run_id: str task_run_id: str
task_run_type: TaskRunType task_run_type: RunType
run_id: str run_id: str
organization_id: str | None = None organization_id: str | None = None
title: str | None = None title: str | None = None

View file

@ -1,23 +1,23 @@
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType from skyvern.forge.sdk.schemas.runs import TaskRun
from skyvern.schemas.runs import RunEngine, TaskRunResponse from skyvern.schemas.runs import RunEngine, RunResponse, RunType
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
return await app.DATABASE.get_task_run(run_id, organization_id=organization_id) return await app.DATABASE.get_task_run(run_id, organization_id=organization_id)
async def get_task_run_response(run_id: str, organization_id: str | None = None) -> TaskRunResponse | None: async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
task_run = await get_task_run(run_id, organization_id=organization_id) task_run = await get_task_run(run_id, organization_id=organization_id)
if not task_run: if not task_run:
return None return None
if task_run.task_run_type == TaskRunType.task_v1: if task_run.task_run_type == RunType.task_v1:
# fetch task v1 from db and transform to task run response # fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id) task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id)
if not task_v1: if not task_v1:
return None return None
return TaskRunResponse( return RunResponse(
run_id=task_run.run_id, run_id=task_run.run_id,
engine=RunEngine.skyvern_v1, engine=RunEngine.skyvern_v1,
status=task_v1.status, status=task_v1.status,
@ -32,11 +32,11 @@ async def get_task_run_response(run_id: str, organization_id: str | None = None)
created_at=task_v1.created_at, created_at=task_v1.created_at,
modified_at=task_v1.modified_at, modified_at=task_v1.modified_at,
) )
elif task_run.task_run_type == TaskRunType.task_v2: elif task_run.task_run_type == RunType.task_v2:
task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id) task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id)
if not task_v2: if not task_v2:
return None return None
return TaskRunResponse( return RunResponse(
run_id=task_run.run_id, run_id=task_run.run_id,
engine=RunEngine.skyvern_v2, engine=RunEngine.skyvern_v2,
status=task_v2.status, status=task_v2.status,

View file

@ -86,6 +86,12 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
return None return None
class RunType(StrEnum):
task_v1 = "task_v1"
task_v2 = "task_v2"
workflow_run = "workflow_run"
class RunEngine(StrEnum): class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0" skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0" skyvern_v2 = "skyvern-2.0"
@ -101,12 +107,15 @@ class TaskRunStatus(StrEnum):
completed = "completed" completed = "completed"
canceled = "canceled" canceled = "canceled"
def is_final(self) -> bool:
return self in [self.failed, self.terminated, self.canceled, self.timed_out, self.completed]
class TaskRunRequest(BaseModel): class TaskRunRequest(BaseModel):
goal: str goal: str
url: str | None = None url: str | None = None
title: str | None = None title: str | None = None
engine: RunEngine = RunEngine.skyvern_v1 engine: RunEngine = RunEngine.skyvern_v2
proxy_location: ProxyLocation | None = None proxy_location: ProxyLocation | None = None
data_extraction_schema: dict | list | str | None = None data_extraction_schema: dict | list | str | None = None
error_code_mapping: dict[str, str] | None = None error_code_mapping: dict[str, str] | None = None
@ -126,7 +135,7 @@ class TaskRunRequest(BaseModel):
return validate_url(url) return validate_url(url)
class TaskRunResponse(BaseModel): class RunResponse(BaseModel):
run_id: str run_id: str
engine: RunEngine = RunEngine.skyvern_v1 engine: RunEngine = RunEngine.skyvern_v1
status: TaskRunStatus status: TaskRunStatus

View file

@ -12,8 +12,8 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest
from skyvern.schemas.runs import RunType
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -84,7 +84,7 @@ async def run_task(
created_task = await app.agent.create_task(task, organization.organization_id) created_task = await app.agent.create_task(task, organization.organization_id)
url_hash = generate_url_hash(task.url) url_hash = generate_url_hash(task.url)
await app.DATABASE.create_task_run( await app.DATABASE.create_task_run(
task_run_type=TaskRunType.task_v1, task_run_type=RunType.task_v1,
organization_id=organization.organization_id, organization_id=organization.organization_id,
run_id=created_task.task_id, run_id=created_task.task_id,
title=task.title, title=task.title,

View file

@ -20,7 +20,6 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType
from skyvern.forge.sdk.workflow.models.block import ( from skyvern.forge.sdk.workflow.models.block import (
@ -53,7 +52,7 @@ from skyvern.forge.sdk.workflow.models.yaml import (
WorkflowCreateYAMLRequest, WorkflowCreateYAMLRequest,
WorkflowDefinitionYAML, WorkflowDefinitionYAML,
) )
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation, RunType
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
from skyvern.webeye.utils.page import SkyvernFrame from skyvern.webeye.utils.page import SkyvernFrame
@ -196,7 +195,7 @@ async def initialize_task_v2(
) )
if create_task_run: if create_task_run:
await app.DATABASE.create_task_run( await app.DATABASE.create_task_run(
task_run_type=TaskRunType.task_v2, task_run_type=RunType.task_v2,
organization_id=organization.organization_id, organization_id=organization.organization_id,
run_id=task_v2.observer_cruise_id, run_id=task_v2.observer_cruise_id,
title=new_workflow.title, title=new_workflow.title,