support get_run, run_task in SkyvernAgent (#2049)

This commit is contained in:
Shuchang Zheng 2025-03-31 08:26:53 -07:00 committed by GitHub
parent 83ad2adabd
commit 3bcd7db2bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 227 additions and 26 deletions

View file

@ -46,7 +46,7 @@ class RunTask(SkyvernTaskBaseTool):
if url is not None:
task_request.url = url
return await self.agent.run_task(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds)
return await self.agent.run_task_v1(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds)
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2:
task_request = TaskV2Request(user_prompt=user_prompt, url=url)
@ -72,7 +72,7 @@ class DispatchTask(SkyvernTaskBaseTool):
if url is not None:
task_request.url = url
return await self.agent.create_task(task_request=task_request)
return await self.agent.create_task_v1(task_request=task_request)
async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2:
task_request = TaskV2Request(user_prompt=user_prompt, url=url)

View file

@ -104,7 +104,7 @@ class SkyvernTaskToolSpec(BaseToolSpec):
if url is not None:
task_request.url = url
return await self.agent.run_task(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds)
return await self.agent.run_task_v1(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds)
async def dispatch_task_v1(self, user_prompt: str, url: Optional[str] = None) -> CreateTaskResponse:
task_generation = await self._generate_v1_task_request(user_prompt=user_prompt)
@ -112,7 +112,7 @@ class SkyvernTaskToolSpec(BaseToolSpec):
if url is not None:
task_request.url = url
return await self.agent.create_task(task_request=task_request)
return await self.agent.create_task_v1(task_request=task_request)
async def get_task_v1(self, task_id: str) -> TaskResponse | None:
return await self.agent.get_task(task_id=task_id)

View file

@ -1,9 +1,16 @@
import asyncio
import os
import subprocess
from typing import Any, cast
from dotenv import load_dotenv
from skyvern.agent.client import SkyvernClient
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import security, skyvern_context
from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.organizations import Organization
@ -11,14 +18,58 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.services import task_v2_service
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
from skyvern.services import run_service, task_v1_service, task_v2_service
from skyvern.utils import migrate_db
class SkyvernAgent:
def __init__(self) -> None:
load_dotenv(".env")
migrate_db()
def __init__(
self,
base_url: str | None = None,
api_key: str | None = None,
cdp_url: str | None = None,
browser_path: str | None = None,
browser_type: str | None = None,
) -> None:
self.skyvern_client: SkyvernClient | None = None
if base_url is None and api_key is None:
# TODO: run at the root wherever the code is initiated
load_dotenv(".env")
migrate_db()
# TODO: will this change the already imported settings?
# TODO: maybe refresh the settings
self.cdp_url = cdp_url
if browser_path:
# TODO validate browser_path
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
if "Chrome" in browser_path or "Brave" in browser_path or "Edge" in browser_path:
result = subprocess.run(
["/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", "--remote-debugging-port=9222"]
)
if result.returncode != 0:
raise Exception(f"Failed to open browser. browser_path: {browser_path}")
self.cdp_url = "http://127.0.0.1:9222"
settings.BROWSER_TYPE = "cdp-connect"
settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url
else:
raise ValueError(
f"Unsupported browser or invalid path: {browser_path}. "
"Here's a list of supported browsers Skyvern can connect to: Google Chrome, Brave Browser, Microsoft Edge, Firefox."
)
elif base_url is None and api_key is None:
if not browser_type:
if "BROWSER_TYPE" not in os.environ:
raise Exception("browser type is missing")
browser_type = os.environ["BROWSER_TYPE"]
settings.BROWSER_TYPE = browser_type
elif base_url and api_key:
self.client = SkyvernClient(base_url=base_url, api_key=api_key)
else:
raise ValueError("base_url and api_key must be both provided")
async def _get_organization(self) -> Organization:
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
@ -41,7 +92,7 @@ class SkyvernAgent:
)
return organization
async def _run_task(self, organization: Organization, task: Task) -> None:
async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
@ -58,13 +109,23 @@ class SkyvernAgent:
status=TaskStatus.running,
organization_id=organization.organization_id,
)
try:
skyvern_context.set(
SkyvernContext(
organization_id=organization.organization_id,
task_id=task.task_id,
max_steps_override=max_steps,
)
)
step, _, _ = await app.agent.execute_step(
organization=organization,
task=updated_task,
step=step,
api_key=org_auth_token.token if org_auth_token else None,
)
step, _, _ = await app.agent.execute_step(
organization=organization,
task=updated_task,
step=step,
api_key=org_auth_token.token if org_auth_token else None,
)
finally:
skyvern_context.reset()
async def _run_task_v2(self, organization: Organization, task_v2: TaskV2) -> None:
# mark task v2 as queued
@ -85,22 +146,15 @@ class SkyvernAgent:
task_v2_id=task_v2.observer_cruise_id,
)
async def create_task(
async def create_task_v1(
self,
task_request: TaskRequest,
) -> CreateTaskResponse:
organization = await self._get_organization()
created_task = await app.agent.create_task(task_request, organization.organization_id)
skyvern_context.set(
SkyvernContext(
organization_id=organization.organization_id,
task_id=created_task.task_id,
max_steps_override=created_task.max_steps_per_run,
)
)
asyncio.create_task(self._run_task(organization, created_task))
asyncio.create_task(self._run_task(organization, created_task, max_steps=task_request.max_steps_per_run))
return CreateTaskResponse(task_id=created_task.task_id)
async def get_task(
@ -138,12 +192,12 @@ class SkyvernAgent:
task=task, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
)
async def run_task(
async def run_task_v1(
self,
task_request: TaskRequest,
timeout_seconds: int = 600,
) -> TaskResponse:
created_task = await self.create_task(task_request)
created_task = await self.create_task_v1(task_request)
async with asyncio.timeout(timeout_seconds):
while True:
@ -187,3 +241,148 @@ class SkyvernAgent:
if refreshed_task_v2.status.is_final():
return refreshed_task_v2
await asyncio.sleep(1)
############### officially supported interfaces ###############
async def get_run(self, run_id: str) -> RunResponse | None:
if not self.client:
organization = await self._get_organization()
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
return await self.client.get_run(run_id)
async def run_task(
self,
prompt: str,
engine: RunEngine = RunEngine.skyvern_v1,
url: str | None = None,
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
title: str | None = None,
error_code_mapping: dict[str, str] | None = None,
data_extraction_schema: dict[str, Any] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
) -> TaskRunResponse:
if not self.client:
if engine == RunEngine.skyvern_v1:
data_extraction_goal = None
data_extraction_schema = data_extraction_schema
navigation_goal = prompt
navigation_payload = None
organization = await self._get_organization()
if not url:
task_generation = await task_v1_service.generate_task(
user_prompt=prompt,
organization=organization,
)
url = task_generation.url
navigation_goal = task_generation.navigation_goal or prompt
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
task_request = TaskRequest(
title=title,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=data_extraction_schema,
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
)
if wait_for_completion:
created_task = await app.agent.create_task(task_request, organization.organization_id)
url_hash = generate_url_hash(task_request.url)
await app.DATABASE.create_task_run(
task_run_type=RunType.task_v1,
organization_id=organization.organization_id,
run_id=created_task.task_id,
title=task_request.title,
url=task_request.url,
url_hash=url_hash,
)
try:
await self._run_task(organization, created_task)
run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj)
except Exception:
# TODO: better error handling and logging
run_obj = await self.get_run(run_id=created_task.task_id)
return cast(TaskRunResponse, run_obj)
else:
create_task_resp = await self.create_task_v1(task_request)
run_obj = await self.get_run(run_id=create_task_resp.task_id)
return cast(TaskRunResponse, run_obj)
elif engine == RunEngine.skyvern_v2:
# initialize task v2
organization = await self._get_organization()
task_v2 = await task_v2_service.initialize_task_v2(
organization=organization,
user_prompt=prompt,
user_url=url,
totp_identifier=totp_identifier,
totp_verification_url=totp_url,
webhook_callback_url=webhook_url,
proxy_location=proxy_location,
publish_workflow=False,
extracted_information_schema=data_extraction_schema,
error_code_mapping=error_code_mapping,
create_task_run=True,
)
if wait_for_completion:
await self._run_task_v2(organization, task_v2)
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
return cast(TaskRunResponse, run_obj)
else:
asyncio.create_task(self._run_task_v2(organization, task_v2))
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
return cast(TaskRunResponse, run_obj)
else:
raise ValueError("Local mode is not supported for this method")
task_run = await self.client.run_task(
prompt=prompt,
engine=engine,
url=url,
webhook_url=webhook_url,
totp_identifier=totp_identifier,
totp_url=totp_url,
title=title,
error_code_mapping=error_code_mapping,
proxy_location=proxy_location,
max_steps=max_steps,
)
if wait_for_completion:
async with asyncio.timeout(timeout):
while True:
task_run = await self.client.get_run(task_run.run_id)
if task_run.status.is_final():
return task_run
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
return task_run
async def run_workflow(
self,
workflow_id: str,
parameters: dict[str, Any],
webhook_url: str | None = None,
totp_identifier: str | None = None,
totp_url: str | None = None,
title: str | None = None,
error_code_mapping: dict[str, str] | None = None,
proxy_location: ProxyLocation | None = None,
max_steps: int | None = None,
wait_for_completion: bool = True,
timeout: float = DEFAULT_AGENT_TIMEOUT,
browser_session_id: str | None = None,
) -> None:
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")

View file

@ -0,0 +1,2 @@
DEFAULT_AGENT_TIMEOUT = 1800 # 30 minutes
DEFAULT_AGENT_HEARTBEAT_INTERVAL = 10 # 10 seconds