diff --git a/integrations/langchain/skyvern_langchain/agent.py b/integrations/langchain/skyvern_langchain/agent.py index 0f9f2c3d..bea0364b 100644 --- a/integrations/langchain/skyvern_langchain/agent.py +++ b/integrations/langchain/skyvern_langchain/agent.py @@ -46,7 +46,7 @@ class RunTask(SkyvernTaskBaseTool): if url is not None: task_request.url = url - return await self.agent.run_task(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds) + return await self.agent.run_task_v1(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds) async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2: task_request = TaskV2Request(user_prompt=user_prompt, url=url) @@ -72,7 +72,7 @@ class DispatchTask(SkyvernTaskBaseTool): if url is not None: task_request.url = url - return await self.agent.create_task(task_request=task_request) + return await self.agent.create_task_v1(task_request=task_request) async def _arun_task_v2(self, user_prompt: str, url: str | None = None) -> TaskV2: task_request = TaskV2Request(user_prompt=user_prompt, url=url) diff --git a/integrations/llama_index/skyvern_llamaindex/agent.py b/integrations/llama_index/skyvern_llamaindex/agent.py index c79ec3f4..1c6ecffd 100644 --- a/integrations/llama_index/skyvern_llamaindex/agent.py +++ b/integrations/llama_index/skyvern_llamaindex/agent.py @@ -104,7 +104,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): if url is not None: task_request.url = url - return await self.agent.run_task(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds) + return await self.agent.run_task_v1(task_request=task_request, timeout_seconds=self.run_task_timeout_seconds) async def dispatch_task_v1(self, user_prompt: str, url: Optional[str] = None) -> CreateTaskResponse: task_generation = await self._generate_v1_task_request(user_prompt=user_prompt) @@ -112,7 +112,7 @@ class SkyvernTaskToolSpec(BaseToolSpec): if url is not None: task_request.url = url - return await self.agent.create_task(task_request=task_request) + return await self.agent.create_task_v1(task_request=task_request) async def get_task_v1(self, task_id: str) -> TaskResponse | None: return await self.agent.get_task(task_id=task_id) diff --git a/skyvern/agent/agent.py b/skyvern/agent/agent.py index 144b97fa..ebcbebca 100644 --- a/skyvern/agent/agent.py +++ b/skyvern/agent/agent.py @@ -1,9 +1,16 @@ import asyncio +import os +import subprocess +from typing import Any, cast from dotenv import load_dotenv +from skyvern.agent.client import SkyvernClient +from skyvern.agent.constants import DEFAULT_AGENT_HEARTBEAT_INTERVAL, DEFAULT_AGENT_TIMEOUT +from skyvern.config import settings from skyvern.forge import app from skyvern.forge.sdk.core import security, skyvern_context +from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.schemas.organizations import Organization @@ -11,14 +18,58 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Request, TaskV2Statu from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.services.org_auth_token_service import API_KEY_LIFETIME from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus -from skyvern.services import task_v2_service +from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse +from skyvern.services import run_service, task_v1_service, task_v2_service from skyvern.utils import migrate_db class SkyvernAgent: - def __init__(self) -> None: - load_dotenv(".env") - migrate_db() + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + cdp_url: str | None = None, + browser_path: str | None = None, + browser_type: str | None = None, + ) -> None: + self.skyvern_client: SkyvernClient | None = None + if base_url is None and api_key is None: + # TODO: run at the root wherever the code is initiated + load_dotenv(".env") + migrate_db() + # TODO: will this change the already imported settings? + # TODO: maybe refresh the settings + + self.cdp_url = cdp_url + if browser_path: + # TODO validate browser_path + # Supported Browsers: Google Chrome, Brave Browser, Microsoft Edge, Firefox + if "Chrome" in browser_path or "Brave" in browser_path or "Edge" in browser_path: + result = subprocess.run( + ["/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", "--remote-debugging-port=9222"] + ) + if result.returncode != 0: + raise Exception(f"Failed to open browser. browser_path: {browser_path}") + + self.cdp_url = "http://127.0.0.1:9222" + settings.BROWSER_TYPE = "cdp-connect" + settings.BROWSER_REMOTE_DEBUGGING_URL = self.cdp_url + else: + raise ValueError( + f"Unsupported browser or invalid path: {browser_path}. " + "Here's a list of supported browsers Skyvern can connect to: Google Chrome, Brave Browser, Microsoft Edge, Firefox." + ) + elif base_url is None and api_key is None: + if not browser_type: + if "BROWSER_TYPE" not in os.environ: + raise Exception("browser type is missing") + browser_type = os.environ["BROWSER_TYPE"] + + settings.BROWSER_TYPE = browser_type + elif base_url and api_key: + self.client = SkyvernClient(base_url=base_url, api_key=api_key) + else: + raise ValueError("base_url and api_key must be both provided") async def _get_organization(self) -> Organization: organization = await app.DATABASE.get_organization_by_domain("skyvern.local") @@ -41,7 +92,7 @@ class SkyvernAgent: ) return organization - async def _run_task(self, organization: Organization, task: Task) -> None: + async def _run_task(self, organization: Organization, task: Task, max_steps: int | None = None) -> None: org_auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=organization.organization_id, token_type=OrganizationAuthTokenType.api, @@ -58,13 +109,23 @@ class SkyvernAgent: status=TaskStatus.running, organization_id=organization.organization_id, ) + try: + skyvern_context.set( + SkyvernContext( + organization_id=organization.organization_id, + task_id=task.task_id, + max_steps_override=max_steps, + ) + ) - step, _, _ = await app.agent.execute_step( - organization=organization, - task=updated_task, - step=step, - api_key=org_auth_token.token if org_auth_token else None, - ) + step, _, _ = await app.agent.execute_step( + organization=organization, + task=updated_task, + step=step, + api_key=org_auth_token.token if org_auth_token else None, + ) + finally: + skyvern_context.reset() async def _run_task_v2(self, organization: Organization, task_v2: TaskV2) -> None: # mark task v2 as queued @@ -85,22 +146,15 @@ class SkyvernAgent: task_v2_id=task_v2.observer_cruise_id, ) - async def create_task( + async def create_task_v1( self, task_request: TaskRequest, ) -> CreateTaskResponse: organization = await self._get_organization() created_task = await app.agent.create_task(task_request, organization.organization_id) - skyvern_context.set( - SkyvernContext( - organization_id=organization.organization_id, - task_id=created_task.task_id, - max_steps_override=created_task.max_steps_per_run, - ) - ) - asyncio.create_task(self._run_task(organization, created_task)) + asyncio.create_task(self._run_task(organization, created_task, max_steps=task_request.max_steps_per_run)) return CreateTaskResponse(task_id=created_task.task_id) async def get_task( @@ -138,12 +192,12 @@ class SkyvernAgent: task=task, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True ) - async def run_task( + async def run_task_v1( self, task_request: TaskRequest, timeout_seconds: int = 600, ) -> TaskResponse: - created_task = await self.create_task(task_request) + created_task = await self.create_task_v1(task_request) async with asyncio.timeout(timeout_seconds): while True: @@ -187,3 +241,148 @@ class SkyvernAgent: if refreshed_task_v2.status.is_final(): return refreshed_task_v2 await asyncio.sleep(1) + + ############### officially supported interfaces ############### + async def get_run(self, run_id: str) -> RunResponse | None: + if not self.client: + organization = await self._get_organization() + return await run_service.get_run_response(run_id, organization_id=organization.organization_id) + + return await self.client.get_run(run_id) + + async def run_task( + self, + prompt: str, + engine: RunEngine = RunEngine.skyvern_v1, + url: str | None = None, + webhook_url: str | None = None, + totp_identifier: str | None = None, + totp_url: str | None = None, + title: str | None = None, + error_code_mapping: dict[str, str] | None = None, + data_extraction_schema: dict[str, Any] | None = None, + proxy_location: ProxyLocation | None = None, + max_steps: int | None = None, + wait_for_completion: bool = True, + timeout: float = DEFAULT_AGENT_TIMEOUT, + browser_session_id: str | None = None, + ) -> TaskRunResponse: + if not self.client: + if engine == RunEngine.skyvern_v1: + data_extraction_goal = None + data_extraction_schema = data_extraction_schema + navigation_goal = prompt + navigation_payload = None + organization = await self._get_organization() + if not url: + task_generation = await task_v1_service.generate_task( + user_prompt=prompt, + organization=organization, + ) + url = task_generation.url + navigation_goal = task_generation.navigation_goal or prompt + navigation_payload = task_generation.navigation_payload + data_extraction_goal = task_generation.data_extraction_goal + data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema + + task_request = TaskRequest( + title=title, + url=url, + navigation_goal=navigation_goal, + navigation_payload=navigation_payload, + data_extraction_goal=data_extraction_goal, + extracted_information_schema=data_extraction_schema, + error_code_mapping=error_code_mapping, + proxy_location=proxy_location, + ) + + if wait_for_completion: + created_task = await app.agent.create_task(task_request, organization.organization_id) + url_hash = generate_url_hash(task_request.url) + await app.DATABASE.create_task_run( + task_run_type=RunType.task_v1, + organization_id=organization.organization_id, + run_id=created_task.task_id, + title=task_request.title, + url=task_request.url, + url_hash=url_hash, + ) + try: + await self._run_task(organization, created_task) + run_obj = await self.get_run(run_id=created_task.task_id) + return cast(TaskRunResponse, run_obj) + except Exception: + # TODO: better error handling and logging + run_obj = await self.get_run(run_id=created_task.task_id) + return cast(TaskRunResponse, run_obj) + else: + create_task_resp = await self.create_task_v1(task_request) + run_obj = await self.get_run(run_id=create_task_resp.task_id) + return cast(TaskRunResponse, run_obj) + elif engine == RunEngine.skyvern_v2: + # initialize task v2 + organization = await self._get_organization() + + task_v2 = await task_v2_service.initialize_task_v2( + organization=organization, + user_prompt=prompt, + user_url=url, + totp_identifier=totp_identifier, + totp_verification_url=totp_url, + webhook_callback_url=webhook_url, + proxy_location=proxy_location, + publish_workflow=False, + extracted_information_schema=data_extraction_schema, + error_code_mapping=error_code_mapping, + create_task_run=True, + ) + + if wait_for_completion: + await self._run_task_v2(organization, task_v2) + run_obj = await self.get_run(run_id=task_v2.observer_cruise_id) + return cast(TaskRunResponse, run_obj) + else: + asyncio.create_task(self._run_task_v2(organization, task_v2)) + run_obj = await self.get_run(run_id=task_v2.observer_cruise_id) + return cast(TaskRunResponse, run_obj) + else: + raise ValueError("Local mode is not supported for this method") + + task_run = await self.client.run_task( + prompt=prompt, + engine=engine, + url=url, + webhook_url=webhook_url, + totp_identifier=totp_identifier, + totp_url=totp_url, + title=title, + error_code_mapping=error_code_mapping, + proxy_location=proxy_location, + max_steps=max_steps, + ) + + if wait_for_completion: + async with asyncio.timeout(timeout): + while True: + task_run = await self.client.get_run(task_run.run_id) + if task_run.status.is_final(): + return task_run + await asyncio.sleep(DEFAULT_AGENT_HEARTBEAT_INTERVAL) + return task_run + + async def run_workflow( + self, + workflow_id: str, + parameters: dict[str, Any], + webhook_url: str | None = None, + totp_identifier: str | None = None, + totp_url: str | None = None, + title: str | None = None, + error_code_mapping: dict[str, str] | None = None, + proxy_location: ProxyLocation | None = None, + max_steps: int | None = None, + wait_for_completion: bool = True, + timeout: float = DEFAULT_AGENT_TIMEOUT, + browser_session_id: str | None = None, + ) -> None: + raise NotImplementedError("Running workflows is currently not supported with skyvern SDK.") diff --git a/skyvern/agent/constants.py b/skyvern/agent/constants.py new file mode 100644 index 00000000..7a0e6ab3 --- /dev/null +++ b/skyvern/agent/constants.py @@ -0,0 +1,2 @@ +DEFAULT_AGENT_TIMEOUT = 1800 # 30 minutes +DEFAULT_AGENT_HEARTBEAT_INTERVAL = 10 # 10 seconds