Introduce ActionFunction to make it easy to patch and do extra validations before step starts (#365)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-25 18:24:35 -07:00 committed by GitHub
parent d3d38e2647
commit 2b4829f87a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 36 deletions

View file

@ -262,3 +262,8 @@ class BitwardenLogoutError(BitwardenBaseError):
class UnknownElementTreeFormat(SkyvernException): class UnknownElementTreeFormat(SkyvernException):
def __init__(self, fmt: str) -> None: def __init__(self, fmt: str) -> None:
super().__init__(f"Unknown element tree format {fmt}") super().__init__(f"Unknown element tree format {fmt}")
class StepTerminationError(SkyvernException):
def __init__(self, step_id: str, reason: str) -> None:
super().__init__(f"Step {step_id} cannot be executed and task is terminated. Reason: {reason}")

View file

@ -17,6 +17,7 @@ from skyvern.exceptions import (
FailedToSendWebhook, FailedToSendWebhook,
InvalidWorkflowTaskURLState, InvalidWorkflowTaskURLState,
MissingBrowserStatePage, MissingBrowserStatePage,
StepTerminationError,
TaskNotFound, TaskNotFound,
) )
from skyvern.forge import app from skyvern.forge import app
@ -79,34 +80,6 @@ class ForgeAgent:
) )
self.async_operation_pool = AsyncOperationPool() self.async_operation_pool = AsyncOperationPool()
async def validate_step_execution(
self,
task: Task,
step: Step,
) -> None:
"""
Checks if the step can be executed.
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
"""
reasons = []
# can't execute if task status is not running
has_valid_task_status = task.status == TaskStatus.running
if not has_valid_task_status:
reasons.append(f"invalid_task_status:{task.status}")
# can't execute if the step is already running or completed
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
if not has_valid_step_status:
reasons.append(f"invalid_step_status:{step.status}")
# can't execute if the task has another step that is running
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
if not has_no_running_steps:
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
if not can_execute:
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
async def create_task_and_step_from_block( async def create_task_and_step_from_block(
self, self,
task_block: TaskBlock, task_block: TaskBlock,
@ -211,9 +184,7 @@ class ForgeAgent:
return task return task
def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None: def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None:
if not app.generate_async_operations: operations = app.AGENT_FUNCTION.generate_async_operations(organization, task, page)
return
operations = app.generate_async_operations(organization, task, page)
self.async_operation_pool.add_operations(task.task_id, operations) self.async_operation_pool.add_operations(task.task_id, operations)
async def execute_step( async def execute_step(
@ -229,7 +200,7 @@ class ForgeAgent:
detailed_output: DetailedAgentStepOutput | None = None detailed_output: DetailedAgentStepOutput | None = None
try: try:
# Check some conditions before executing the step, throw an exception if the step can't be executed # Check some conditions before executing the step, throw an exception if the step can't be executed
await self.validate_step_execution(task, step) await app.AGENT_FUNCTION.validate_step_execution(task, step)
( (
step, step,
browser_state, browser_state,
@ -323,6 +294,28 @@ class ForgeAgent:
return step, detailed_output, next_step return step, detailed_output, next_step
# TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well # TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well
except StepTerminationError as e:
LOG.error(
"Step cannot be executed. Task terminated",
task_id=task.task_id,
step_id=step.step_id,
)
await self.update_step(
step=step,
status=StepStatus.failed,
)
task = await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=e.message,
)
await self.send_task_response(
task=task,
last_step=step,
api_key=api_key,
close_browser_on_completion=close_browser_on_completion,
)
return step, detailed_output, next_step
except FailedToSendWebhook: except FailedToSendWebhook:
LOG.exception( LOG.exception(
"Failed to send webhook", "Failed to send webhook",

View file

@ -0,0 +1,44 @@
from playwright.async_api import Page
from skyvern.forge import app
from skyvern.forge.async_operations import AsyncOperation
from skyvern.forge.sdk.models import Organization, Step, StepStatus
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
class AgentFunction:
async def validate_step_execution(
self,
task: Task,
step: Step,
) -> None:
"""
Checks if the step can be executed.
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
"""
reasons = []
# can't execute if task status is not running
has_valid_task_status = task.status == TaskStatus.running
if not has_valid_task_status:
reasons.append(f"invalid_task_status:{task.status}")
# can't execute if the step is already running or completed
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
if not has_valid_step_status:
reasons.append(f"invalid_step_status:{step.status}")
# can't execute if the task has another step that is running
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
if not has_no_running_steps:
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
if not can_execute:
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
def generate_async_operations(
self,
organization: Organization,
task: Task,
page: Page,
) -> list[AsyncOperation]:
return []

View file

@ -3,10 +3,9 @@ from typing import Awaitable, Callable
from ddtrace import tracer from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl from ddtrace.filters import FilterRequestsOnUrl
from fastapi import FastAPI from fastapi import FastAPI
from playwright.async_api import Page
from skyvern.forge.agent import ForgeAgent from skyvern.forge.agent import ForgeAgent
from skyvern.forge.async_operations import AsyncOperation from skyvern.forge.agent_functions import AgentFunction
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.manager import ArtifactManager from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
@ -14,7 +13,6 @@ from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
from skyvern.forge.sdk.forge_log import setup_logger from skyvern.forge.sdk.forge_log import setup_logger
from skyvern.forge.sdk.models import Organization from skyvern.forge.sdk.models import Organization
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
from skyvern.forge.sdk.workflow.service import WorkflowService from skyvern.forge.sdk.workflow.service import WorkflowService
@ -41,7 +39,7 @@ EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvi
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY) LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService() WORKFLOW_SERVICE = WorkflowService()
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None AGENT_FUNCTION = AgentFunction()
authentication_function: Callable[[str], Awaitable[Organization]] | None = None authentication_function: Callable[[str], Awaitable[Organization]] | None = None
setup_api_app: Callable[[FastAPI], None] | None = None setup_api_app: Callable[[FastAPI], None] | None = None