mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
Add user-defined error detection on task failure (#4974)
This commit is contained in:
parent
d87a229e2b
commit
c6d62e3fa0
9 changed files with 1498 additions and 42 deletions
|
|
@ -104,6 +104,7 @@ from skyvern.schemas.runs import CUA_ENGINES, RunEngine
|
|||
from skyvern.schemas.steps import AgentStepOutput
|
||||
from skyvern.services import run_service, service_utils
|
||||
from skyvern.services.action_service import get_action_history
|
||||
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
|
||||
from skyvern.services.otp_service import (
|
||||
extract_totp_from_navigation_inputs,
|
||||
poll_otp_value,
|
||||
|
|
@ -419,6 +420,7 @@ class ForgeAgent:
|
|||
next_step: Step | None = None
|
||||
detailed_output: DetailedAgentStepOutput | None = None
|
||||
list_files_before: list[str] = []
|
||||
browser_state: BrowserState | None = None
|
||||
try:
|
||||
if task.workflow_run_id:
|
||||
list_files_before = list_files_in_directory(
|
||||
|
|
@ -727,7 +729,7 @@ class ForgeAgent:
|
|||
"Step cannot be executed, marking task as failed",
|
||||
exc_info=True,
|
||||
)
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, e.message)
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, e.message, browser_state)
|
||||
if is_task_marked_as_failed:
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
|
|
@ -753,7 +755,7 @@ class ForgeAgent:
|
|||
url=e.url,
|
||||
)
|
||||
failure_reason = f"Failed to navigate to URL. URL:{e.url}, Error:{e.error_message}"
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason, browser_state)
|
||||
if is_task_marked_as_failed:
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
|
|
@ -798,7 +800,7 @@ class ForgeAgent:
|
|||
step_order=step.order,
|
||||
step_retry=step.retry_index,
|
||||
)
|
||||
await self.fail_task(task, step, e.message)
|
||||
await self.fail_task(task, step, e.message, browser_state)
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
last_step=step,
|
||||
|
|
@ -819,6 +821,7 @@ class ForgeAgent:
|
|||
step,
|
||||
sfe.reason
|
||||
or "Skyvern failed to load the website. This usually happens when the website is not properly designed, and crashes the browser as a result.",
|
||||
browser_state,
|
||||
)
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
|
|
@ -834,6 +837,7 @@ class ForgeAgent:
|
|||
task,
|
||||
step,
|
||||
"The browser does not have a valid page for skyvern to operate. This may be due to the website being empty or the browser crashing.",
|
||||
browser_state,
|
||||
)
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
|
|
@ -848,7 +852,7 @@ class ForgeAgent:
|
|||
|
||||
failure_reason = get_user_facing_exception_message(e)
|
||||
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason)
|
||||
is_task_marked_as_failed = await self.fail_task(task, step, failure_reason, browser_state)
|
||||
if is_task_marked_as_failed:
|
||||
await self.clean_up_task(
|
||||
task=task,
|
||||
|
|
@ -866,7 +870,9 @@ class ForgeAgent:
|
|||
context.step_id = None
|
||||
context.task_id = None
|
||||
|
||||
async def fail_task(self, task: Task, step: Step | None, reason: str | None) -> bool:
|
||||
async def fail_task(
|
||||
self, task: Task, step: Step | None, reason: str | None, browser_state: BrowserState | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
if step is not None:
|
||||
await self.update_step(
|
||||
|
|
@ -874,11 +880,50 @@ class ForgeAgent:
|
|||
status=StepStatus.failed,
|
||||
)
|
||||
|
||||
# Update task status first
|
||||
await self.update_task(
|
||||
task,
|
||||
status=TaskStatus.failed,
|
||||
failure_reason=reason,
|
||||
)
|
||||
|
||||
# Detect user-defined errors if error_code_mapping is provided
|
||||
if task.error_code_mapping and step is not None:
|
||||
LOG.info(
|
||||
"Task has error_code_mapping, attempting to detect user-defined errors",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_code_mapping=task.error_code_mapping,
|
||||
)
|
||||
|
||||
try:
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
failure_reason=reason,
|
||||
)
|
||||
|
||||
# Update task errors if any were detected
|
||||
# Only pass new errors — update_task() appends to existing errors
|
||||
if detected_errors:
|
||||
new_errors = [error.model_dump() for error in detected_errors]
|
||||
await app.DATABASE.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=task.organization_id,
|
||||
errors=new_errors,
|
||||
)
|
||||
LOG.info(
|
||||
"Updated task with detected user-defined errors",
|
||||
task_id=task.task_id,
|
||||
error_codes=[e.error_code for e in detected_errors],
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to detect or store user-defined errors during task failure",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
return True
|
||||
except TaskAlreadyCanceled:
|
||||
LOG.info(
|
||||
|
|
@ -4028,7 +4073,7 @@ class ForgeAgent:
|
|||
if browser_state is not None:
|
||||
page = await browser_state.get_working_page()
|
||||
|
||||
failure_reason = await self.summary_failure_reason_for_max_retries(
|
||||
failure_response = await self.summary_failure_reason_for_max_retries(
|
||||
organization=organization,
|
||||
task=task,
|
||||
step=step,
|
||||
|
|
@ -4036,13 +4081,23 @@ class ForgeAgent:
|
|||
max_retries=max_retries_per_step,
|
||||
)
|
||||
|
||||
# Only pass new errors — update_task() appends to existing errors in the DB
|
||||
new_errors: list[dict[str, Any]] = [ReachMaxRetriesError().model_dump()]
|
||||
if failure_response.errors:
|
||||
new_errors.extend([error.model_dump() for error in failure_response.errors])
|
||||
LOG.info(
|
||||
"Detected user-defined errors for max retries failure",
|
||||
task_id=task.task_id,
|
||||
error_codes=[e.error_code for e in failure_response.errors],
|
||||
)
|
||||
|
||||
await self.update_task(
|
||||
task,
|
||||
TaskStatus.failed,
|
||||
failure_reason=(
|
||||
f"Max retries per step ({max_retries_per_step}) exceeded. Possible failure reasons: {failure_reason}"
|
||||
f"Max retries per step ({max_retries_per_step}) exceeded. Possible failure reasons: {failure_response.reasoning}"
|
||||
),
|
||||
errors=[ReachMaxRetriesError().model_dump()],
|
||||
errors=new_errors,
|
||||
)
|
||||
return None
|
||||
else:
|
||||
|
|
@ -4180,7 +4235,7 @@ class ForgeAgent:
|
|||
step: Step,
|
||||
page: Page | None,
|
||||
max_retries: int,
|
||||
) -> str:
|
||||
) -> MaxStepsReasonResponse:
|
||||
html = ""
|
||||
screenshots: list[bytes] = []
|
||||
steps_results = []
|
||||
|
|
@ -4231,18 +4286,26 @@ class ForgeAgent:
|
|||
# If we detected LLM errors, return a clear message without calling the LLM
|
||||
if llm_errors:
|
||||
llm_error_details = "; ".join(llm_errors)
|
||||
return (
|
||||
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
|
||||
f"This is typically caused by rate limiting, service outages, or resource exhaustion from the LLM provider. "
|
||||
f"Error details: {llm_error_details}"
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning=(
|
||||
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
|
||||
f"This is typically caused by rate limiting, service outages, or resource exhaustion from the LLM provider. "
|
||||
f"Error details: {llm_error_details}"
|
||||
),
|
||||
errors=[],
|
||||
)
|
||||
|
||||
# If multiple steps failed without producing any actions, it's likely an LLM error during action extraction
|
||||
if steps_without_actions >= max_retries:
|
||||
return (
|
||||
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
|
||||
f"This is typically caused by LLM service errors during action extraction, such as rate limiting, "
|
||||
f"service outages, or resource exhaustion from the LLM provider. Please check the LLM service status and try again."
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning=(
|
||||
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
|
||||
f"This is typically caused by LLM service errors during action extraction, such as rate limiting, "
|
||||
f"service outages, or resource exhaustion from the LLM provider. Please check the LLM service status and try again."
|
||||
),
|
||||
errors=[],
|
||||
)
|
||||
|
||||
if page is not None:
|
||||
|
|
@ -4257,6 +4320,7 @@ class ForgeAgent:
|
|||
steps=steps_results,
|
||||
page_html=html,
|
||||
max_retries=max_retries,
|
||||
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
||||
local_datetime=datetime.now(skyvern_context.ensure_context().tz_info).isoformat(),
|
||||
)
|
||||
json_response = await app.SECONDARY_LLM_API_HANDLER(
|
||||
|
|
@ -4265,26 +4329,42 @@ class ForgeAgent:
|
|||
step=step,
|
||||
prompt_name="summarize-max-retries-reason",
|
||||
)
|
||||
return json_response.get("reasoning", "")
|
||||
return MaxStepsReasonResponse.model_validate(json_response)
|
||||
except Exception:
|
||||
LOG.warning("Failed to summarize the failure reason for max retries")
|
||||
# Check if we have LLM errors even if the summarization failed
|
||||
if llm_errors:
|
||||
llm_error_details = "; ".join(llm_errors)
|
||||
return (
|
||||
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
|
||||
f"Error details: {llm_error_details}"
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning=(
|
||||
f"The task failed due to LLM service errors. The LLM provider encountered errors and was unable to process the requests. "
|
||||
f"Error details: {llm_error_details}"
|
||||
),
|
||||
errors=[],
|
||||
)
|
||||
# If multiple steps failed without actions during summarization failure, still report it
|
||||
if steps_without_actions >= max_retries:
|
||||
return (
|
||||
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
|
||||
f"This is typically caused by LLM service errors during action extraction."
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning=(
|
||||
f"The task failed because all {max_retries} retry attempts failed to generate actions. "
|
||||
f"This is typically caused by LLM service errors during action extraction."
|
||||
),
|
||||
errors=[],
|
||||
)
|
||||
if steps_results:
|
||||
last_step_result = steps_results[-1]
|
||||
return f"Retry Step {last_step_result['order']}: {last_step_result['actions_result']}"
|
||||
return ""
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning=f"Retry Step {last_step_result['order']}: {last_step_result['actions_result']}",
|
||||
errors=[],
|
||||
)
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning="",
|
||||
errors=[],
|
||||
)
|
||||
|
||||
async def handle_completed_step(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,22 @@ Make sure to ONLY return the JSON object in this format with no additional text
|
|||
```json
|
||||
{
|
||||
"page_info": str, // Think step by step. Describe useful information from the page HTML related to the user goal.
|
||||
"reasoning": str, // Think step by step. Summarize why the actions failed based on 'page_info', screenshots, user goal and the failed actions. Keep it short and to the point.
|
||||
"reasoning": str, // Think step by step. Summarize why the actions failed based on 'page_info', screenshots, user goal and the failed actions. Keep it short and to the point.{% if error_code_mapping_str %}
|
||||
"errors": array // A list of errors. This is used to surface any errors that matches the current situation. If no error description suits the current situation on the screenshots or the action history, return an empty list. You are allowed to return multiple errors if there are multiple errors on the page.
|
||||
[{
|
||||
"error_code": str, // The error code from the user's error code list
|
||||
"reasoning": str, // The reasoning behind the error. Be specific, referencing any user information and their fields in your reasoning. Keep the reasoning short and to the point.
|
||||
"confidence_float": float // The confidence of the error. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
|
||||
}]{% endif %}
|
||||
}
|
||||
```
|
||||
|
||||
User Goal:
|
||||
{{ navigation_goal }}
|
||||
{% if error_code_mapping_str %}
|
||||
Use the error codes and their descriptions to surface user-defined errors. Do not return any error that's not defined by the user. User defined errors:
|
||||
{{ error_code_mapping_str }}
|
||||
{% endif %}
|
||||
|
||||
User Details:
|
||||
{{ navigation_payload }}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
You are here to help the user use the error codes and their descriptions to surface user-defined errors based on the screenshots, user goal, user details, action history{{ ", context" if reasoning else "" }} and the HTML elements.
|
||||
You are here to help the user use the error codes and their descriptions to surface user-defined errors based on the screenshots(if provided), user goal, user details, action history{{ ", context" if reasoning else "" }} and the HTML elements.
|
||||
Do not return any error that's not defined by the user.
|
||||
|
||||
Reply in JSON format with the following keys:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_request
|
|||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
|
|
@ -102,6 +103,7 @@ from skyvern.forge.sdk.workflow.models.parameter import (
|
|||
)
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.schemas.workflows import BlockResult, BlockStatus, BlockType, FileStorageType, FileType
|
||||
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
|
||||
from skyvern.utils.strings import generate_random_string
|
||||
from skyvern.utils.templating import get_missing_variables
|
||||
from skyvern.utils.token_counter import count_tokens
|
||||
|
|
@ -705,6 +707,49 @@ class BaseTaskBlock(Block):
|
|||
|
||||
return order, retry + 1
|
||||
|
||||
async def _handle_task_failure_with_error_detection(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState | None,
|
||||
failure_reason: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Handle task failure by updating the task status and detecting user-defined errors.
|
||||
|
||||
This helper method consolidates the error detection logic that was previously
|
||||
duplicated across multiple exception handlers in the execute method.
|
||||
"""
|
||||
await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=organization_id,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
# Detect user-defined errors if error_code_mapping is provided
|
||||
if self.error_code_mapping:
|
||||
try:
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
if detected_errors:
|
||||
# Only pass new errors — update_task() appends to existing errors
|
||||
new_errors = [error.model_dump() for error in detected_errors]
|
||||
await app.DATABASE.update_task(
|
||||
task_id=task.task_id,
|
||||
organization_id=organization_id,
|
||||
errors=new_errors,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to detect or store user-defined errors during task failure",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -850,12 +895,12 @@ class BaseTaskBlock(Block):
|
|||
task_id=task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
# Make sure the task is marked as failed in the database before raising the exception
|
||||
await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=workflow_run.organization_id,
|
||||
await self._handle_task_failure_with_error_detection(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
failure_reason=str(e),
|
||||
organization_id=workflow_run.organization_id,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
@ -902,11 +947,12 @@ class BaseTaskBlock(Block):
|
|||
try:
|
||||
await browser_state.navigate_to_url(page=working_page, url=self.url)
|
||||
except Exception as e:
|
||||
await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=workflow_run.organization_id,
|
||||
await self._handle_task_failure_with_error_detection(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
failure_reason=str(e),
|
||||
organization_id=workflow_run.organization_id,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
@ -926,11 +972,12 @@ class BaseTaskBlock(Block):
|
|||
)
|
||||
except Exception as e:
|
||||
# Make sure the task is marked as failed in the database before raising the exception
|
||||
await app.DATABASE.update_task(
|
||||
task.task_id,
|
||||
status=TaskStatus.failed,
|
||||
organization_id=workflow_run.organization_id,
|
||||
await self._handle_task_failure_with_error_detection(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
failure_reason=str(e),
|
||||
organization_id=workflow_run.organization_id,
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
|
|
|
|||
257
skyvern/services/error_detection_service.py
Normal file
257
skyvern/services/error_detection_service.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""
|
||||
Service for detecting user-defined errors when tasks fail.
|
||||
|
||||
This module provides a centralized error detection service that can be used
|
||||
by both agent execution and script execution to detect user-defined errors
|
||||
based on the current page state or failure context.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Page
|
||||
|
||||
from skyvern.errors.errors import UserDefinedError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.services.action_service import get_action_history
|
||||
from skyvern.webeye.actions.handler import extract_user_defined_errors
|
||||
from skyvern.webeye.browser_state import BrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def detect_user_defined_errors_for_task(
|
||||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState | None = None,
|
||||
failure_reason: str | None = None,
|
||||
) -> list[UserDefinedError]:
|
||||
"""
|
||||
Detect user-defined errors for a failed task.
|
||||
|
||||
This function uses the existing extract_user_defined_errors when browser_state
|
||||
and page are available. When they're not available (early failures), it falls
|
||||
back to detecting errors from the failure context.
|
||||
|
||||
Args:
|
||||
task: The task that failed
|
||||
step: The last step executed
|
||||
browser_state: Optional browser state (may be None in early failures)
|
||||
failure_reason: The reason for task failure (used when browser_state unavailable)
|
||||
|
||||
Returns:
|
||||
List of detected UserDefinedError objects (empty if detection fails)
|
||||
"""
|
||||
# Skip detection if no error_code_mapping defined
|
||||
if not task.error_code_mapping:
|
||||
LOG.debug(
|
||||
"No error_code_mapping defined for task, skipping error detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
# Try to use full page-based detection if browser state is available
|
||||
if browser_state is not None:
|
||||
page = await browser_state.get_working_page()
|
||||
if page is not None:
|
||||
LOG.info(
|
||||
"Using page-based error detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
url=page.url,
|
||||
)
|
||||
return await _detect_errors_from_page(task, step, page, browser_state, failure_reason)
|
||||
|
||||
# Fall back to context-based detection when page is not available
|
||||
LOG.info(
|
||||
"Browser state or page not available, using context-based error detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
has_browser_state=browser_state is not None,
|
||||
has_failure_reason=failure_reason is not None,
|
||||
)
|
||||
return await _detect_errors_from_context(task, step, failure_reason)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Don't swallow cancellation - let it propagate
|
||||
raise
|
||||
except Exception:
|
||||
# Gracefully handle any errors during detection
|
||||
# Error detection failure should never prevent task from failing
|
||||
LOG.exception(
|
||||
"Failed to detect user-defined errors, continuing with task failure",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
async def _detect_errors_from_page(
|
||||
task: Task,
|
||||
step: Step,
|
||||
page: Page,
|
||||
browser_state: BrowserState,
|
||||
failure_reason: str | None,
|
||||
) -> list[UserDefinedError]:
|
||||
"""
|
||||
Detect errors using full page context (screenshots, HTML, action history).
|
||||
|
||||
Reuses the existing extract_user_defined_errors function.
|
||||
"""
|
||||
try:
|
||||
# Scrape the current page
|
||||
LOG.info(
|
||||
"Scraping page for error detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
url=page.url,
|
||||
)
|
||||
scraped_page = await browser_state.scrape_website(
|
||||
url=page.url,
|
||||
cleanup_element_tree=app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step),
|
||||
take_screenshots=True,
|
||||
draw_boxes=False,
|
||||
)
|
||||
|
||||
if scraped_page is None:
|
||||
LOG.warning(
|
||||
"Failed to scrape page for error detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
return []
|
||||
|
||||
# Use the existing extract_user_defined_errors function
|
||||
LOG.info(
|
||||
"Calling extract_user_defined_errors with full page context",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_code_mapping=task.error_code_mapping,
|
||||
)
|
||||
user_defined_errors = await extract_user_defined_errors(
|
||||
task=task,
|
||||
step=step,
|
||||
scraped_page=scraped_page,
|
||||
reasoning=failure_reason,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Detected user-defined errors from page",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_count=len(user_defined_errors),
|
||||
errors=[e.error_code for e in user_defined_errors],
|
||||
)
|
||||
|
||||
return user_defined_errors
|
||||
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to detect errors from page",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
async def _detect_errors_from_context(
|
||||
task: Task,
|
||||
step: Step,
|
||||
failure_reason: str | None,
|
||||
) -> list[UserDefinedError]:
|
||||
"""
|
||||
Detect errors using only failure context when page is not available.
|
||||
|
||||
This is used for early failures (e.g., navigation failures, browser init failures)
|
||||
where we don't have access to the page or screenshots.
|
||||
"""
|
||||
try:
|
||||
# Get current timezone
|
||||
context = skyvern_context.current()
|
||||
tz_info = datetime.now().astimezone().tzinfo
|
||||
if context and context.tz_info:
|
||||
tz_info = context.tz_info
|
||||
|
||||
# Try to get action history even without page - may provide useful context
|
||||
|
||||
action_history = []
|
||||
try:
|
||||
action_history = await get_action_history(task=task, current_step=step)
|
||||
except Exception:
|
||||
LOG.debug(
|
||||
"Could not retrieve action history for context-based detection",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
|
||||
# Build a degraded prompt with available context
|
||||
# Note: No screenshots, no HTML elements, but we try to include action history if available
|
||||
prompt = prompt_engine.load_prompt(
|
||||
"surface-user-defined-errors",
|
||||
navigation_goal=task.navigation_goal or "",
|
||||
navigation_payload_str=json.dumps(task.navigation_payload or {}),
|
||||
elements=[],
|
||||
current_url="",
|
||||
action_history=json.dumps(action_history),
|
||||
error_code_mapping_str=json.dumps(task.error_code_mapping),
|
||||
local_datetime=datetime.now(tz_info).isoformat(),
|
||||
reasoning=failure_reason,
|
||||
)
|
||||
|
||||
# Call LLM without screenshots
|
||||
LOG.info(
|
||||
"Calling LLM to detect user-defined errors from context only",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_code_mapping=task.error_code_mapping,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
json_response = await app.EXTRACTION_LLM_API_HANDLER(
|
||||
prompt=prompt,
|
||||
screenshots=[], # No screenshots available
|
||||
step=step,
|
||||
prompt_name="surface-user-defined-errors",
|
||||
)
|
||||
|
||||
# Parse and validate errors
|
||||
errors_list = json_response.get("errors", [])
|
||||
user_defined_errors = []
|
||||
|
||||
for error_dict in errors_list:
|
||||
try:
|
||||
user_defined_error = UserDefinedError.model_validate(error_dict)
|
||||
user_defined_errors.append(user_defined_error)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to validate user-defined error",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_dict=error_dict,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Detected user-defined errors from context",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
error_count=len(user_defined_errors),
|
||||
errors=[e.error_code for e in user_defined_errors],
|
||||
)
|
||||
|
||||
return user_defined_errors
|
||||
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to detect errors from context",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
)
|
||||
return []
|
||||
|
|
@ -34,7 +34,7 @@ HTMLTreeStr = str
|
|||
class MaxStepsReasonResponse(BaseModel):
|
||||
page_info: str
|
||||
reasoning: str
|
||||
errors: list[UserDefinedError]
|
||||
errors: list[UserDefinedError] = []
|
||||
|
||||
|
||||
def load_prompt_with_elements(
|
||||
|
|
|
|||
363
tests/unit/test_error_detection_integration.py
Normal file
363
tests/unit/test_error_detection_integration.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
"""
|
||||
Integration tests for error detection across different failure scenarios.
|
||||
|
||||
Tests the complete flow from task failure to error detection and storage.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.errors.errors import ReachMaxRetriesError, UserDefinedError
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.utils.prompt_engine import MaxStepsReasonResponse
|
||||
from tests.unit.helpers import make_organization, make_step, make_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Create a ForgeAgent instance."""
|
||||
return ForgeAgent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_browser_state():
|
||||
"""Create a complete mock browser state."""
|
||||
browser_state = MagicMock()
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com/checkout"
|
||||
|
||||
# Use AsyncMock so "await browser_state.get_working_page()" never hits a MagicMock
|
||||
browser_state.get_working_page = AsyncMock(return_value=page)
|
||||
browser_state.cleanup_element_tree = MagicMock()
|
||||
|
||||
# Mock scrape_website
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.url = "https://example.com/checkout"
|
||||
scraped_page.build_element_tree = MagicMock(
|
||||
return_value='<html><body><div class="error">Payment failed</div></body></html>'
|
||||
)
|
||||
scraped_page.screenshots = [b"screenshot_data"]
|
||||
|
||||
async def scrape_website(**kwargs):
|
||||
return scraped_page
|
||||
|
||||
browser_state.scrape_website = scrape_website
|
||||
|
||||
return browser_state
|
||||
|
||||
|
||||
def create_error_detection_mocks(detected_errors):
|
||||
"""Helper to create standard error detection mocks."""
|
||||
# Mock the top-level detect_user_defined_errors_for_task function
|
||||
return patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
return_value=detected_errors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_navigate_failure_with_error_detection(agent, mock_browser_state):
|
||||
"""Test error detection when navigation fails."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"page_not_found": "The requested page does not exist",
|
||||
"server_error": "Server is experiencing issues",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(error_code="page_not_found", reasoning="404 error page shown", confidence_float=0.95)
|
||||
]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
# Simulate FailedToNavigateToUrl scenario
|
||||
from skyvern.exceptions import FailedToNavigateToUrl
|
||||
|
||||
try:
|
||||
raise FailedToNavigateToUrl(url=task.url, error_message="Navigation timeout")
|
||||
except FailedToNavigateToUrl:
|
||||
# Call fail_task as the exception handler would
|
||||
result = await agent.fail_task(
|
||||
task,
|
||||
step,
|
||||
f"Failed to navigate to URL. URL:{task.url}, Error:Navigation timeout",
|
||||
mock_browser_state,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify errors were stored
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "page_not_found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_with_error_detection(agent, mock_browser_state):
|
||||
"""Test error detection when max retries are exceeded."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now).model_copy(update={"max_retries_per_step": 3})
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"captcha_failed": "CAPTCHA verification failed",
|
||||
"rate_limited": "Too many requests",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-3", status=StepStatus.failed, order=1, retry_index=3, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(error_code="rate_limited", reasoning="Rate limit message displayed", confidence_float=0.90)
|
||||
]
|
||||
|
||||
# Mock summary_failure_reason_for_max_retries to return MaxStepsReasonResponse with detected errors
|
||||
async def mock_summary(*args, **kwargs):
|
||||
return MaxStepsReasonResponse(
|
||||
page_info="",
|
||||
reasoning="Multiple retry failures",
|
||||
errors=detected_errors,
|
||||
)
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.BROWSER_MANAGER.get_for_task.return_value = mock_browser_state
|
||||
mock_app.DATABASE.get_task_steps = AsyncMock(return_value=[step, step, step])
|
||||
mock_app.DATABASE.get_task = AsyncMock(return_value=task)
|
||||
mock_app.DATABASE.update_task = AsyncMock(return_value=task)
|
||||
# create_step is awaited in handle_failed_step retry branch; avoid MagicMock in await
|
||||
next_step = make_step(
|
||||
now,
|
||||
task,
|
||||
step_id="step-next",
|
||||
status=StepStatus.running,
|
||||
order=step.order,
|
||||
retry_index=step.retry_index + 1,
|
||||
output=None,
|
||||
)
|
||||
mock_app.DATABASE.create_step = AsyncMock(return_value=next_step)
|
||||
|
||||
# Async mock that forwards to mock_app.DATABASE.update_task so we never await MagicMock inside real update_task
|
||||
async def mock_update_task(
|
||||
_self,
|
||||
task,
|
||||
status,
|
||||
extracted_information=None,
|
||||
failure_reason=None,
|
||||
webhook_failure_reason=None,
|
||||
errors=None,
|
||||
):
|
||||
updates = {}
|
||||
if status is not None:
|
||||
updates["status"] = status
|
||||
if failure_reason is not None:
|
||||
updates["failure_reason"] = failure_reason
|
||||
if errors is not None:
|
||||
updates["errors"] = errors
|
||||
return await mock_app.DATABASE.update_task(task.task_id, organization_id=task.organization_id, **updates)
|
||||
|
||||
with patch.object(ForgeAgent, "summary_failure_reason_for_max_retries", mock_summary):
|
||||
with patch.object(ForgeAgent, "update_task", mock_update_task):
|
||||
result = await agent.handle_failed_step(organization, task, step)
|
||||
|
||||
assert result is None # No next step when max retries exceeded
|
||||
|
||||
# Verify errors include both system and user-defined errors
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
|
||||
errors = call_kwargs["errors"]
|
||||
assert len(errors) == 2
|
||||
|
||||
# First should be ReachMaxRetriesError
|
||||
assert errors[0]["error_code"] == ReachMaxRetriesError().error_code
|
||||
|
||||
# Second should be detected user error
|
||||
assert errors[1]["error_code"] == "rate_limited"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scraping_failure_with_error_detection(agent, mock_browser_state):
|
||||
"""Test error detection when page scraping fails."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"login_required": "User must be logged in",
|
||||
"access_denied": "Access to resource denied",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(error_code="login_required", reasoning="Login prompt detected", confidence_float=0.85)
|
||||
]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
# Simulate ScrapingFailed scenario
|
||||
from skyvern.exceptions import ScrapingFailed
|
||||
|
||||
try:
|
||||
raise ScrapingFailed(reason="Failed to scrape page elements")
|
||||
except ScrapingFailed as e:
|
||||
result = await agent.fail_task(task, step, e.reason, mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify errors were stored
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "login_required"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_failures_accumulate_errors(agent, mock_browser_state):
|
||||
"""Test that errors accumulate across multiple failures."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
|
||||
# Start with an existing error
|
||||
initial_errors = [{"error_code": "initial_error", "reasoning": "First error"}]
|
||||
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment declined",
|
||||
"address_invalid": "Invalid address",
|
||||
},
|
||||
errors=initial_errors,
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
# First failure detects payment error
|
||||
first_detected = [UserDefinedError(error_code="payment_failed", reasoning="Card declined", confidence_float=0.92)]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with create_error_detection_mocks(first_detected):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
await agent.fail_task(task, step, "First failure", mock_browser_state)
|
||||
|
||||
# Only new errors are passed — DB handles appending to existing ones
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_detection_with_workflow_task(agent, mock_browser_state):
|
||||
"""Test error detection works for workflow tasks."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
workflow_run_id="wr-123",
|
||||
workflow_permanent_id="wp-456",
|
||||
error_code_mapping={
|
||||
"workflow_error": "Workflow-specific error",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(error_code="workflow_error", reasoning="Workflow condition not met", confidence_float=0.88)
|
||||
]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with create_error_detection_mocks(detected_errors):
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Workflow task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify errors were stored for workflow task
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert call_kwargs["task_id"] == task.task_id
|
||||
# workflow_run_id is not passed in the update call, only task_id and errors
|
||||
assert "workflow_run_id" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_detection_performance_doesnt_block_failure(agent, mock_browser_state):
|
||||
"""Test that slow error detection doesn't significantly delay task failure."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"timeout": "Operation timed out",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
# Simulate slow error detection
|
||||
import asyncio
|
||||
|
||||
async def slow_detection(*args, **kwargs):
|
||||
await asyncio.sleep(0.1) # Simulate some delay
|
||||
return [UserDefinedError(error_code="timeout", reasoning="Timeout detected", confidence_float=0.80)]
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
mock_detect.side_effect = slow_detection
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
result = await agent.fail_task(task, step, "Task timeout", mock_browser_state)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should complete (error detection runs but doesn't block indefinitely)
|
||||
assert result is True
|
||||
# Should take at least 0.1s (the sleep time)
|
||||
assert elapsed >= 0.1
|
||||
# But not much more (no retry loops or hangs)
|
||||
assert elapsed < 1.0
|
||||
346
tests/unit/test_error_detection_service.py
Normal file
346
tests/unit/test_error_detection_service.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"""
|
||||
Unit tests for error_detection_service module.
|
||||
|
||||
Tests the user-defined error detection functionality for failed tasks.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.services.error_detection_service import detect_user_defined_errors_for_task
|
||||
from tests.unit.helpers import make_organization, make_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_browser_state():
|
||||
"""Create a mock browser state with a working page."""
|
||||
browser_state = MagicMock()
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com/checkout"
|
||||
|
||||
async def get_working_page():
|
||||
return page
|
||||
|
||||
browser_state.get_working_page = get_working_page
|
||||
browser_state.cleanup_element_tree = MagicMock()
|
||||
|
||||
# Mock scrape_website
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.url = "https://example.com/checkout"
|
||||
scraped_page.build_element_tree = MagicMock(
|
||||
return_value='<html><body><div class="error">Payment failed</div></body></html>'
|
||||
)
|
||||
scraped_page.screenshots = [b"screenshot_data"]
|
||||
|
||||
async def scrape_website(**kwargs):
|
||||
return scraped_page
|
||||
|
||||
browser_state.scrape_website = scrape_website
|
||||
|
||||
return browser_state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_step():
|
||||
"""Create a mock step."""
|
||||
now = datetime.now()
|
||||
return Step(
|
||||
created_at=now,
|
||||
modified_at=now,
|
||||
task_id="task-123",
|
||||
step_id="step-456",
|
||||
status=StepStatus.failed,
|
||||
output=None,
|
||||
order=1,
|
||||
is_last=True,
|
||||
retry_index=0,
|
||||
organization_id="org-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_with_valid_error_code_mapping(mock_browser_state, mock_step):
|
||||
"""Test error detection with valid error_code_mapping and browser state."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
"out_of_stock": "Product is not available",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock extract_user_defined_errors to return expected errors
|
||||
from skyvern.errors.errors import UserDefinedError
|
||||
|
||||
expected_errors = [
|
||||
UserDefinedError(
|
||||
error_code="payment_failed",
|
||||
reasoning="The page shows a payment declined message",
|
||||
confidence_float=0.95,
|
||||
)
|
||||
]
|
||||
|
||||
# Mock extract_user_defined_errors from handler
|
||||
with patch(
|
||||
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = expected_errors
|
||||
|
||||
# Call detect_user_defined_errors_for_task
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(detected_errors) == 1
|
||||
assert detected_errors[0].error_code == "payment_failed"
|
||||
assert detected_errors[0].reasoning == "The page shows a payment declined message"
|
||||
|
||||
# Verify extract_user_defined_errors was called
|
||||
mock_extract.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_no_error_code_mapping(mock_browser_state, mock_step):
|
||||
"""Test that detection is skipped when no error_code_mapping is provided."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization, error_code_mapping=None)
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
assert detected_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_no_browser_state(mock_step):
|
||||
"""Test that detection uses context-based method when browser_state is None."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock the LLM API handler for context-based detection
|
||||
with patch("skyvern.services.error_detection_service.app") as mock_app:
|
||||
mock_app.EXTRACTION_LLM_API_HANDLER = AsyncMock(
|
||||
return_value={
|
||||
"errors": [{"error_code": "payment_failed", "reasoning": "Navigation failed", "confidence_float": 0.80}]
|
||||
}
|
||||
)
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=None, failure_reason="Navigation timeout"
|
||||
)
|
||||
|
||||
assert len(detected_errors) == 1
|
||||
assert detected_errors[0].error_code == "payment_failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_no_working_page(mock_step):
|
||||
"""Test that detection uses context-based method when there's no working page."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
# Create browser state that returns None for working page
|
||||
browser_state = MagicMock()
|
||||
|
||||
async def get_working_page():
|
||||
return None
|
||||
|
||||
browser_state.get_working_page = get_working_page
|
||||
|
||||
# Mock the LLM API handler for context-based detection
|
||||
with patch("skyvern.services.error_detection_service.app") as mock_app:
|
||||
mock_app.EXTRACTION_LLM_API_HANDLER = AsyncMock(
|
||||
return_value={
|
||||
"errors": [{"error_code": "payment_failed", "reasoning": "Page unavailable", "confidence_float": 0.75}]
|
||||
}
|
||||
)
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=browser_state, failure_reason="Page load failed"
|
||||
)
|
||||
|
||||
assert len(detected_errors) == 1
|
||||
assert detected_errors[0].error_code == "payment_failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_scraping_fails(mock_step):
|
||||
"""Test that detection handles scraping failures gracefully."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
# Create browser state that raises an exception during scraping
|
||||
browser_state = MagicMock()
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com"
|
||||
|
||||
async def get_working_page():
|
||||
return page
|
||||
|
||||
async def scrape_website(**kwargs):
|
||||
raise Exception("Scraping failed")
|
||||
|
||||
browser_state.get_working_page = get_working_page
|
||||
browser_state.scrape_website = scrape_website
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(task=task, step=mock_step, browser_state=browser_state)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert detected_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_llm_call_fails(mock_browser_state, mock_step):
|
||||
"""Test that detection handles LLM call failures gracefully."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock extract_user_defined_errors to raise exception
|
||||
with patch(
|
||||
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
|
||||
) as mock_extract:
|
||||
mock_extract.side_effect = Exception("LLM call failed")
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert detected_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_multiple_errors(mock_browser_state, mock_step):
|
||||
"""Test detection of multiple errors."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
"address_invalid": "Shipping address is invalid",
|
||||
},
|
||||
)
|
||||
|
||||
from skyvern.errors.errors import UserDefinedError
|
||||
|
||||
expected_errors = [
|
||||
UserDefinedError(
|
||||
error_code="payment_failed", reasoning="Payment declined message visible", confidence_float=0.95
|
||||
),
|
||||
UserDefinedError(
|
||||
error_code="address_invalid", reasoning="Address validation error shown", confidence_float=0.90
|
||||
),
|
||||
]
|
||||
|
||||
# Mock extract_user_defined_errors
|
||||
with patch(
|
||||
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = expected_errors
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
assert len(detected_errors) == 2
|
||||
assert detected_errors[0].error_code == "payment_failed"
|
||||
assert detected_errors[1].error_code == "address_invalid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_invalid_error_format(mock_browser_state, mock_step):
|
||||
"""Test that invalid error formats are skipped (handled by extract_user_defined_errors)."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
from skyvern.errors.errors import UserDefinedError
|
||||
|
||||
# Mock extract_user_defined_errors to return only valid errors
|
||||
expected_errors = [
|
||||
UserDefinedError(error_code="payment_failed", reasoning="Payment declined", confidence_float=0.90)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = expected_errors
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
# Only valid error should be returned
|
||||
assert len(detected_errors) == 1
|
||||
assert detected_errors[0].error_code == "payment_failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_errors_empty_llm_response(mock_browser_state, mock_step):
|
||||
"""Test handling of empty LLM response."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Credit card was declined",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock extract_user_defined_errors to return empty list
|
||||
with patch(
|
||||
"skyvern.services.error_detection_service.extract_user_defined_errors", new_callable=AsyncMock
|
||||
) as mock_extract:
|
||||
mock_extract.return_value = []
|
||||
|
||||
detected_errors = await detect_user_defined_errors_for_task(
|
||||
task=task, step=mock_step, browser_state=mock_browser_state
|
||||
)
|
||||
|
||||
assert detected_errors == []
|
||||
353
tests/unit/test_fail_task_error_detection.py
Normal file
353
tests/unit/test_fail_task_error_detection.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Unit tests for fail_task error detection integration.
|
||||
|
||||
Tests the integration between ForgeAgent.fail_task() and the error detection service.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.errors.errors import UserDefinedError
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from tests.unit.helpers import make_organization, make_step, make_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Create a ForgeAgent instance."""
|
||||
return ForgeAgent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_browser_state():
|
||||
"""Create a mock browser state."""
|
||||
browser_state = MagicMock()
|
||||
page = MagicMock()
|
||||
page.url = "https://example.com/error"
|
||||
|
||||
async def get_working_page():
|
||||
return page
|
||||
|
||||
async def scrape_website(*args, **kwargs):
|
||||
return None
|
||||
|
||||
browser_state.get_working_page = get_working_page
|
||||
browser_state.scrape_website = scrape_website
|
||||
return browser_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_with_error_code_mapping_detects_errors(agent, mock_browser_state):
|
||||
"""Test that fail_task detects errors when error_code_mapping is provided."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
"out_of_stock": "Product unavailable",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(
|
||||
error_code="payment_failed", reasoning="Payment declined message shown on page", confidence_float=0.95
|
||||
)
|
||||
]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
mock_detect.return_value = detected_errors
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify error detection was called
|
||||
mock_detect.assert_called_once_with(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=mock_browser_state,
|
||||
failure_reason="Task failed",
|
||||
)
|
||||
|
||||
# Verify task errors were updated in database
|
||||
mock_app.DATABASE.update_task.assert_called_once()
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert call_kwargs["task_id"] == task.task_id
|
||||
assert call_kwargs["organization_id"] == task.organization_id
|
||||
assert len(call_kwargs["errors"]) == 1
|
||||
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_without_error_code_mapping(agent, mock_browser_state):
|
||||
"""Test that fail_task skips detection when no error_code_mapping."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(now, organization, error_code_mapping=None)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify error detection was NOT called
|
||||
mock_detect.assert_not_called()
|
||||
|
||||
# Verify database update was NOT called for errors
|
||||
mock_app.DATABASE.update_task.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_without_browser_state(agent):
|
||||
"""Test that fail_task handles missing browser_state gracefully."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
mock_detect.return_value = []
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
# Call without browser_state
|
||||
result = await agent.fail_task(task, step, "Task failed", browser_state=None)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Error detection should still be called (will skip internally)
|
||||
mock_detect.assert_called_once_with(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=None,
|
||||
failure_reason="Task failed",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_without_step(agent, mock_browser_state):
|
||||
"""Test that fail_task handles missing step gracefully."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock) as mock_update_step:
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
# Call without step
|
||||
result = await agent.fail_task(task, None, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Error detection should not be called (step is required)
|
||||
mock_detect.assert_not_called()
|
||||
|
||||
# update_step should not be called
|
||||
mock_update_step.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_error_detection_fails_gracefully(agent, mock_browser_state):
|
||||
"""Test that fail_task continues even if error detection fails."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
# Error detection raises exception
|
||||
mock_detect.side_effect = Exception("Detection failed")
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
# Should not raise exception
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
# Task should still be marked as failed
|
||||
assert result is True
|
||||
mock_update_task.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_multiple_errors_detected(agent, mock_browser_state):
|
||||
"""Test that fail_task handles multiple detected errors."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
"address_invalid": "Address validation failed",
|
||||
},
|
||||
errors=[{"error_code": "existing_error", "reasoning": "Pre-existing error"}],
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
detected_errors = [
|
||||
UserDefinedError(error_code="payment_failed", reasoning="Payment declined", confidence_float=0.90),
|
||||
UserDefinedError(error_code="address_invalid", reasoning="Invalid shipping address", confidence_float=0.85),
|
||||
]
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
mock_detect.return_value = detected_errors
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify only new errors were passed (DB handles appending to existing errors)
|
||||
call_kwargs = mock_app.DATABASE.update_task.call_args[1]
|
||||
assert len(call_kwargs["errors"]) == 2
|
||||
assert call_kwargs["errors"][0]["error_code"] == "payment_failed"
|
||||
assert call_kwargs["errors"][1]["error_code"] == "address_invalid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_no_errors_detected(agent, mock_browser_state):
|
||||
"""Test that fail_task handles case where no errors are detected."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
mock_update_task.return_value = task
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
mock_detect.return_value = []
|
||||
|
||||
with patch("skyvern.forge.agent.app") as mock_app:
|
||||
mock_app.DATABASE.update_task = AsyncMock()
|
||||
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Database update for errors should not be called
|
||||
mock_app.DATABASE.update_task.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_task_with_task_already_canceled(agent, mock_browser_state):
|
||||
"""Test that fail_task returns False when task is already canceled."""
|
||||
now = datetime.now()
|
||||
organization = make_organization(now)
|
||||
task = make_task(
|
||||
now,
|
||||
organization,
|
||||
status=TaskStatus.canceled,
|
||||
error_code_mapping={
|
||||
"payment_failed": "Payment was declined",
|
||||
},
|
||||
)
|
||||
step = make_step(now, task, step_id="step-1", status=StepStatus.running, order=1, output=None)
|
||||
|
||||
with patch.object(agent, "update_step", new_callable=AsyncMock):
|
||||
with patch.object(agent, "update_task", new_callable=AsyncMock) as mock_update_task:
|
||||
# Simulate TaskAlreadyCanceled exception
|
||||
from skyvern.exceptions import TaskAlreadyCanceled
|
||||
|
||||
mock_update_task.side_effect = TaskAlreadyCanceled("new_status", task.task_id)
|
||||
|
||||
with patch(
|
||||
"skyvern.forge.agent.detect_user_defined_errors_for_task",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
result = await agent.fail_task(task, step, "Task failed", mock_browser_state)
|
||||
|
||||
# Should return False
|
||||
assert result is False
|
||||
|
||||
# Error detection should not be called
|
||||
mock_detect.assert_not_called()
|
||||
Loading…
Add table
Add a link
Reference in a new issue