diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 89412395..6be9ee80 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -3,11 +3,11 @@ import abc import structlog from fastapi import BackgroundTasks +from skyvern.exceptions import OrganizationNotFound from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext -from skyvern.forge.sdk.models import Organization -from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus +from skyvern.forge.sdk.schemas.tasks import TaskStatus LOG = structlog.get_logger() @@ -17,10 +17,11 @@ class AsyncExecutor(abc.ABC): async def execute_task( self, background_tasks: BackgroundTasks, - task: Task, - organization: Organization, + task_id: str, + organization_id: str, max_steps_override: int | None, api_key: str | None, + **kwargs: dict, ) -> None: pass @@ -28,11 +29,12 @@ class AsyncExecutor(abc.ABC): async def execute_workflow( self, background_tasks: BackgroundTasks, - organization: Organization, + organization_id: str, workflow_id: str, workflow_run_id: str, max_steps_override: int | None, api_key: str | None, + **kwargs: dict, ) -> None: pass @@ -41,28 +43,34 @@ class BackgroundTaskExecutor(AsyncExecutor): async def execute_task( self, background_tasks: BackgroundTasks, - task: Task, - organization: Organization, + task_id: str, + organization_id: str, max_steps_override: int | None, api_key: str | None, + **kwargs: dict, ) -> None: - LOG.info("Executing task using background task executor", task_id=task.task_id) + LOG.info("Executing task using background task executor", task_id=task_id) + + organization = await app.DATABASE.get_organization(organization_id) + if organization is None: + raise OrganizationNotFound(organization_id) + step = await app.DATABASE.create_step( - task.task_id, + task_id, order=0, retry_index=0, - organization_id=organization.organization_id, + organization_id=organization_id, ) task = await app.DATABASE.update_task( - task.task_id, + task_id, status=TaskStatus.running, - organization_id=organization.organization_id, + organization_id=organization_id, ) context: SkyvernContext = skyvern_context.ensure_context() context.task_id = task.task_id - context.organization_id = organization.organization_id + context.organization_id = organization_id context.max_steps_override = max_steps_override background_tasks.add_task( @@ -76,11 +84,12 @@ class BackgroundTaskExecutor(AsyncExecutor): async def execute_workflow( self, background_tasks: BackgroundTasks, - organization: Organization, + organization_id: str, workflow_id: str, workflow_run_id: str, max_steps_override: int | None, api_key: str | None, + **kwargs: dict, ) -> None: LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id) background_tasks.add_task( diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 69471c40..831d9b6e 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -96,8 +96,8 @@ async def create_agent_task( LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) await AsyncExecutorFactory.get_executor().execute_task( background_tasks=background_tasks, - task=created_task, - organization=current_org, + task_id=created_task.task_id, + organization_id=current_org.organization_id, max_steps_override=x_max_steps_override, api_key=x_api_key, ) @@ -422,7 +422,7 @@ async def execute_workflow( LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) await AsyncExecutorFactory.get_executor().execute_workflow( background_tasks=background_tasks, - organization=current_org, + organization_id=current_org.organization_id, workflow_id=workflow_id, workflow_run_id=workflow_run.workflow_run_id, max_steps_override=x_max_steps_override,