Cleanup executor logic (#84)

This commit is contained in:
Kerem Yilmaz 2024-03-14 23:07:04 -07:00 committed by GitHub
parent ff4be0de9e
commit eda6e07d36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 17 deletions

View file

@ -3,11 +3,11 @@ import abc
import structlog import structlog
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from skyvern.exceptions import OrganizationNotFound
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Organization from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -17,10 +17,11 @@ class AsyncExecutor(abc.ABC):
async def execute_task( async def execute_task(
self, self,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
task: Task, task_id: str,
organization: Organization, organization_id: str,
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
**kwargs: dict,
) -> None: ) -> None:
pass pass
@ -28,11 +29,12 @@ class AsyncExecutor(abc.ABC):
async def execute_workflow( async def execute_workflow(
self, self,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
organization: Organization, organization_id: str,
workflow_id: str, workflow_id: str,
workflow_run_id: str, workflow_run_id: str,
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
**kwargs: dict,
) -> None: ) -> None:
pass pass
@ -41,28 +43,34 @@ class BackgroundTaskExecutor(AsyncExecutor):
async def execute_task( async def execute_task(
self, self,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
task: Task, task_id: str,
organization: Organization, organization_id: str,
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
**kwargs: dict,
) -> None: ) -> None:
LOG.info("Executing task using background task executor", task_id=task.task_id) LOG.info("Executing task using background task executor", task_id=task_id)
organization = await app.DATABASE.get_organization(organization_id)
if organization is None:
raise OrganizationNotFound(organization_id)
step = await app.DATABASE.create_step( step = await app.DATABASE.create_step(
task.task_id, task_id,
order=0, order=0,
retry_index=0, retry_index=0,
organization_id=organization.organization_id, organization_id=organization_id,
) )
task = await app.DATABASE.update_task( task = await app.DATABASE.update_task(
task.task_id, task_id,
status=TaskStatus.running, status=TaskStatus.running,
organization_id=organization.organization_id, organization_id=organization_id,
) )
context: SkyvernContext = skyvern_context.ensure_context() context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id context.task_id = task.task_id
context.organization_id = organization.organization_id context.organization_id = organization_id
context.max_steps_override = max_steps_override context.max_steps_override = max_steps_override
background_tasks.add_task( background_tasks.add_task(
@ -76,11 +84,12 @@ class BackgroundTaskExecutor(AsyncExecutor):
async def execute_workflow( async def execute_workflow(
self, self,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
organization: Organization, organization_id: str,
workflow_id: str, workflow_id: str,
workflow_run_id: str, workflow_run_id: str,
max_steps_override: int | None, max_steps_override: int | None,
api_key: str | None, api_key: str | None,
**kwargs: dict,
) -> None: ) -> None:
LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id) LOG.info("Executing workflow using background task executor", workflow_run_id=workflow_run_id)
background_tasks.add_task( background_tasks.add_task(

View file

@ -96,8 +96,8 @@ async def create_agent_task(
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_task( await AsyncExecutorFactory.get_executor().execute_task(
background_tasks=background_tasks, background_tasks=background_tasks,
task=created_task, task_id=created_task.task_id,
organization=current_org, organization_id=current_org.organization_id,
max_steps_override=x_max_steps_override, max_steps_override=x_max_steps_override,
api_key=x_api_key, api_key=x_api_key,
) )
@ -422,7 +422,7 @@ async def execute_workflow(
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override) LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
await AsyncExecutorFactory.get_executor().execute_workflow( await AsyncExecutorFactory.get_executor().execute_workflow(
background_tasks=background_tasks, background_tasks=background_tasks,
organization=current_org, organization_id=current_org.organization_id,
workflow_id=workflow_id, workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
max_steps_override=x_max_steps_override, max_steps_override=x_max_steps_override,