add observer task block (#1665)

This commit is contained in:
Shuchang Zheng 2025-01-28 16:59:54 +08:00 committed by GitHub
parent 1b79ef9ca3
commit 185fc330a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 224 additions and 22 deletions

View file

@ -48,7 +48,8 @@ from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core.validators import prepend_scheme_and_validate_url
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.schemas.observers import ObserverTaskStatus
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
FailedToFormatJinjaStyleParameter,
@ -71,6 +72,7 @@ LOG = structlog.get_logger()
class BlockType(StrEnum):
TASK = "task"
TaskV2 = "task_v2"
FOR_LOOP = "for_loop"
CODE = "code"
TEXT_PROMPT = "text_prompt"
@ -2072,6 +2074,80 @@ class UrlBlock(BaseTaskBlock):
url: str
# observer block
class TaskV2Block(Block):
block_type: Literal[BlockType.TaskV2] = BlockType.TaskV2
prompt: str
url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
max_iterations: int = 10
def get_all_parameters(
self,
workflow_run_id: str,
) -> list[PARAMETER_TYPE]:
return []
async def execute(
self,
workflow_run_id: str,
workflow_run_block_id: str,
organization_id: str | None = None,
browser_session_id: str | None = None,
**kwargs: dict,
) -> BlockResult:
from skyvern.forge.sdk.services import observer_service
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
if not organization_id:
raise ValueError("Running TaskV2Block requires organization_id")
organization = await app.DATABASE.get_organization(organization_id)
if not organization:
raise ValueError(f"Organization not found {organization_id}")
observer_task = await observer_service.initialize_observer_task(
organization,
user_prompt=self.prompt,
user_url=self.url,
parent_workflow_run_id=workflow_run_id,
proxy_location=ProxyLocation.NONE,
)
await app.DATABASE.update_observer_cruise(
observer_task.observer_cruise_id, status=ObserverTaskStatus.queued, organization_id=organization_id
)
if observer_task.workflow_run_id:
await app.DATABASE.update_workflow_run(
workflow_run_id=observer_task.workflow_run_id,
status=WorkflowRunStatus.queued,
)
await app.DATABASE.update_workflow_run_block(
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
block_workflow_run_id=observer_task.workflow_run_id,
)
observer_task = await observer_service.run_observer_task(
organization=organization,
observer_cruise_id=observer_task.observer_cruise_id,
request_id=None,
max_iterations_override=self.max_iterations,
browser_session_id=browser_session_id,
)
result_dict = None
if observer_task:
result_dict = observer_task.output
return await self.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=result_dict,
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
BlockSubclasses = Union[
ForLoopBlock,
TaskBlock,
@ -2090,5 +2166,6 @@ BlockSubclasses = Union[
WaitBlock,
FileDownloadBlock,
UrlBlock,
TaskV2Block,
]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]