Sync cloud skyvern to oss skyvern (#55)

This commit is contained in:
Kerem Yilmaz 2024-03-12 22:28:16 -07:00 committed by GitHub
parent 647ea2ac0f
commit 15d78d7b08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 554 additions and 163 deletions

View file

@ -13,7 +13,7 @@ from skyvern.exceptions import (
)
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import ContextManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
LOG = structlog.get_logger()
@ -33,8 +33,12 @@ class Block(BaseModel, abc.ABC):
def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__())
@staticmethod
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
pass
@abc.abstractmethod
@ -48,9 +52,12 @@ class TaskBlock(Block):
block_type: Literal[BlockType.TASK] = BlockType.TASK
url: str | None = None
title: str = "Untitled Task"
navigation_goal: str | None = None
data_extraction_goal: str | None = None
data_schema: dict[str, Any] | None = None
# error code to error description for the LLM
error_code_mapping: dict[str, str] | None = None
max_retries: int = 0
parameters: list[PARAMETER_TYPE] = []
@ -89,8 +96,8 @@ class TaskBlock(Block):
return order, retry + 1
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
task = None
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
current_retry = 0
# initial value for will_retry is True, so that the loop runs at least once
will_retry = True
@ -104,7 +111,7 @@ class TaskBlock(Block):
task_block=self,
workflow=workflow,
workflow_run=workflow_run,
context_manager=context_manager,
workflow_run_context=workflow_run_context,
task_order=task_order,
task_retry=task_retry,
)
@ -131,7 +138,18 @@ class TaskBlock(Block):
if self.url:
await browser_state.page.goto(self.url)
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
try:
await app.agent.execute_step(organization=organization, task=task, step=step, workflow_run=workflow_run)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow.organization_id,
failure_reason=str(e),
)
raise e
# Check task status
updated_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=workflow.organization_id)
if not updated_task:
@ -188,9 +206,9 @@ class ForLoopBlock(Block):
return context_parameters
def get_loop_over_parameter_values(self, context_manager: ContextManager) -> list[Any]:
def get_loop_over_parameter_values(self, workflow_run_context: WorkflowRunContext) -> list[Any]:
if isinstance(self.loop_over, WorkflowParameter):
parameter_value = context_manager.get_value(self.loop_over.key)
parameter_value = workflow_run_context.get_value(self.loop_over.key)
if isinstance(parameter_value, list):
return parameter_value
else:
@ -200,8 +218,9 @@ class ForLoopBlock(Block):
# TODO (kerem): Implement this for context parameters
raise NotImplementedError
async def execute(self, workflow_run_id: str, context_manager: ContextManager, **kwargs: dict) -> Any:
loop_over_values = self.get_loop_over_parameter_values(context_manager)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
LOG.info(
f"Number of loop_over values: {len(loop_over_values)}",
block_type=self.block_type,
@ -211,8 +230,8 @@ class ForLoopBlock(Block):
for loop_over_value in loop_over_values:
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
context_manager.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id, context_manager=context_manager)
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
await self.loop_block.execute(workflow_run_id=workflow_run_id)
return None