mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 02:30:07 +00:00
Introduce ActionFunction to make it easy to patch and do extra validations before step starts (#365)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
parent
d3d38e2647
commit
2b4829f87a
4 changed files with 76 additions and 36 deletions
|
@ -262,3 +262,8 @@ class BitwardenLogoutError(BitwardenBaseError):
|
||||||
class UnknownElementTreeFormat(SkyvernException):
|
class UnknownElementTreeFormat(SkyvernException):
|
||||||
def __init__(self, fmt: str) -> None:
|
def __init__(self, fmt: str) -> None:
|
||||||
super().__init__(f"Unknown element tree format {fmt}")
|
super().__init__(f"Unknown element tree format {fmt}")
|
||||||
|
|
||||||
|
|
||||||
|
class StepTerminationError(SkyvernException):
|
||||||
|
def __init__(self, step_id: str, reason: str) -> None:
|
||||||
|
super().__init__(f"Step {step_id} cannot be executed and task is terminated. Reason: {reason}")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from skyvern.exceptions import (
|
||||||
FailedToSendWebhook,
|
FailedToSendWebhook,
|
||||||
InvalidWorkflowTaskURLState,
|
InvalidWorkflowTaskURLState,
|
||||||
MissingBrowserStatePage,
|
MissingBrowserStatePage,
|
||||||
|
StepTerminationError,
|
||||||
TaskNotFound,
|
TaskNotFound,
|
||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
|
@ -79,34 +80,6 @@ class ForgeAgent:
|
||||||
)
|
)
|
||||||
self.async_operation_pool = AsyncOperationPool()
|
self.async_operation_pool = AsyncOperationPool()
|
||||||
|
|
||||||
async def validate_step_execution(
|
|
||||||
self,
|
|
||||||
task: Task,
|
|
||||||
step: Step,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Checks if the step can be executed.
|
|
||||||
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
|
|
||||||
"""
|
|
||||||
reasons = []
|
|
||||||
# can't execute if task status is not running
|
|
||||||
has_valid_task_status = task.status == TaskStatus.running
|
|
||||||
if not has_valid_task_status:
|
|
||||||
reasons.append(f"invalid_task_status:{task.status}")
|
|
||||||
# can't execute if the step is already running or completed
|
|
||||||
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
|
|
||||||
if not has_valid_step_status:
|
|
||||||
reasons.append(f"invalid_step_status:{step.status}")
|
|
||||||
# can't execute if the task has another step that is running
|
|
||||||
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
|
||||||
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
|
|
||||||
if not has_no_running_steps:
|
|
||||||
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
|
|
||||||
|
|
||||||
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
|
|
||||||
if not can_execute:
|
|
||||||
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
|
|
||||||
|
|
||||||
async def create_task_and_step_from_block(
|
async def create_task_and_step_from_block(
|
||||||
self,
|
self,
|
||||||
task_block: TaskBlock,
|
task_block: TaskBlock,
|
||||||
|
@ -211,9 +184,7 @@ class ForgeAgent:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None:
|
def register_async_operations(self, organization: Organization, task: Task, page: Page) -> None:
|
||||||
if not app.generate_async_operations:
|
operations = app.AGENT_FUNCTION.generate_async_operations(organization, task, page)
|
||||||
return
|
|
||||||
operations = app.generate_async_operations(organization, task, page)
|
|
||||||
self.async_operation_pool.add_operations(task.task_id, operations)
|
self.async_operation_pool.add_operations(task.task_id, operations)
|
||||||
|
|
||||||
async def execute_step(
|
async def execute_step(
|
||||||
|
@ -229,7 +200,7 @@ class ForgeAgent:
|
||||||
detailed_output: DetailedAgentStepOutput | None = None
|
detailed_output: DetailedAgentStepOutput | None = None
|
||||||
try:
|
try:
|
||||||
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
||||||
await self.validate_step_execution(task, step)
|
await app.AGENT_FUNCTION.validate_step_execution(task, step)
|
||||||
(
|
(
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
|
@ -323,6 +294,28 @@ class ForgeAgent:
|
||||||
|
|
||||||
return step, detailed_output, next_step
|
return step, detailed_output, next_step
|
||||||
# TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well
|
# TODO (kerem): Let's add other exceptions that we know about here as custom exceptions as well
|
||||||
|
except StepTerminationError as e:
|
||||||
|
LOG.error(
|
||||||
|
"Step cannot be executed. Task terminated",
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
)
|
||||||
|
await self.update_step(
|
||||||
|
step=step,
|
||||||
|
status=StepStatus.failed,
|
||||||
|
)
|
||||||
|
task = await self.update_task(
|
||||||
|
task,
|
||||||
|
status=TaskStatus.failed,
|
||||||
|
failure_reason=e.message,
|
||||||
|
)
|
||||||
|
await self.send_task_response(
|
||||||
|
task=task,
|
||||||
|
last_step=step,
|
||||||
|
api_key=api_key,
|
||||||
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
)
|
||||||
|
return step, detailed_output, next_step
|
||||||
except FailedToSendWebhook:
|
except FailedToSendWebhook:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Failed to send webhook",
|
"Failed to send webhook",
|
||||||
|
|
44
skyvern/forge/agent_functions.py
Normal file
44
skyvern/forge/agent_functions.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
from playwright.async_api import Page
|
||||||
|
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.async_operations import AsyncOperation
|
||||||
|
from skyvern.forge.sdk.models import Organization, Step, StepStatus
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFunction:
|
||||||
|
async def validate_step_execution(
|
||||||
|
self,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Checks if the step can be executed.
|
||||||
|
:return: A tuple of whether the step can be executed and a list of reasons why it can't be executed.
|
||||||
|
"""
|
||||||
|
reasons = []
|
||||||
|
# can't execute if task status is not running
|
||||||
|
has_valid_task_status = task.status == TaskStatus.running
|
||||||
|
if not has_valid_task_status:
|
||||||
|
reasons.append(f"invalid_task_status:{task.status}")
|
||||||
|
# can't execute if the step is already running or completed
|
||||||
|
has_valid_step_status = step.status in [StepStatus.created, StepStatus.failed]
|
||||||
|
if not has_valid_step_status:
|
||||||
|
reasons.append(f"invalid_step_status:{step.status}")
|
||||||
|
# can't execute if the task has another step that is running
|
||||||
|
steps = await app.DATABASE.get_task_steps(task_id=task.task_id, organization_id=task.organization_id)
|
||||||
|
has_no_running_steps = not any(step.status == StepStatus.running for step in steps)
|
||||||
|
if not has_no_running_steps:
|
||||||
|
reasons.append(f"another_step_is_running_for_task:{task.task_id}")
|
||||||
|
|
||||||
|
can_execute = has_valid_task_status and has_valid_step_status and has_no_running_steps
|
||||||
|
if not can_execute:
|
||||||
|
raise Exception(f"Cannot execute step. Reasons: {reasons}, Step: {step}")
|
||||||
|
|
||||||
|
def generate_async_operations(
|
||||||
|
self,
|
||||||
|
organization: Organization,
|
||||||
|
task: Task,
|
||||||
|
page: Page,
|
||||||
|
) -> list[AsyncOperation]:
|
||||||
|
return []
|
|
@ -3,10 +3,9 @@ from typing import Awaitable, Callable
|
||||||
from ddtrace import tracer
|
from ddtrace import tracer
|
||||||
from ddtrace.filters import FilterRequestsOnUrl
|
from ddtrace.filters import FilterRequestsOnUrl
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from playwright.async_api import Page
|
|
||||||
|
|
||||||
from skyvern.forge.agent import ForgeAgent
|
from skyvern.forge.agent import ForgeAgent
|
||||||
from skyvern.forge.async_operations import AsyncOperation
|
from skyvern.forge.agent_functions import AgentFunction
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||||
|
@ -14,7 +13,6 @@ from skyvern.forge.sdk.db.client import AgentDB
|
||||||
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
||||||
from skyvern.forge.sdk.forge_log import setup_logger
|
from skyvern.forge.sdk.forge_log import setup_logger
|
||||||
from skyvern.forge.sdk.models import Organization
|
from skyvern.forge.sdk.models import Organization
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task
|
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
||||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||||
|
@ -41,7 +39,7 @@ EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvi
|
||||||
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
||||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||||
WORKFLOW_SERVICE = WorkflowService()
|
WORKFLOW_SERVICE = WorkflowService()
|
||||||
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
|
AGENT_FUNCTION = AgentFunction()
|
||||||
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
|
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
|
||||||
setup_api_app: Callable[[FastAPI], None] | None = None
|
setup_api_app: Callable[[FastAPI], None] | None = None
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue