mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 02:30:07 +00:00
388 lines
16 KiB
Python
388 lines
16 KiB
Python
import asyncio
|
|
import os
|
|
import subprocess
|
|
from typing import Any, cast
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from skyvern.agent.client import SkyvernClient
|
|
from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT
|
|
from skyvern.config import settings
|
|
from skyvern.forge import app
|
|
from skyvern.forge.sdk.core import security, skyvern_context
|
|
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Status
|
|
from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus
|
|
from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME
|
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
|
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse
|
|
from skyvern.services import run_service, task_v1_service, task_v2_service
|
|
from skyvern.utils import migrate_db
|
|
|
|
|
|
class SkyvernAgent:
|
|
def __init__(
|
|
self,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
cdp_url: str | None = None,
|
|
browser_path: str | None = None,
|
|
browser_type: str | None = None,
|
|
) -> None:
|
|
self.skyvern_client: SkyvernClient | None = None
|
|
if base_url is None and api_key is None:
|
|
# TODO: run at the root wherever the code is initiated
|
|
load_dotenv(".env")
|
|
migrate_db()
|
|
# TODO: will this change the already imported settings?
|
|
# TODO: maybe refresh the settings
|
|
|
|
self.cdp_url = cdp_url
|
|
if browser_path:
|
|
# TODO validate browser_path
|
|
# Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox
|
|
if "Chrome" in browser_path or "Brave" in browser_path or "Edge" in browser_path:
|
|
result = subprocess.run(
|
|
["/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", "--remote-debugging-port=9222"]
|
|
)
|
|
if result.returncode != 0:
|
|
raise Exception(f"Failed to open browser. browser_path: {browser_path}")
|
|
|
|
self.cdp_url = "http://127.0.0.1:9222"
|
|
settings.BROWSER_TYPE = "cdp-connect"
|
|
settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported browser or invalid path: {browser_path}. "
|
|
"Here's a list of supported browsers Skyvern can connect to: Google Chrome, Brave Browser, Microsoft Edge, Firefox."
|
|
)
|
|
elif base_url is None and api_key is None:
|
|
if not browser_type:
|
|
if "BROWSER_TYPE" not in os.environ:
|
|
raise Exception("browser type is missing")
|
|
browser_type = os.environ["BROWSER_TYPE"]
|
|
|
|
settings.BROWSER_TYPE = browser_type
|
|
elif base_url and api_key:
|
|
self.client = SkyvernClient(base_url=base_url, api_key=api_key)
|
|
else:
|
|
raise ValueError("base_url and api_key must be both provided")
|
|
|
|
async def _get_organization(self) -> Organization:
|
|
organization = await app.DATABASE.get_organization_by_domain("skyvern.local")
|
|
if not organization:
|
|
organization = await app.DATABASE.create_organization(
|
|
organization_name="Skyvern-local",
|
|
domain="skyvern.local",
|
|
max_steps_per_run=10,
|
|
max_retries_per_step=3,
|
|
)
|
|
api_key = security.create_access_token(
|
|
organization.organization_id,
|
|
expires_delta=API_KEY_LIFETIME,
|
|
)
|
|
# generate OrganizationAutoToken
|
|
await app.DATABASE.create_org_auth_token(
|
|
organization_id=organization.organization_id,
|
|
token=api_key,
|
|
token_type=OrganizationAuthTokenType.api,
|
|
)
|
|
return organization
|
|
|
|
async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None:
|
|
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
|
organization_id=organization.organization_id,
|
|
token_type=OrganizationAuthTokenType.api,
|
|
)
|
|
|
|
step = await app.DATABASE.create_step(
|
|
task.task_id,
|
|
order=0,
|
|
retry_index=0,
|
|
organization_id=organization.organization_id,
|
|
)
|
|
updated_task = await app.DATABASE.update_task(
|
|
task.task_id,
|
|
status=TaskStatus.running,
|
|
organization_id=organization.organization_id,
|
|
)
|
|
try:
|
|
skyvern_context.set(
|
|
SkyvernContext(
|
|
organization_id=organization.organization_id,
|
|
task_id=task.task_id,
|
|
max_steps_override=max_steps,
|
|
)
|
|
)
|
|
|
|
step, _, _ = await app.agent.execute_step(
|
|
organization=organization,
|
|
task=updated_task,
|
|
step=step,
|
|
api_key=org_auth_token.token if org_auth_token else None,
|
|
)
|
|
finally:
|
|
skyvern_context.reset()
|
|
|
|
async def _run_task_v2(self, organization: Organization, task_v2: TaskV2) -> None:
|
|
# mark task v2 as queued
|
|
await app.DATABASE.update_task_v2(
|
|
task_v2_id=task_v2.observer_cruise_id,
|
|
status=TaskV2Status.queued,
|
|
organization_id=organization.organization_id,
|
|
)
|
|
|
|
assert task_v2.workflow_run_id
|
|
await app.DATABASE.update_workflow_run(
|
|
workflow_run_id=task_v2.workflow_run_id,
|
|
status=WorkflowRunStatus.queued,
|
|
)
|
|
|
|
await task_v2_service.run_task_v2(
|
|
organization=organization,
|
|
task_v2_id=task_v2.observer_cruise_id,
|
|
)
|
|
|
|
async def create_task_v1(
|
|
self,
|
|
task_request: TaskRequest,
|
|
) -> CreateTaskResponse:
|
|
organization = await self._get_organization()
|
|
|
|
created_task = await app.agent.create_task(task_request, organization.organization_id)
|
|
|
|
asyncio.create_task(self._run_task(organization, created_task, max_steps=task_request.max_steps_per_run))
|
|
return CreateTaskResponse(task_id=created_task.task_id)
|
|
|
|
async def get_task(
|
|
self,
|
|
task_id: str,
|
|
) -> TaskResponse | None:
|
|
organization = await self._get_organization()
|
|
task = await app.DATABASE.get_task(task_id, organization.organization_id)
|
|
|
|
if task is None:
|
|
return None
|
|
|
|
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization.organization_id)
|
|
if not latest_step:
|
|
return await app.agent.build_task_response(task=task)
|
|
|
|
failure_reason: str | None = None
|
|
if task.status == TaskStatus.failed and (task.failure_reason):
|
|
failure_reason = ""
|
|
if task.failure_reason:
|
|
failure_reason += task.failure_reason or ""
|
|
if latest_step.output is not None and latest_step.output.actions_and_results is not None:
|
|
action_results_string: list[str] = []
|
|
for action, results in latest_step.output.actions_and_results:
|
|
if len(results) == 0:
|
|
continue
|
|
if results[-1].success:
|
|
continue
|
|
action_results_string.append(f"{action.action_type} action failed.")
|
|
|
|
if len(action_results_string) > 0:
|
|
failure_reason += "(Exceptions: " + str(action_results_string) + ")"
|
|
|
|
return await app.agent.build_task_response(
|
|
task=task, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
|
|
)
|
|
|
|
async def run_task_v1(
|
|
self,
|
|
task_request: TaskRequest,
|
|
timeout_seconds: int = 600,
|
|
) -> TaskResponse:
|
|
created_task = await self.create_task_v1(task_request)
|
|
|
|
async with asyncio.timeout(timeout_seconds):
|
|
while True:
|
|
task_response = await self.get_task(created_task.task_id)
|
|
assert task_response is not None
|
|
if task_response.status.is_final():
|
|
return task_response
|
|
await asyncio.sleep(1)
|
|
|
|
async def observer_task_v_2(self, task_request: TaskV2Request) -> TaskV2:
|
|
organization = await self._get_organization()
|
|
|
|
task_v2 = await task_v2_service.initialize_task_v2(
|
|
organization=organization,
|
|
user_prompt=task_request.user_prompt,
|
|
user_url=str(task_request.url) if task_request.url else None,
|
|
totp_identifier=task_request.totp_identifier,
|
|
totp_verification_url=task_request.totp_verification_url,
|
|
webhook_callback_url=task_request.webhook_callback_url,
|
|
proxy_location=task_request.proxy_location,
|
|
publish_workflow=task_request.publish_workflow,
|
|
)
|
|
|
|
if not task_v2.workflow_run_id:
|
|
raise Exception("Task v2 missing workflow run id")
|
|
|
|
asyncio.create_task(self._run_task_v2(organization, task_v2))
|
|
return task_v2
|
|
|
|
async def get_observer_task_v_2(self, task_id: str) -> TaskV2 | None:
|
|
organization = await self._get_organization()
|
|
return await app.DATABASE.get_task_v2(task_id, organization.organization_id)
|
|
|
|
async def run_observer_task_v_2(self, task_request: TaskV2Request, timeout_seconds: int = 600) -> TaskV2:
|
|
task_v2 = await self.observer_task_v_2(task_request)
|
|
|
|
async with asyncio.timeout(timeout_seconds):
|
|
while True:
|
|
refreshed_task_v2 = await self.get_observer_task_v_2(task_v2.observer_cruise_id)
|
|
assert refreshed_task_v2 is not None
|
|
if refreshed_task_v2.status.is_final():
|
|
return refreshed_task_v2
|
|
await asyncio.sleep(1)
|
|
|
|
############### officially supported interfaces ###############
|
|
async def get_run(self, run_id: str) -> RunResponse | None:
|
|
if not self.client:
|
|
organization = await self._get_organization()
|
|
return await run_service.get_run_response(run_id, organization_id=organization.organization_id)
|
|
|
|
return await self.client.get_run(run_id)
|
|
|
|
async def run_task(
|
|
self,
|
|
prompt: str,
|
|
engine: RunEngine = RunEngine.skyvern_v1,
|
|
url: str | None = None,
|
|
webhook_url: str | None = None,
|
|
totp_identifier: str | None = None,
|
|
totp_url: str | None = None,
|
|
title: str | None = None,
|
|
error_code_mapping: dict[str, str] | None = None,
|
|
data_extraction_schema: dict[str, Any] | None = None,
|
|
proxy_location: ProxyLocation | None = None,
|
|
max_steps: int | None = None,
|
|
wait_for_completion: bool = True,
|
|
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
|
browser_session_id: str | None = None,
|
|
) -> TaskRunResponse:
|
|
if not self.client:
|
|
if engine == RunEngine.skyvern_v1:
|
|
data_extraction_goal = None
|
|
data_extraction_schema = data_extraction_schema
|
|
navigation_goal = prompt
|
|
navigation_payload = None
|
|
organization = await self._get_organization()
|
|
if not url:
|
|
task_generation = await task_v1_service.generate_task(
|
|
user_prompt=prompt,
|
|
organization=organization,
|
|
)
|
|
url = task_generation.url
|
|
navigation_goal = task_generation.navigation_goal or prompt
|
|
navigation_payload = task_generation.navigation_payload
|
|
data_extraction_goal = task_generation.data_extraction_goal
|
|
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
|
|
|
task_request = TaskRequest(
|
|
title=title,
|
|
url=url,
|
|
navigation_goal=navigation_goal,
|
|
navigation_payload=navigation_payload,
|
|
data_extraction_goal=data_extraction_goal,
|
|
extracted_information_schema=data_extraction_schema,
|
|
error_code_mapping=error_code_mapping,
|
|
proxy_location=proxy_location,
|
|
)
|
|
|
|
if wait_for_completion:
|
|
created_task = await app.agent.create_task(task_request, organization.organization_id)
|
|
url_hash = generate_url_hash(task_request.url)
|
|
await app.DATABASE.create_task_run(
|
|
task_run_type=RunType.task_v1,
|
|
organization_id=organization.organization_id,
|
|
run_id=created_task.task_id,
|
|
title=task_request.title,
|
|
url=task_request.url,
|
|
url_hash=url_hash,
|
|
)
|
|
try:
|
|
await self._run_task(organization, created_task)
|
|
run_obj = await self.get_run(run_id=created_task.task_id)
|
|
return cast(TaskRunResponse, run_obj)
|
|
except Exception:
|
|
# TODO: better error handling and logging
|
|
run_obj = await self.get_run(run_id=created_task.task_id)
|
|
return cast(TaskRunResponse, run_obj)
|
|
else:
|
|
create_task_resp = await self.create_task_v1(task_request)
|
|
run_obj = await self.get_run(run_id=create_task_resp.task_id)
|
|
return cast(TaskRunResponse, run_obj)
|
|
elif engine == RunEngine.skyvern_v2:
|
|
# initialize task v2
|
|
organization = await self._get_organization()
|
|
|
|
task_v2 = await task_v2_service.initialize_task_v2(
|
|
organization=organization,
|
|
user_prompt=prompt,
|
|
user_url=url,
|
|
totp_identifier=totp_identifier,
|
|
totp_verification_url=totp_url,
|
|
webhook_callback_url=webhook_url,
|
|
proxy_location=proxy_location,
|
|
publish_workflow=False,
|
|
extracted_information_schema=data_extraction_schema,
|
|
error_code_mapping=error_code_mapping,
|
|
create_task_run=True,
|
|
)
|
|
|
|
if wait_for_completion:
|
|
await self._run_task_v2(organization, task_v2)
|
|
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
|
|
return cast(TaskRunResponse, run_obj)
|
|
else:
|
|
asyncio.create_task(self._run_task_v2(organization, task_v2))
|
|
run_obj = await self.get_run(run_id=task_v2.observer_cruise_id)
|
|
return cast(TaskRunResponse, run_obj)
|
|
else:
|
|
raise ValueError("Local mode is not supported for this method")
|
|
|
|
task_run = await self.client.run_task(
|
|
prompt=prompt,
|
|
engine=engine,
|
|
url=url,
|
|
webhook_url=webhook_url,
|
|
totp_identifier=totp_identifier,
|
|
totp_url=totp_url,
|
|
title=title,
|
|
error_code_mapping=error_code_mapping,
|
|
proxy_location=proxy_location,
|
|
max_steps=max_steps,
|
|
)
|
|
|
|
if wait_for_completion:
|
|
async with asyncio.timeout(timeout):
|
|
while True:
|
|
task_run = await self.client.get_run(task_run.run_id)
|
|
if task_run.status.is_final():
|
|
return task_run
|
|
await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL)
|
|
return task_run
|
|
|
|
async def run_workflow(
|
|
self,
|
|
workflow_id: str,
|
|
parameters: dict[str, Any],
|
|
webhook_url: str | None = None,
|
|
totp_identifier: str | None = None,
|
|
totp_url: str | None = None,
|
|
title: str | None = None,
|
|
error_code_mapping: dict[str, str] | None = None,
|
|
proxy_location: ProxyLocation | None = None,
|
|
max_steps: int | None = None,
|
|
wait_for_completion: bool = True,
|
|
timeout: float = DEFAULT_AGENT_TIMEOUT,
|
|
browser_session_id: str | None = None,
|
|
) -> None:
|
|
raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.")
|