Add user-defined error detection on task failure (#4974)

This commit is contained in:
LawyZheng 2026-03-04 15:38:27 +08:00 committed by GitHub
parent d87a229e2b
commit c6d62e3fa0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1498 additions and 42 deletions

View file

@ -104,6 +104,7 @@ from skyvern.schemas.runs import CUA_ENGINES, RunEngine
from skyvern.schemas.steps import AgentStepOutput
from skyvern.services import run_service, service_utils
from skyvern.services.action_service import get_action_history
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
from skyvern.services.otp_service import (
extract_totp_from_navigation_inputs,
poll_otp_value,
@ -419,6 +420,7 @@ class ForgeAgent:
next_step: Step | None = None
detailed_output: DetailedAgentStepOutput | None = None
list_files_before: list[str] = []
browser_state: BrowserState | None = None
try:
if task.workflow_run_id:
list_files_before = list_files_in_directory(
@ -727,7 +729,7 @@ class ForgeAgent:
"Step cannot be executed, marking task as failed",
exc_info=True,
)
is_task_marked_as_failed = await self.fail_task(task, step, e.message)
is_task_marked_as_failed = await self.fail_task(task, step, e.message, browser_state)
if is_task_marked_as_failed:
await self.clean_up_task(
task=task,
@ -753,7 +755,7 @@ class ForgeAgent:
url=e.url,
)
failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}"
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason, browser_state)
if is_task_marked_as_failed:
await self.clean_up_task(
task=task,
@ -798,7 +800,7 @@ class ForgeAgent:
step_order=step.order,
step_retry=step.retry_index,
)
await self.fail_task(task, step, e.message)
await self.fail_task(task, step, e.message, browser_state)
await self.clean_up_task(
task=task,
last_step=step,
@ -819,6 +821,7 @@ class ForgeAgent:
step,
sfe.reason
or "Skyvern failed to load the website. This usually happens when the website is not properly designed, and crashes the browser as a result.",
browser_state,
)
await self.clean_up_task(
task=task,
@ -834,6 +837,7 @@ class ForgeAgent:
task,
step,
"The browser does not have a valid page for skyvern to operate. This may be due to the website being empty or the browser crashing.",
browser_state,
)
await self.clean_up_task(
task=task,
@ -848,7 +852,7 @@ class ForgeAgent:
failure_reason = get_user_facing_exception_message(e)
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason, browser_state)
if is_task_marked_as_failed:
await self.clean_up_task(
task=task,
@ -866,7 +870,9 @@ class ForgeAgent:
context.step_id = None
context.task_id = None
async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> bool:
async def fail_task(
self, task: Task, step: Step | None, reason: str | None, browser_state: BrowserState | None = None
) -> bool:
try:
if step is not None:
await self.update_step(
@ -874,11 +880,50 @@ class ForgeAgent:
status=StepStatus.failed,
)
# Update task status first
await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=reason,
)
# Detect user-defined errors if error_code_mapping is provided
if task.error_code_mapping and step is not None:
LOG.info(
"Task has error_code_mapping, attempting to detect user-defined errors",
task_id=task.task_id,
step_id=step.step_id,
error_code_mapping=task.error_code_mapping,
)
try:
detected_errors = await detect_user_defined_errors_for_task(
task=task,
step=step,
browser_state=browser_state,
failure_reason=reason,
)
# Update task errors if any were detected
# Only pass new errors — update_task() appends to existing errors
if detected_errors:
new_errors = [error.model_dump() for error in detected_errors]
await app.DATABASE.update_task(
task_id=task.task_id,
organization_id=task.organization_id,
errors=new_errors,
)
LOG.info(
"Updated task with detected user-defined errors",
task_id=task.task_id,
error_codes=[e.error_code for e in detected_errors],
)
except Exception:
LOG.exception(
"Failed to detect or store user-defined errors during task failure",
task_id=task.task_id,
)
return True
except TaskAlreadyCanceled:
LOG.info(
@ -4028,7 +4073,7 @@ class ForgeAgent:
if browser_state is not None:
page = await browser_state.get_working_page()
failure_reason = await self.summary_failure_reason_for_max_retries(
failure_response = await self.summary_failure_reason_for_max_retries(
organization=organization,
task=task,
step=step,
@ -4036,13 +4081,23 @@ class ForgeAgent:
max_retries=max_retries_per_step,
)
# Only pass new errors — update_task() appends to existing errors in the DB
new_errors: list[dict[str, Any]] = [ReachMaxRetriesError().model_dump()]
if failure_response.errors:
new_errors.extend([error.model_dump() for error in failure_response.errors])
LOG.info(
"Detected user-defined errors for max retries failure",
task_id=task.task_id,
error_codes=[e.error_code for e in failure_response.errors],
)
await self.update_task(
task,
TaskStatus.failed,
failure_reason=(
f"Max retries per step ({max_retries_per_step}) exceeded. Possible failure reasons: {failure_reason}"
f"Max retries per step ({max_retries_per_step}) exceeded. Possible failure reasons: {failure_response.reasoning}"
),
errors=[ReachMaxRetriesError().model_dump()],
errors=new_errors,
)
return None
else:
@ -4180,7 +4235,7 @@ class ForgeAgent:
step: Step,
page: Page | None,
max_retries: int,
) -> str:
) -> MaxStepsReasonResponse:
html = ""
screenshots: list[bytes] = []
steps_results = []
@ -4231,18 +4286,26 @@ class ForgeAgent:
# If we detected LLM errors, return a clear message without calling the LLM
if llm_errors:
llm_error_details = "; ".join(llm_errors)
return (
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
f"This is typically caused by rate limiting, service outages, or resource exhaustion from the LLM provider. "
f"Error details: {llm_error_details}"
return MaxStepsReasonResponse(
page_info="",
reasoning=(
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
f"This is typically caused by rate limiting, service outages, or resource exhaustion from the LLM provider. "
f"Error details: {llm_error_details}"
),
errors=[],
)
# If multiple steps failed without producing any actions, it's likely an LLM error during action extraction
if steps_without_actions >= max_retries:
return (
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
f"This is typically caused by LLM service errors during action extraction, such as rate limiting, "
f"service outages, or resource exhaustion from the LLM provider. Please check the LLM service status and try again."
return MaxStepsReasonResponse(
page_info="",
reasoning=(
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
f"This is typically caused by LLM service errors during action extraction, such as rate limiting, "
f"service outages, or resource exhaustion from the LLM provider. Please check the LLM service status and try again."
),
errors=[],
)
if page is not None:
@ -4257,6 +4320,7 @@ class ForgeAgent:
steps=steps_results,
page_html=html,
max_retries=max_retries,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
local_datetime=datetime.now(skyvern_context.ensure_context().tz_info).isoformat(),
)
json_response = await app.SECONDARY_LLM_API_HANDLER(
@ -4265,26 +4329,42 @@ class ForgeAgent:
step=step,
prompt_name="summarize-max-retries-reason",
)
return json_response.get("reasoning", "")
return MaxStepsReasonResponse.model_validate(json_response)
except Exception:
LOG.warning("Failed to summarize the failure reason for max retries")
# Check if we have LLM errors even if the summarization failed
if llm_errors:
llm_error_details = "; ".join(llm_errors)
return (
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
f"Error details: {llm_error_details}"
return MaxStepsReasonResponse(
page_info="",
reasoning=(
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
f"Error details: {llm_error_details}"
),
errors=[],
)
# If multiple steps failed without actions during summarization failure, still report it
if steps_without_actions >= max_retries:
return (
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
f"This is typically caused by LLM service errors during action extraction."
return MaxStepsReasonResponse(
page_info="",
reasoning=(
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
f"This is typically caused by LLM service errors during action extraction."
),
errors=[],
)
if steps_results:
last_step_result = steps_results[-1]
return f"Retry Step {last_step_result['order']}: {last_step_result['actions_result']}"
return ""
return MaxStepsReasonResponse(
page_info="",
reasoning=f"Retry Step {last_step_result['order']}: {last_step_result['actions_result']}",
errors=[],
)
return MaxStepsReasonResponse(
page_info="",
reasoning="",
errors=[],
)
async def handle_completed_step(
self,

View file

@ -4,12 +4,22 @@ Make sure to ONLY return the JSON object in this format with no additional text
```json
{
"page_info": str, // Think step by step. Describe useful information from the page HTML related to the user goal.
"reasoning": str, // Think step by step. Summarize why the actions failed based on 'page_info', screenshots, user goal and the failed actions. Keep it short and to the point.
"reasoning": str, // Think step by step. Summarize why the actions failed based on 'page_info', screenshots, user goal and the failed actions. Keep it short and to the point.{% if error_code_mapping_str %}
"errors": array // A list of errors. This is used to surface any errors that matches the current situation. If no error description suits the current situation on the screenshots or the action history, return an empty list. You are allowed to return multiple errors if there are multiple errors on the page.
[{
"error_code": str, // The error code from the user's error code list
"reasoning": str, // The reasoning behind the error. Be specific, referencing any user information and their fields in your reasoning. Keep the reasoning short and to the point.
"confidence_float": float // The confidence of the error. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
}]{% endif %}
}
```
User Goal:
{{ navigation_goal }}
{% if error_code_mapping_str %}
Use the error codes and their descriptions to surface user-defined errors. Do not return any error that's not defined by the user. User defined errors:
{{ error_code_mapping_str }}
{% endif %}
User Details:
{{ navigation_payload }}

View file

@ -1,4 +1,4 @@
You are here to help the user use the error codes and their descriptions to surface user-defined errors based on the screenshots, user goal, user details, action history{{ ", context" if reasoning else "" }} and the HTML elements.
You are here to help the user use the error codes and their descriptions to surface user-defined errors based on the screenshots(if provided), user goal, user details, action history{{ ", context" if reasoning else "" }} and the HTML elements.
Do not return any error that's not defined by the user.
Reply in JSON format with the following keys:

View file

@ -71,6 +71,7 @@ from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_request
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
@ -102,6 +103,7 @@ from skyvern.forge.sdk.workflow.models.parameter import (
)
from skyvern.schemas.runs import RunEngine
from skyvern.schemas.workflows import BlockResult, BlockStatus, BlockType, FileStorageType, FileType
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
from skyvern.utils.strings import generate_random_string
from skyvern.utils.templating import get_missing_variables
from skyvern.utils.token_counter import count_tokens
@ -705,6 +707,49 @@ class BaseTaskBlock(Block):
return order, retry + 1
async def _handle_task_failure_with_error_detection(
self,
task: Task,
step: Step,
browser_state: BrowserState | None,
failure_reason: str,
organization_id: str,
) -> None:
"""
Handle task failure by updating the task status and detecting user-defined errors.
This helper method consolidates the error detection logic that was previously
duplicated across multiple exception handlers in the execute method.
"""
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=organization_id,
failure_reason=failure_reason,
)
# Detect user-defined errors if error_code_mapping is provided
if self.error_code_mapping:
try:
detected_errors = await detect_user_defined_errors_for_task(
task=task,
step=step,
browser_state=browser_state,
failure_reason=failure_reason,
)
if detected_errors:
# Only pass new errors — update_task() appends to existing errors
new_errors = [error.model_dump() for error in detected_errors]
await app.DATABASE.update_task(
task_id=task.task_id,
organization_id=organization_id,
errors=new_errors,
)
except Exception:
LOG.exception(
"Failed to detect or store user-defined errors during task failure",
task_id=task.task_id,
)
async def execute(
self,
workflow_run_id: str,
@ -850,12 +895,12 @@ class BaseTaskBlock(Block):
task_id=task.task_id,
workflow_run_id=workflow_run_id,
)
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow_run.organization_id,
await self._handle_task_failure_with_error_detection(
task=task,
step=step,
browser_state=browser_state,
failure_reason=str(e),
organization_id=workflow_run.organization_id,
)
raise e
@ -902,11 +947,12 @@ class BaseTaskBlock(Block):
try:
await browser_state.navigate_to_url(page=working_page, url=self.url)
except Exception as e:
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow_run.organization_id,
await self._handle_task_failure_with_error_detection(
task=task,
step=step,
browser_state=browser_state,
failure_reason=str(e),
organization_id=workflow_run.organization_id,
)
raise e
@ -926,11 +972,12 @@ class BaseTaskBlock(Block):
)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception
await app.DATABASE.update_task(
task.task_id,
status=TaskStatus.failed,
organization_id=workflow_run.organization_id,
await self._handle_task_failure_with_error_detection(
task=task,
step=step,
browser_state=browser_state,
failure_reason=str(e),
organization_id=workflow_run.organization_id,
)
raise e
finally:

View file

@ -0,0 +1,257 @@
"""
Service for detecting user-defined errors when tasks fail.
This module provides a centralized error detection service that can be used
by both agent execution and script execution to detect user-defined errors
based on the current page state or failure context.
"""
import asyncio
import json
from datetime import datetime
import structlog
from playwright.async_api import Page
from skyvern.errors.errors import UserDefinedError
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.services.action_service import get_action_history
from skyvern.webeye.actions.handler import extract_user_defined_errors
from skyvern.webeye.browser_state import BrowserState
LOG = structlog.get_logger()
async def detect_user_defined_errors_for_task(
task: Task,
step: Step,
browser_state: BrowserState | None = None,
failure_reason: str | None = None,
) -> list[UserDefinedError]:
"""
Detect user-defined errors for a failed task.
This function uses the existing extract_user_defined_errors when browser_state
and page are available. When they're not available (early failures), it falls
back to detecting errors from the failure context.
Args:
task: The task that failed
step: The last step executed
browser_state: Optional browser state (may be None in early failures)
failure_reason: The reason for task failure (used when browser_state unavailable)
Returns:
List of detected UserDefinedError objects (empty if detection fails)
"""
# Skip detection if no error_code_mapping defined
if not task.error_code_mapping:
LOG.debug(
"No error_code_mapping defined for task, skipping error detection",
task_id=task.task_id,
step_id=step.step_id,
)
return []
try:
# Try to use full page-based detection if browser state is available
if browser_state is not None:
page = await browser_state.get_working_page()
if page is not None:
LOG.info(
"Using page-based error detection",
task_id=task.task_id,
step_id=step.step_id,
url=page.url,
)
return await _detect_errors_from_page(task, step, page, browser_state, failure_reason)
# Fall back to context-based detection when page is not available
LOG.info(
"Browser state or page not available, using context-based error detection",
task_id=task.task_id,
step_id=step.step_id,
has_browser_state=browser_state is not None,
has_failure_reason=failure_reason is not None,
)
return await _detect_errors_from_context(task, step, failure_reason)
except asyncio.CancelledError:
# Don't swallow cancellation - let it propagate
raise
except Exception:
# Gracefully handle any errors during detection
# Error detection failure should never prevent task from failing
LOG.exception(
"Failed to detect user-defined errors, continuing with task failure",
task_id=task.task_id,
step_id=step.step_id,
)
return []
async def _detect_errors_from_page(
task: Task,
step: Step,
page: Page,
browser_state: BrowserState,
failure_reason: str | None,
) -> list[UserDefinedError]:
"""
Detect errors using full page context (screenshots, HTML, action history).
Reuses the existing extract_user_defined_errors function.
"""
try:
# Scrape the current page
LOG.info(
"Scraping page for error detection",
task_id=task.task_id,
step_id=step.step_id,
url=page.url,
)
scraped_page = await browser_state.scrape_website(
url=page.url,
cleanup_element_tree=app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step),
take_screenshots=True,
draw_boxes=False,
)
if scraped_page is None:
LOG.warning(
"Failed to scrape page for error detection",
task_id=task.task_id,
step_id=step.step_id,
)
return []
# Use the existing extract_user_defined_errors function
LOG.info(
"Calling extract_user_defined_errors with full page context",
task_id=task.task_id,
step_id=step.step_id,
error_code_mapping=task.error_code_mapping,
)
user_defined_errors = await extract_user_defined_errors(
task=task,
step=step,
scraped_page=scraped_page,
reasoning=failure_reason,
)
LOG.info(
"Detected user-defined errors from page",
task_id=task.task_id,
step_id=step.step_id,
error_count=len(user_defined_errors),
errors=[e.error_code for e in user_defined_errors],
)
return user_defined_errors
except Exception:
LOG.exception(
"Failed to detect errors from page",
task_id=task.task_id,
step_id=step.step_id,
)
return []
async def _detect_errors_from_context(
task: Task,
step: Step,
failure_reason: str | None,
) -> list[UserDefinedError]:
"""
Detect errors using only failure context when page is not available.
This is used for early failures (e.g., navigation failures, browser init failures)
where we don't have access to the page or screenshots.
"""
try:
# Get current timezone
context = skyvern_context.current()
tz_info = datetime.now().astimezone().tzinfo
if context and context.tz_info:
tz_info = context.tz_info
# Try to get action history even without page - may provide useful context
action_history = []
try:
action_history = await get_action_history(task=task, current_step=step)
except Exception:
LOG.debug(
"Could not retrieve action history for context-based detection",
task_id=task.task_id,
step_id=step.step_id,
)
# Build a degraded prompt with available context
# Note: No screenshots, no HTML elements, but we try to include action history if available
prompt = prompt_engine.load_prompt(
"surface-user-defined-errors",
navigation_goal=task.navigation_goal or "",
navigation_payload_str=json.dumps(task.navigation_payload or {}),
elements=[],
current_url="",
action_history=json.dumps(action_history),
error_code_mapping_str=json.dumps(task.error_code_mapping),
local_datetime=datetime.now(tz_info).isoformat(),
reasoning=failure_reason,
)
# Call LLM without screenshots
LOG.info(
"Calling LLM to detect user-defined errors from context only",
task_id=task.task_id,
step_id=step.step_id,
error_code_mapping=task.error_code_mapping,
failure_reason=failure_reason,
)
json_response = await app.EXTRACTION_LLM_API_HANDLER(
prompt=prompt,
screenshots=[], # No screenshots available
step=step,
prompt_name="surface-user-defined-errors",
)
# Parse and validate errors
errors_list = json_response.get("errors", [])
user_defined_errors = []
for error_dict in errors_list:
try:
user_defined_error = UserDefinedError.model_validate(error_dict)
user_defined_errors.append(user_defined_error)
except Exception:
LOG.warning(
"Failed to validate user-defined error",
task_id=task.task_id,
step_id=step.step_id,
error_dict=error_dict,
exc_info=True,
)
LOG.info(
"Detected user-defined errors from context",
task_id=task.task_id,
step_id=step.step_id,
error_count=len(user_defined_errors),
errors=[e.error_code for e in user_defined_errors],
)
return user_defined_errors
except Exception:
LOG.exception(
"Failed to detect errors from context",
task_id=task.task_id,
step_id=step.step_id,
)
return []

View file

@ -34,7 +34,7 @@ HTMLTreeStr = str
class MaxStepsReasonResponse(BaseModel):
page_info: str
reasoning: str
errors: list[UserDefinedError]
errors: list[UserDefinedError] = []
def load_prompt_with_elements(

View file

@ -0,0 +1,363 @@
"""
Integration tests for error detection across different failure scenarios.
Tests the complete flow from task failure to error detection and storage.
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.errors.errors import ReachMaxRetriesError, UserDefinedError
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.models import StepStatus
from skyvern.utils.prompt_engine import MaxStepsReasonResponse
from tests.unit.helpers import make_organization, make_step, make_task
@pytest.fixture
def agent():
"""Create a ForgeAgent instance."""
return ForgeAgent()
@pytest.fixture
def mock_browser_state():
"""Create a complete mock browser state."""
browser_state = MagicMock()
page = MagicMock()
page.url = "https://example.com/checkout"
# Use AsyncMock so "await browser_state.get_working_page()" never hits a MagicMock
browser_state.get_working_page = AsyncMock(return_value=page)
browser_state.cleanup_element_tree = MagicMock()
# Mock scrape_website
scraped_page = MagicMock()
scraped_page.url = "https://example.com/checkout"
scraped_page.build_element_tree = MagicMock(
return_value='<html><body><div class="error">Payment failed</div></body></html>'
)
scraped_page.screenshots = [b"screenshot_data"]
async def scrape_website(**kwargs):
return scraped_page
browser_state.scrape_website = scrape_website
return browser_state
def create_error_detection_mocks(detected_errors):
"""Helper to create standard error detection mocks."""
# Mock the top-level detect_user_defined_errors_for_task function
return patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
return_value=detected_errors,
)
@pytest.mark.asyncio
async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
"""Test error detection when navigation fails."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"page_not_found": "The requested page does not exist",
"server_error": "Server is experiencing issues",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
detected_errors = [
UserDefinedError(error_code="page_not_found", reasoning="404 error page shown", confidence_float=0.95)
]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
# Simulate FailedToNavigateToUrl scenario
from skyvern.exceptions import FailedToNavigateToUrl
try:
raise FailedToNavigateToUrl(url=task.url, error_message="Navigation timeout")
except FailedToNavigateToUrl:
# Call fail_task as the exception handler would
result = await agent.fail_task(
task,
step,
f"Failed to navigate to URL. URL:{task.url}, Error:Navigation timeout",
mock_browser_state,
)
assert result is True
# Verify errors were stored
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "page_not_found"
@pytest.mark.asyncio
async def test_max_retries_with_error_detection(agent, mock_browser_state):
"""Test error detection when max retries are exceeded."""
now = datetime.now()
organization = make_organization(now).model_copy(update={"max_retries_per_step": 3})
task = make_task(
now,
organization,
error_code_mapping={
"captcha_failed": "CAPTCHA verification failed",
"rate_limited": "Too many requests",
},
)
step = make_step(now, task, step_id="step-3", status=StepStatus.failed, order=1, retry_index=3, output=None)
detected_errors = [
UserDefinedError(error_code="rate_limited", reasoning="Rate limit message displayed", confidence_float=0.90)
]
# Mock summary_failure_reason_for_max_retries to return MaxStepsReasonResponse with detected errors
async def mock_summary(*args, **kwargs):
return MaxStepsReasonResponse(
page_info="",
reasoning="Multiple retry failures",
errors=detected_errors,
)
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.BROWSER_MANAGER.get_for_task.return_value = mock_browser_state
mock_app.DATABASE.get_task_steps = AsyncMock(return_value=[step, step, step])
mock_app.DATABASE.get_task = AsyncMock(return_value=task)
mock_app.DATABASE.update_task = AsyncMock(return_value=task)
# create_step is awaited in handle_failed_step retry branch; avoid MagicMock in await
next_step = make_step(
now,
task,
step_id="step-next",
status=StepStatus.running,
order=step.order,
retry_index=step.retry_index + 1,
output=None,
)
mock_app.DATABASE.create_step = AsyncMock(return_value=next_step)
# Async mock that forwards to mock_app.DATABASE.update_task so we never await MagicMock inside real update_task
async def mock_update_task(
_self,
task,
status,
extracted_information=None,
failure_reason=None,
webhook_failure_reason=None,
errors=None,
):
updates = {}
if status is not None:
updates["status"] = status
if failure_reason is not None:
updates["failure_reason"] = failure_reason
if errors is not None:
updates["errors"] = errors
return await mock_app.DATABASE.update_task(task.task_id, organization_id=task.organization_id, **updates)
with patch.object(ForgeAgent, "summary_failure_reason_for_max_retries", mock_summary):
with patch.object(ForgeAgent, "update_task", mock_update_task):
result = await agent.handle_failed_step(organization, task, step)
assert result is None # No next step when max retries exceeded
# Verify errors include both system and user-defined errors
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
errors = call_kwargs["errors"]
assert len(errors) == 2
# First should be ReachMaxRetriesError
assert errors[0]["error_code"] == ReachMaxRetriesError().error_code
# Second should be detected user error
assert errors[1]["error_code"] == "rate_limited"
@pytest.mark.asyncio
async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
"""Test error detection when page scraping fails."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"login_required": "User must be logged in",
"access_denied": "Access to resource denied",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
detected_errors = [
UserDefinedError(error_code="login_required", reasoning="Login prompt detected", confidence_float=0.85)
]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
# Simulate ScrapingFailed scenario
from skyvern.exceptions import ScrapingFailed
try:
raise ScrapingFailed(reason="Failed to scrape page elements")
except ScrapingFailed as e:
result = await agent.fail_task(task, step, e.reason, mock_browser_state)
assert result is True
# Verify errors were stored
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "login_required"
@pytest.mark.asyncio
async def test_multiple_failures_accumulate_errors(agent, mock_browser_state):
"""Test that errors accumulate across multiple failures."""
now = datetime.now()
organization = make_organization(now)
# Start with an existing error
initial_errors = [{"error_code": "initial_error", "reasoning": "First error"}]
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment declined",
"address_invalid": "Invalid address",
},
errors=initial_errors,
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
# First failure detects payment error
first_detected = [UserDefinedError(error_code="payment_failed", reasoning="Card declined", confidence_float=0.92)]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with create_error_detection_mocks(first_detected):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
await agent.fail_task(task, step, "First failure", mock_browser_state)
# Only new errors are passed — DB handles appending to existing ones
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
@pytest.mark.asyncio
async def test_error_detection_with_workflow_task(agent, mock_browser_state):
"""Test error detection works for workflow tasks."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
workflow_run_id="wr-123",
workflow_permanent_id="wp-456",
error_code_mapping={
"workflow_error": "Workflow-specific error",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
detected_errors = [
UserDefinedError(error_code="workflow_error", reasoning="Workflow condition not met", confidence_float=0.88)
]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with create_error_detection_mocks(detected_errors):
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Workflow task failed", mock_browser_state)
assert result is True
# Verify errors were stored for workflow task
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert call_kwargs["task_id"] == task.task_id
# workflow_run_id is not passed in the update call, only task_id and errors
assert "workflow_run_id" not in call_kwargs
@pytest.mark.asyncio
async def test_error_detection_performance_doesnt_block_failure(agent, mock_browser_state):
"""Test that slow error detection doesn't significantly delay task failure."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"timeout": "Operation timed out",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
# Simulate slow error detection
import asyncio
async def slow_detection(*args, **kwargs):
await asyncio.sleep(0.1) # Simulate some delay
return [UserDefinedError(error_code="timeout", reasoning="Timeout detected", confidence_float=0.80)]
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
mock_detect.side_effect = slow_detection
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
import time
start_time = time.time()
result = await agent.fail_task(task, step, "Task timeout", mock_browser_state)
elapsed = time.time() - start_time
# Should complete (error detection runs but doesn't block indefinitely)
assert result is True
# Should take at least 0.1s (the sleep time)
assert elapsed >= 0.1
# But not much more (no retry loops or hangs)
assert elapsed < 1.0

View file

@ -0,0 +1,346 @@
"""
Unit tests for error_detection_service module.
Tests the user-defined error detection functionality for failed tasks.
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
from tests.unit.helpers import make_organization, make_task
@pytest.fixture
def mock_browser_state():
"""Create a mock browser state with a working page."""
browser_state = MagicMock()
page = MagicMock()
page.url = "https://example.com/checkout"
async def get_working_page():
return page
browser_state.get_working_page = get_working_page
browser_state.cleanup_element_tree = MagicMock()
# Mock scrape_website
scraped_page = MagicMock()
scraped_page.url = "https://example.com/checkout"
scraped_page.build_element_tree = MagicMock(
return_value='<html><body><div class="error">Payment failed</div></body></html>'
)
scraped_page.screenshots = [b"screenshot_data"]
async def scrape_website(**kwargs):
return scraped_page
browser_state.scrape_website = scrape_website
return browser_state
@pytest.fixture
def mock_step():
"""Create a mock step."""
now = datetime.now()
return Step(
created_at=now,
modified_at=now,
task_id="task-123",
step_id="step-456",
status=StepStatus.failed,
output=None,
order=1,
is_last=True,
retry_index=0,
organization_id="org-123",
)
@pytest.mark.asyncio
async def test_detect_errors_with_valid_error_code_mapping(mock_browser_state, mock_step):
"""Test error detection with valid error_code_mapping and browser state."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
"out_of_stock": "Product is not available",
},
)
# Mock extract_user_defined_errors to return expected errors
from skyvern.errors.errors import UserDefinedError
expected_errors = [
UserDefinedError(
error_code="payment_failed",
reasoning="The page shows a payment declined message",
confidence_float=0.95,
)
]
# Mock extract_user_defined_errors from handler
with patch(
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = expected_errors
# Call detect_user_defined_errors_for_task
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
# Assertions
assert len(detected_errors) == 1
assert detected_errors[0].error_code == "payment_failed"
assert detected_errors[0].reasoning == "The page shows a payment declined message"
# Verify extract_user_defined_errors was called
mock_extract.assert_called_once()
@pytest.mark.asyncio
async def test_detect_errors_no_error_code_mapping(mock_browser_state, mock_step):
"""Test that detection is skipped when no error_code_mapping is provided."""
now = datetime.now()
organization = make_organization(now)
task = make_task(now, organization, error_code_mapping=None)
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
assert detected_errors == []
@pytest.mark.asyncio
async def test_detect_errors_no_browser_state(mock_step):
"""Test that detection uses context-based method when browser_state is None."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
# Mock the LLM API handler for context-based detection
with patch("skyvern.services.error_detection_service.app") as mock_app:
mock_app.EXTRACTION_LLM_API_HANDLER = AsyncMock(
return_value={
"errors": [{"error_code": "payment_failed", "reasoning": "Navigation failed", "confidence_float": 0.80}]
}
)
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=None, failure_reason="Navigation timeout"
)
assert len(detected_errors) == 1
assert detected_errors[0].error_code == "payment_failed"
@pytest.mark.asyncio
async def test_detect_errors_no_working_page(mock_step):
"""Test that detection uses context-based method when there's no working page."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
# Create browser state that returns None for working page
browser_state = MagicMock()
async def get_working_page():
return None
browser_state.get_working_page = get_working_page
# Mock the LLM API handler for context-based detection
with patch("skyvern.services.error_detection_service.app") as mock_app:
mock_app.EXTRACTION_LLM_API_HANDLER = AsyncMock(
return_value={
"errors": [{"error_code": "payment_failed", "reasoning": "Page unavailable", "confidence_float": 0.75}]
}
)
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=browser_state, failure_reason="Page load failed"
)
assert len(detected_errors) == 1
assert detected_errors[0].error_code == "payment_failed"
@pytest.mark.asyncio
async def test_detect_errors_scraping_fails(mock_step):
"""Test that detection handles scraping failures gracefully."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
# Create browser state that raises an exception during scraping
browser_state = MagicMock()
page = MagicMock()
page.url = "https://example.com"
async def get_working_page():
return page
async def scrape_website(**kwargs):
raise Exception("Scraping failed")
browser_state.get_working_page = get_working_page
browser_state.scrape_website = scrape_website
detected_errors = await detect_user_defined_errors_for_task(task=task, step=mock_step, browser_state=browser_state)
# Should return empty list, not raise exception
assert detected_errors == []
@pytest.mark.asyncio
async def test_detect_errors_llm_call_fails(mock_browser_state, mock_step):
"""Test that detection handles LLM call failures gracefully."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
# Mock extract_user_defined_errors to raise exception
with patch(
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
) as mock_extract:
mock_extract.side_effect = Exception("LLM call failed")
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
# Should return empty list, not raise exception
assert detected_errors == []
@pytest.mark.asyncio
async def test_detect_errors_multiple_errors(mock_browser_state, mock_step):
"""Test detection of multiple errors."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
"address_invalid": "Shipping address is invalid",
},
)
from skyvern.errors.errors import UserDefinedError
expected_errors = [
UserDefinedError(
error_code="payment_failed", reasoning="Payment declined message visible", confidence_float=0.95
),
UserDefinedError(
error_code="address_invalid", reasoning="Address validation error shown", confidence_float=0.90
),
]
# Mock extract_user_defined_errors
with patch(
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = expected_errors
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
assert len(detected_errors) == 2
assert detected_errors[0].error_code == "payment_failed"
assert detected_errors[1].error_code == "address_invalid"
@pytest.mark.asyncio
async def test_detect_errors_invalid_error_format(mock_browser_state, mock_step):
"""Test that invalid error formats are skipped (handled by extract_user_defined_errors)."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
from skyvern.errors.errors import UserDefinedError
# Mock extract_user_defined_errors to return only valid errors
expected_errors = [
UserDefinedError(error_code="payment_failed", reasoning="Payment declined", confidence_float=0.90)
]
with patch(
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = expected_errors
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
# Only valid error should be returned
assert len(detected_errors) == 1
assert detected_errors[0].error_code == "payment_failed"
@pytest.mark.asyncio
async def test_detect_errors_empty_llm_response(mock_browser_state, mock_step):
"""Test handling of empty LLM response."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Credit card was declined",
},
)
# Mock extract_user_defined_errors to return empty list
with patch(
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = []
detected_errors = await detect_user_defined_errors_for_task(
task=task, step=mock_step, browser_state=mock_browser_state
)
assert detected_errors == []

View file

@ -0,0 +1,353 @@
"""
Unit tests for fail_task error detection integration.
Tests the integration between ForgeAgent.fail_task() and the error detection service.
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.errors.errors import UserDefinedError
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.models import StepStatus
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from tests.unit.helpers import make_organization, make_step, make_task
@pytest.fixture
def agent():
"""Create a ForgeAgent instance."""
return ForgeAgent()
@pytest.fixture
def mock_browser_state():
"""Create a mock browser state."""
browser_state = MagicMock()
page = MagicMock()
page.url = "https://example.com/error"
async def get_working_page():
return page
async def scrape_website(*args, **kwargs):
return None
browser_state.get_working_page = get_working_page
browser_state.scrape_website = scrape_website
return browser_state
@pytest.mark.asyncio
async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_browser_state):
"""Test that fail_task detects errors when error_code_mapping is provided."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
"out_of_stock": "Product unavailable",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
detected_errors = [
UserDefinedError(
error_code="payment_failed", reasoning="Payment declined message shown on page", confidence_float=0.95
)
]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
mock_detect.return_value = detected_errors
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Verify error detection was called
mock_detect.assert_called_once_with(
task=task,
step=step,
browser_state=mock_browser_state,
failure_reason="Task failed",
)
# Verify task errors were updated in database
mock_app.DATABASE.update_task.assert_called_once()
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert call_kwargs["task_id"] == task.task_id
assert call_kwargs["organization_id"] == task.organization_id
assert len(call_kwargs["errors"]) == 1
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
@pytest.mark.asyncio
async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
"""Test that fail_task skips detection when no error_code_mapping."""
now = datetime.now()
organization = make_organization(now)
task = make_task(now, organization, error_code_mapping=None)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Verify error detection was NOT called
mock_detect.assert_not_called()
# Verify database update was NOT called for errors
mock_app.DATABASE.update_task.assert_not_called()
@pytest.mark.asyncio
async def test_fail_task_without_browser_state(agent):
"""Test that fail_task handles missing browser_state gracefully."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
mock_detect.return_value = []
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
# Call without browser_state
result = await agent.fail_task(task, step, "Task failed", browser_state=None)
assert result is True
# Error detection should still be called (will skip internally)
mock_detect.assert_called_once_with(
task=task,
step=step,
browser_state=None,
failure_reason="Task failed",
)
@pytest.mark.asyncio
async def test_fail_task_without_step(agent, mock_browser_state):
"""Test that fail_task handles missing step gracefully."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
},
)
with patch.object(agent, "update_step", new_callable=AsyncMock) as mock_update_step:
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
# Call without step
result = await agent.fail_task(task, None, "Task failed", mock_browser_state)
assert result is True
# Error detection should not be called (step is required)
mock_detect.assert_not_called()
# update_step should not be called
mock_update_step.assert_not_called()
@pytest.mark.asyncio
async def test_fail_task_error_detection_fails_gracefully(agent, mock_browser_state):
"""Test that fail_task continues even if error detection fails."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
# Error detection raises exception
mock_detect.side_effect = Exception("Detection failed")
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
# Should not raise exception
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
# Task should still be marked as failed
assert result is True
mock_update_task.assert_called_once()
@pytest.mark.asyncio
async def test_fail_task_multiple_errors_detected(agent, mock_browser_state):
"""Test that fail_task handles multiple detected errors."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
"address_invalid": "Address validation failed",
},
errors=[{"error_code": "existing_error", "reasoning": "Pre-existing error"}],
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
detected_errors = [
UserDefinedError(error_code="payment_failed", reasoning="Payment declined", confidence_float=0.90),
UserDefinedError(error_code="address_invalid", reasoning="Invalid shipping address", confidence_float=0.85),
]
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
mock_detect.return_value = detected_errors
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Verify only new errors were passed (DB handles appending to existing errors)
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
assert len(call_kwargs["errors"]) == 2
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
assert call_kwargs["errors"][1]["error_code"] == "address_invalid"
@pytest.mark.asyncio
async def test_fail_task_no_errors_detected(agent, mock_browser_state):
"""Test that fail_task handles case where no errors are detected."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
error_code_mapping={
"payment_failed": "Payment was declined",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
mock_update_task.return_value = task
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
mock_detect.return_value = []
with patch("skyvern.forge.agent.app") as mock_app:
mock_app.DATABASE.update_task = AsyncMock()
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
assert result is True
# Database update for errors should not be called
mock_app.DATABASE.update_task.assert_not_called()
@pytest.mark.asyncio
async def test_fail_task_with_task_already_canceled(agent, mock_browser_state):
"""Test that fail_task returns False when task is already canceled."""
now = datetime.now()
organization = make_organization(now)
task = make_task(
now,
organization,
status=TaskStatus.canceled,
error_code_mapping={
"payment_failed": "Payment was declined",
},
)
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
with patch.object(agent, "update_step", new_callable=AsyncMock):
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
# Simulate TaskAlreadyCanceled exception
from skyvern.exceptions import TaskAlreadyCanceled
mock_update_task.side_effect = TaskAlreadyCanceled("new_status", task.task_id)
with patch(
"skyvern.forge.agent.detect_user_defined_errors_for_task",
new_callable=AsyncMock,
) as mock_detect:
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
# Should return False
assert result is False
# Error detection should not be called
mock_detect.assert_not_called()