extend select agent to support date picker (#2849)

This commit is contained in:
Shuchang Zheng 2025-07-01 14:12:39 +09:00 committed by GitHub
parent 7a96642c12
commit cb17dbbb6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 140 additions and 60 deletions

View file

@ -463,6 +463,11 @@ class FailToSelectByIndex(SkyvernException):
super().__init__(f"Failed to select by index. element_id={element_id}")
class EmptyDomOrHtmlTree(SkyvernException):
def __init__(self) -> None:
super().__init__("Empty dom or html tree")
class OptionIndexOutOfBound(SkyvernException):
def __init__(self, element_id: str):
super().__init__(f"Option index is out of bound. element_id={element_id}")

View file

@ -1373,7 +1373,7 @@ class ForgeAgent:
reasoning = reasonings[0].summary[0].text if reasonings and reasonings[0].summary else None
assistant_message = assistant_messages[0].content[0].text if assistant_messages else None
skyvern_repsonse_prompt = load_prompt_with_elements(
scraped_page=scraped_page,
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name="cua-answer-question",
navigation_goal=task.navigation_goal,
@ -1597,7 +1597,7 @@ class ForgeAgent:
actions_and_results_str = await self._get_action_results(task, current_step=step)
verification_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed,
element_tree_builder=scraped_page_refreshed,
prompt_engine=prompt_engine,
template_name="check-user-goal",
navigation_goal=task.navigation_goal,
@ -1974,7 +1974,7 @@ class ForgeAgent:
context = skyvern_context.ensure_context()
return load_prompt_with_elements(
scraped_page=scraped_page,
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name=template,
navigation_goal=navigation_goal,

View file

@ -1251,7 +1251,7 @@ async def _generate_extraction_task(
# extract the data
context = skyvern_context.ensure_context()
generate_extraction_task_prompt = load_prompt_with_elements(
scraped_page=scraped_page,
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name="task_v2_generate_extraction_task",
current_url=current_url,

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
from skyvern.constants import DEFAULT_MAX_TOKENS
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.scraper.scraper import ScrapedPage
from skyvern.webeye.scraper.scraper import ElementTreeBuilder
LOG = structlog.get_logger()
@ -20,22 +20,26 @@ class CheckPhoneNumberFormatResponse(BaseModel):
recommended_phone_number: str | None
HTMLTreeStr = str
def load_prompt_with_elements(
scraped_page: ScrapedPage,
element_tree_builder: ElementTreeBuilder,
prompt_engine: PromptEngine,
template_name: str,
html_need_skyvern_attrs: bool = True,
**kwargs: Any,
) -> str:
elements = element_tree_builder.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs)
prompt = prompt_engine.load_prompt(
template_name,
elements=scraped_page.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs),
elements=elements,
**kwargs,
)
token_count = count_tokens(prompt)
if token_count > DEFAULT_MAX_TOKENS:
if token_count > DEFAULT_MAX_TOKENS and element_tree_builder.support_economy_elements_tree():
# get rid of all the secondary elements like SVG, etc
economy_elements_tree = scraped_page.build_economy_elements_tree(
economy_elements_tree = element_tree_builder.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs
)
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
@ -50,7 +54,7 @@ def load_prompt_with_elements(
if economy_token_count > DEFAULT_MAX_TOKENS:
# !!! HACK alert
# dump the last 1/3 of the html context and keep the first 2/3 of the html context
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(
economy_elements_tree_dumped = element_tree_builder.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs,
percent_to_keep=2 / 3,
)

View file

@ -91,6 +91,7 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.scraper.scraper import (
CleanupElementTreeFunc,
ElementTreeBuilder,
IncrementalScrapePage,
ScrapedPage,
hash_element,
@ -1352,10 +1353,14 @@ async def handle_select_option_action(
step_id=step.step_id,
exc_info=True,
)
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step)
return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
if not exist:
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step)
return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
if blocking_element is None:
LOG.info(
@ -1373,11 +1378,13 @@ async def handle_select_option_action(
exc_info=True,
)
return await normal_select(
action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
if not exist or blocking_element is None:
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step)
return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
LOG.info(
"<select> is blocked by another element, going to select on the blocking element",
task_id=task.task_id,
@ -2439,10 +2446,11 @@ async def sequentially_select_from_dropdown(
)
return None
# TODO: only support the third-level dropdown selection now
MAX_SELECT_DEPTH = 3
# TODO: only support the third-level dropdown selection now, but for date picker, we need to support more levels as it will move the month, year, etc.
MAX_SELECT_DEPTH = 5 if input_or_select_context.is_date_related else 3
values: list[str | None] = []
select_history: list[CustomSingleSelectResult] = []
single_select_result: CustomSingleSelectResult | None = None
check_filter_funcs: list[CheckFilterOutElementIDFunc] = [check_existed_but_not_option_element_in_dom_factory(dom)]
for i in range(MAX_SELECT_DEPTH):
@ -2465,39 +2473,6 @@ async def sequentially_select_from_dropdown(
# wait 1s until DOM finished updating
await asyncio.sleep(1)
# HACK: if agent took mini actions 2 times, stop executing the rest actions
# this is a hack to fix some date picker issues.
if input_or_select_context.is_date_related and i >= 1:
if skyvern_element.get_tag_name() == InteractiveElement.INPUT and action.option.label:
try:
LOG.info(
"Try to input the date directly",
step_id=step.step_id,
task_id=task.task_id,
)
await skyvern_element.input_sequentially(action.option.label)
result = CustomSingleSelectResult(skyvern_frame=skyvern_frame)
result.action_result = ActionSuccess()
return result
except Exception:
LOG.warning(
"Failed to input the date directly",
exc_info=True,
step_id=step.step_id,
task_id=task.task_id,
)
if single_select_result.action_result:
LOG.warning(
"It's a date picker, going to skip reamaining actions",
depth=i,
task_id=task.task_id,
step_id=step.step_id,
)
single_select_result.action_result.skip_remaining_actions = True
break
if await single_select_result.is_done():
return single_select_result
@ -2580,6 +2555,31 @@ async def sequentially_select_from_dropdown(
if json_response.get("is_mini_goal_finished", False):
LOG.info("The user has finished the selection for the current opened dropdown", step_id=step.step_id)
return single_select_result
else:
if input_or_select_context.is_date_related:
if skyvern_element.get_tag_name() == InteractiveElement.INPUT and action.option.label:
try:
LOG.info(
"Try to input the date directly",
step_id=step.step_id,
task_id=task.task_id,
)
await skyvern_element.input_sequentially(action.option.label)
result = CustomSingleSelectResult(skyvern_frame=skyvern_frame)
result.action_result = ActionSuccess()
return result
except Exception:
LOG.warning(
"Failed to input the date directly",
exc_info=True,
step_id=step.step_id,
task_id=task.task_id,
)
if single_select_result and single_select_result.action_result:
single_select_result.action_result.skip_remaining_actions = True
return single_select_result
return select_history[-1] if len(select_history) > 0 else None
@ -2640,7 +2640,7 @@ async def select_from_emerging_elements(
raise NoIncrementalElementFoundForCustomSelection(element_id=current_element_id)
prompt = load_prompt_with_elements(
scraped_page=scraped_page_after_open,
element_tree_builder=scraped_page_after_open,
prompt_engine=prompt_engine,
template_name="custom-select",
is_date_related=options.is_date_related,
@ -2759,8 +2759,8 @@ async def select_from_dropdown(
trimmed_element_tree = await incremental_scraped.get_incremental_element_tree(
clean_and_remove_element_tree_factory(task=task, step=step, check_filter_funcs=check_filter_funcs),
)
html = incremental_scraped.build_html_tree(element_tree=trimmed_element_tree)
incremental_scraped.set_element_tree_trimmed(trimmed_element_tree)
html = incremental_scraped.build_element_tree(html_need_skyvern_attrs=True)
skyvern_context = ensure_context()
prompt = prompt_engine.load_prompt(
@ -2837,6 +2837,21 @@ async def select_from_dropdown(
try:
selected_element = await SkyvernElement.create_from_incremental(incremental_scraped, element_id)
# TODO Some popup dropdowns include <select> element, we only handle the <select> element now, to prevent infinite recursion. Need to support more types of dropdowns.
if selected_element.get_tag_name() == InteractiveElement.SELECT and value:
await selected_element.scroll_into_view()
action = SelectOptionAction(
reasoning=select_reason,
element_id=element_id,
option=SelectOption(label=value),
)
results = await normal_select(
action=action, skyvern_element=selected_element, task=task, step=step, builder=incremental_scraped
)
assert len(results) > 0
single_select_result.action_result = results[0]
return single_select_result
if await selected_element.get_attr("role") == "listbox":
single_select_result.action_result = ActionFailure(
exception=InteractWithDropdownContainer(element_id=element_id)
@ -3193,9 +3208,9 @@ async def scroll_down_to_load_all_options(
async def normal_select(
action: actions.SelectOptionAction,
skyvern_element: SkyvernElement,
dom: DomUtil,
task: Task,
step: Step,
builder: ElementTreeBuilder,
) -> List[ActionResult]:
try:
current_text = await skyvern_element.get_attr("selected")
@ -3209,7 +3224,7 @@ async def normal_select(
locator = skyvern_element.get_locator()
prompt = load_prompt_with_elements(
scraped_page=dom.scraped_page,
element_tree_builder=builder,
prompt_engine=prompt_engine,
template_name="parse-input-or-select-context",
action_reasoning=action.reasoning,
@ -3382,7 +3397,7 @@ async def extract_information_for_navigation_goal(
scraped_page_refreshed = await scraped_page.refresh()
context = ensure_context()
extract_information_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed,
element_tree_builder=scraped_page_refreshed,
prompt_engine=prompt_engine,
template_name="extract-information",
html_need_skyvern_attrs=False,
@ -3572,7 +3587,7 @@ async def _get_input_or_select_context(
action: InputTextAction | SelectOptionAction | AbstractActionForContextParse, scraped_page: ScrapedPage, step: Step
) -> InputOrSelectContext:
prompt = load_prompt_with_elements(
scraped_page=scraped_page,
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name="parse-input-or-select-context",
action_reasoning=action.reasoning,

View file

@ -1,6 +1,7 @@
import asyncio
import copy
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import StrEnum
from typing import Any, Awaitable, Callable, Self
@ -212,7 +213,28 @@ class ElementTreeFormat(StrEnum):
HTML = "html"
class ScrapedPage(BaseModel):
class ElementTreeBuilder(ABC):
@abstractmethod
def support_economy_elements_tree(self) -> bool:
pass
@abstractmethod
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
pass
@abstractmethod
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
pass
class ScrapedPage(BaseModel, ElementTreeBuilder):
"""
Scraped response from a webpage, including:
1. List of elements
@ -259,6 +281,9 @@ class ScrapedPage(BaseModel):
self._clean_up_func = clean_up_func
self._scrape_exclude = scrape_exclude
def support_economy_elements_tree(self) -> bool:
return True
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
@ -675,7 +700,7 @@ async def get_interactable_element_tree(
return elements, element_tree
class IncrementalScrapePage:
class IncrementalScrapePage(ElementTreeBuilder):
def __init__(self, skyvern_frame: SkyvernFrame) -> None:
self.id_to_element_dict: dict[str, dict] = dict()
self.id_to_css_dict: dict[str, str] = dict()
@ -684,6 +709,9 @@ class IncrementalScrapePage:
self.element_tree_trimmed: list[dict] = list()
self.skyvern_frame = skyvern_frame
def set_element_tree_trimmed(self, element_tree_trimmed: list[dict]) -> None:
self.element_tree_trimmed = element_tree_trimmed
def check_id_in_page(self, element_id: str) -> bool:
css_selector = self.id_to_css_dict.get(element_id, "")
if css_selector:
@ -798,8 +826,36 @@ class IncrementalScrapePage:
return locator
return None
def build_html_tree(self, element_tree: list[dict] | None = None) -> str:
return "".join([json_to_html(element) for element in (element_tree or self.element_tree_trimmed)])
def build_html_tree(self, element_tree: list[dict] | None = None, need_skyvern_attrs: bool = True) -> str:
return "".join(
[
json_to_html(element, need_skyvern_attrs=need_skyvern_attrs)
for element in (element_tree or self.element_tree_trimmed)
]
)
def support_economy_elements_tree(self) -> bool:
return False
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
if fmt == ElementTreeFormat.HTML:
return self.build_html_tree(
element_tree=self.element_tree_trimmed, need_skyvern_attrs=html_need_skyvern_attrs
)
if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed)
raise UnknownElementTreeFormat(fmt=fmt)
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
raise NotImplementedError("Not implemented")
def _should_keep_unique_id(element: dict) -> bool: