extend select agent to support date picker (#2849)

This commit is contained in:
Shuchang Zheng 2025-07-01 14:12:39 +09:00 committed by GitHub
parent 7a96642c12
commit cb17dbbb6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 140 additions and 60 deletions

View file

@ -463,6 +463,11 @@ class FailToSelectByIndex(SkyvernException):
super().__init__(f"Failed to select by index. element_id={element_id}") super().__init__(f"Failed to select by index. element_id={element_id}")
class EmptyDomOrHtmlTree(SkyvernException):
def __init__(self) -> None:
super().__init__("Empty dom or html tree")
class OptionIndexOutOfBound(SkyvernException): class OptionIndexOutOfBound(SkyvernException):
def __init__(self, element_id: str): def __init__(self, element_id: str):
super().__init__(f"Option index is out of bound. element_id={element_id}") super().__init__(f"Option index is out of bound. element_id={element_id}")

View file

@ -1373,7 +1373,7 @@ class ForgeAgent:
reasoning = reasonings[0].summary[0].text if reasonings and reasonings[0].summary else None reasoning = reasonings[0].summary[0].text if reasonings and reasonings[0].summary else None
assistant_message = assistant_messages[0].content[0].text if assistant_messages else None assistant_message = assistant_messages[0].content[0].text if assistant_messages else None
skyvern_repsonse_prompt = load_prompt_with_elements( skyvern_repsonse_prompt = load_prompt_with_elements(
scraped_page=scraped_page, element_tree_builder=scraped_page,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="cua-answer-question", template_name="cua-answer-question",
navigation_goal=task.navigation_goal, navigation_goal=task.navigation_goal,
@ -1597,7 +1597,7 @@ class ForgeAgent:
actions_and_results_str = await self._get_action_results(task, current_step=step) actions_and_results_str = await self._get_action_results(task, current_step=step)
verification_prompt = load_prompt_with_elements( verification_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed, element_tree_builder=scraped_page_refreshed,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="check-user-goal", template_name="check-user-goal",
navigation_goal=task.navigation_goal, navigation_goal=task.navigation_goal,
@ -1974,7 +1974,7 @@ class ForgeAgent:
context = skyvern_context.ensure_context() context = skyvern_context.ensure_context()
return load_prompt_with_elements( return load_prompt_with_elements(
scraped_page=scraped_page, element_tree_builder=scraped_page,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name=template, template_name=template,
navigation_goal=navigation_goal, navigation_goal=navigation_goal,

View file

@ -1251,7 +1251,7 @@ async def _generate_extraction_task(
# extract the data # extract the data
context = skyvern_context.ensure_context() context = skyvern_context.ensure_context()
generate_extraction_task_prompt = load_prompt_with_elements( generate_extraction_task_prompt = load_prompt_with_elements(
scraped_page=scraped_page, element_tree_builder=scraped_page,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="task_v2_generate_extraction_task", template_name="task_v2_generate_extraction_task",
current_url=current_url, current_url=current_url,

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
from skyvern.constants import DEFAULT_MAX_TOKENS from skyvern.constants import DEFAULT_MAX_TOKENS
from skyvern.forge.sdk.prompting import PromptEngine from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.utils.token_counter import count_tokens from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.scraper.scraper import ScrapedPage from skyvern.webeye.scraper.scraper import ElementTreeBuilder
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -20,22 +20,26 @@ class CheckPhoneNumberFormatResponse(BaseModel):
recommended_phone_number: str | None recommended_phone_number: str | None
HTMLTreeStr = str
def load_prompt_with_elements( def load_prompt_with_elements(
scraped_page: ScrapedPage, element_tree_builder: ElementTreeBuilder,
prompt_engine: PromptEngine, prompt_engine: PromptEngine,
template_name: str, template_name: str,
html_need_skyvern_attrs: bool = True, html_need_skyvern_attrs: bool = True,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
elements = element_tree_builder.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs)
prompt = prompt_engine.load_prompt( prompt = prompt_engine.load_prompt(
template_name, template_name,
elements=scraped_page.build_element_tree(html_need_skyvern_attrs=html_need_skyvern_attrs), elements=elements,
**kwargs, **kwargs,
) )
token_count = count_tokens(prompt) token_count = count_tokens(prompt)
if token_count > DEFAULT_MAX_TOKENS: if token_count > DEFAULT_MAX_TOKENS and element_tree_builder.support_economy_elements_tree():
# get rid of all the secondary elements like SVG, etc # get rid of all the secondary elements like SVG, etc
economy_elements_tree = scraped_page.build_economy_elements_tree( economy_elements_tree = element_tree_builder.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs html_need_skyvern_attrs=html_need_skyvern_attrs
) )
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs) prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
@ -50,7 +54,7 @@ def load_prompt_with_elements(
if economy_token_count > DEFAULT_MAX_TOKENS: if economy_token_count > DEFAULT_MAX_TOKENS:
# !!! HACK alert # !!! HACK alert
# dump the last 1/3 of the html context and keep the first 2/3 of the html context # dump the last 1/3 of the html context and keep the first 2/3 of the html context
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree( economy_elements_tree_dumped = element_tree_builder.build_economy_elements_tree(
html_need_skyvern_attrs=html_need_skyvern_attrs, html_need_skyvern_attrs=html_need_skyvern_attrs,
percent_to_keep=2 / 3, percent_to_keep=2 / 3,
) )

View file

@ -91,6 +91,7 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess from skyvern.webeye.actions.responses import ActionAbort, ActionFailure, ActionResult, ActionSuccess
from skyvern.webeye.scraper.scraper import ( from skyvern.webeye.scraper.scraper import (
CleanupElementTreeFunc, CleanupElementTreeFunc,
ElementTreeBuilder,
IncrementalScrapePage, IncrementalScrapePage,
ScrapedPage, ScrapedPage,
hash_element, hash_element,
@ -1352,10 +1353,14 @@ async def handle_select_option_action(
step_id=step.step_id, step_id=step.step_id,
exc_info=True, exc_info=True,
) )
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step) return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
if not exist: if not exist:
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step) return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
if blocking_element is None: if blocking_element is None:
LOG.info( LOG.info(
@ -1373,11 +1378,13 @@ async def handle_select_option_action(
exc_info=True, exc_info=True,
) )
return await normal_select( return await normal_select(
action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
) )
if not exist or blocking_element is None: if not exist or blocking_element is None:
return await normal_select(action=action, skyvern_element=skyvern_element, dom=dom, task=task, step=step) return await normal_select(
action=action, skyvern_element=skyvern_element, builder=dom.scraped_page, task=task, step=step
)
LOG.info( LOG.info(
"<select> is blocked by another element, going to select on the blocking element", "<select> is blocked by another element, going to select on the blocking element",
task_id=task.task_id, task_id=task.task_id,
@ -2439,10 +2446,11 @@ async def sequentially_select_from_dropdown(
) )
return None return None
# TODO: only support the third-level dropdown selection now # TODO: only support the third-level dropdown selection now, but for date picker, we need to support more levels as it will move the month, year, etc.
MAX_SELECT_DEPTH = 3 MAX_SELECT_DEPTH = 5 if input_or_select_context.is_date_related else 3
values: list[str | None] = [] values: list[str | None] = []
select_history: list[CustomSingleSelectResult] = [] select_history: list[CustomSingleSelectResult] = []
single_select_result: CustomSingleSelectResult | None = None
check_filter_funcs: list[CheckFilterOutElementIDFunc] = [check_existed_but_not_option_element_in_dom_factory(dom)] check_filter_funcs: list[CheckFilterOutElementIDFunc] = [check_existed_but_not_option_element_in_dom_factory(dom)]
for i in range(MAX_SELECT_DEPTH): for i in range(MAX_SELECT_DEPTH):
@ -2465,39 +2473,6 @@ async def sequentially_select_from_dropdown(
# wait 1s until DOM finished updating # wait 1s until DOM finished updating
await asyncio.sleep(1) await asyncio.sleep(1)
# HACK: if agent took mini actions 2 times, stop executing the rest actions
# this is a hack to fix some date picker issues.
if input_or_select_context.is_date_related and i >= 1:
if skyvern_element.get_tag_name() == InteractiveElement.INPUT and action.option.label:
try:
LOG.info(
"Try to input the date directly",
step_id=step.step_id,
task_id=task.task_id,
)
await skyvern_element.input_sequentially(action.option.label)
result = CustomSingleSelectResult(skyvern_frame=skyvern_frame)
result.action_result = ActionSuccess()
return result
except Exception:
LOG.warning(
"Failed to input the date directly",
exc_info=True,
step_id=step.step_id,
task_id=task.task_id,
)
if single_select_result.action_result:
LOG.warning(
"It's a date picker, going to skip reamaining actions",
depth=i,
task_id=task.task_id,
step_id=step.step_id,
)
single_select_result.action_result.skip_remaining_actions = True
break
if await single_select_result.is_done(): if await single_select_result.is_done():
return single_select_result return single_select_result
@ -2580,6 +2555,31 @@ async def sequentially_select_from_dropdown(
if json_response.get("is_mini_goal_finished", False): if json_response.get("is_mini_goal_finished", False):
LOG.info("The user has finished the selection for the current opened dropdown", step_id=step.step_id) LOG.info("The user has finished the selection for the current opened dropdown", step_id=step.step_id)
return single_select_result return single_select_result
else:
if input_or_select_context.is_date_related:
if skyvern_element.get_tag_name() == InteractiveElement.INPUT and action.option.label:
try:
LOG.info(
"Try to input the date directly",
step_id=step.step_id,
task_id=task.task_id,
)
await skyvern_element.input_sequentially(action.option.label)
result = CustomSingleSelectResult(skyvern_frame=skyvern_frame)
result.action_result = ActionSuccess()
return result
except Exception:
LOG.warning(
"Failed to input the date directly",
exc_info=True,
step_id=step.step_id,
task_id=task.task_id,
)
if single_select_result and single_select_result.action_result:
single_select_result.action_result.skip_remaining_actions = True
return single_select_result
return select_history[-1] if len(select_history) > 0 else None return select_history[-1] if len(select_history) > 0 else None
@ -2640,7 +2640,7 @@ async def select_from_emerging_elements(
raise NoIncrementalElementFoundForCustomSelection(element_id=current_element_id) raise NoIncrementalElementFoundForCustomSelection(element_id=current_element_id)
prompt = load_prompt_with_elements( prompt = load_prompt_with_elements(
scraped_page=scraped_page_after_open, element_tree_builder=scraped_page_after_open,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="custom-select", template_name="custom-select",
is_date_related=options.is_date_related, is_date_related=options.is_date_related,
@ -2759,8 +2759,8 @@ async def select_from_dropdown(
trimmed_element_tree = await incremental_scraped.get_incremental_element_tree( trimmed_element_tree = await incremental_scraped.get_incremental_element_tree(
clean_and_remove_element_tree_factory(task=task, step=step, check_filter_funcs=check_filter_funcs), clean_and_remove_element_tree_factory(task=task, step=step, check_filter_funcs=check_filter_funcs),
) )
incremental_scraped.set_element_tree_trimmed(trimmed_element_tree)
html = incremental_scraped.build_html_tree(element_tree=trimmed_element_tree) html = incremental_scraped.build_element_tree(html_need_skyvern_attrs=True)
skyvern_context = ensure_context() skyvern_context = ensure_context()
prompt = prompt_engine.load_prompt( prompt = prompt_engine.load_prompt(
@ -2837,6 +2837,21 @@ async def select_from_dropdown(
try: try:
selected_element = await SkyvernElement.create_from_incremental(incremental_scraped, element_id) selected_element = await SkyvernElement.create_from_incremental(incremental_scraped, element_id)
# TODO Some popup dropdowns include <select> element, we only handle the <select> element now, to prevent infinite recursion. Need to support more types of dropdowns.
if selected_element.get_tag_name() == InteractiveElement.SELECT and value:
await selected_element.scroll_into_view()
action = SelectOptionAction(
reasoning=select_reason,
element_id=element_id,
option=SelectOption(label=value),
)
results = await normal_select(
action=action, skyvern_element=selected_element, task=task, step=step, builder=incremental_scraped
)
assert len(results) > 0
single_select_result.action_result = results[0]
return single_select_result
if await selected_element.get_attr("role") == "listbox": if await selected_element.get_attr("role") == "listbox":
single_select_result.action_result = ActionFailure( single_select_result.action_result = ActionFailure(
exception=InteractWithDropdownContainer(element_id=element_id) exception=InteractWithDropdownContainer(element_id=element_id)
@ -3193,9 +3208,9 @@ async def scroll_down_to_load_all_options(
async def normal_select( async def normal_select(
action: actions.SelectOptionAction, action: actions.SelectOptionAction,
skyvern_element: SkyvernElement, skyvern_element: SkyvernElement,
dom: DomUtil,
task: Task, task: Task,
step: Step, step: Step,
builder: ElementTreeBuilder,
) -> List[ActionResult]: ) -> List[ActionResult]:
try: try:
current_text = await skyvern_element.get_attr("selected") current_text = await skyvern_element.get_attr("selected")
@ -3209,7 +3224,7 @@ async def normal_select(
locator = skyvern_element.get_locator() locator = skyvern_element.get_locator()
prompt = load_prompt_with_elements( prompt = load_prompt_with_elements(
scraped_page=dom.scraped_page, element_tree_builder=builder,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="parse-input-or-select-context", template_name="parse-input-or-select-context",
action_reasoning=action.reasoning, action_reasoning=action.reasoning,
@ -3382,7 +3397,7 @@ async def extract_information_for_navigation_goal(
scraped_page_refreshed = await scraped_page.refresh() scraped_page_refreshed = await scraped_page.refresh()
context = ensure_context() context = ensure_context()
extract_information_prompt = load_prompt_with_elements( extract_information_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed, element_tree_builder=scraped_page_refreshed,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="extract-information", template_name="extract-information",
html_need_skyvern_attrs=False, html_need_skyvern_attrs=False,
@ -3572,7 +3587,7 @@ async def _get_input_or_select_context(
action: InputTextAction | SelectOptionAction | AbstractActionForContextParse, scraped_page: ScrapedPage, step: Step action: InputTextAction | SelectOptionAction | AbstractActionForContextParse, scraped_page: ScrapedPage, step: Step
) -> InputOrSelectContext: ) -> InputOrSelectContext:
prompt = load_prompt_with_elements( prompt = load_prompt_with_elements(
scraped_page=scraped_page, element_tree_builder=scraped_page,
prompt_engine=prompt_engine, prompt_engine=prompt_engine,
template_name="parse-input-or-select-context", template_name="parse-input-or-select-context",
action_reasoning=action.reasoning, action_reasoning=action.reasoning,

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import copy import copy
import json import json
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from enum import StrEnum from enum import StrEnum
from typing import Any, Awaitable, Callable, Self from typing import Any, Awaitable, Callable, Self
@ -212,7 +213,28 @@ class ElementTreeFormat(StrEnum):
HTML = "html" HTML = "html"
class ScrapedPage(BaseModel): class ElementTreeBuilder(ABC):
@abstractmethod
def support_economy_elements_tree(self) -> bool:
pass
@abstractmethod
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
pass
@abstractmethod
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
pass
class ScrapedPage(BaseModel, ElementTreeBuilder):
""" """
Scraped response from a webpage, including: Scraped response from a webpage, including:
1. List of elements 1. List of elements
@ -259,6 +281,9 @@ class ScrapedPage(BaseModel):
self._clean_up_func = clean_up_func self._clean_up_func = clean_up_func
self._scrape_exclude = scrape_exclude self._scrape_exclude = scrape_exclude
def support_economy_elements_tree(self) -> bool:
return True
def build_element_tree( def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str: ) -> str:
@ -675,7 +700,7 @@ async def get_interactable_element_tree(
return elements, element_tree return elements, element_tree
class IncrementalScrapePage: class IncrementalScrapePage(ElementTreeBuilder):
def __init__(self, skyvern_frame: SkyvernFrame) -> None: def __init__(self, skyvern_frame: SkyvernFrame) -> None:
self.id_to_element_dict: dict[str, dict] = dict() self.id_to_element_dict: dict[str, dict] = dict()
self.id_to_css_dict: dict[str, str] = dict() self.id_to_css_dict: dict[str, str] = dict()
@ -684,6 +709,9 @@ class IncrementalScrapePage:
self.element_tree_trimmed: list[dict] = list() self.element_tree_trimmed: list[dict] = list()
self.skyvern_frame = skyvern_frame self.skyvern_frame = skyvern_frame
def set_element_tree_trimmed(self, element_tree_trimmed: list[dict]) -> None:
self.element_tree_trimmed = element_tree_trimmed
def check_id_in_page(self, element_id: str) -> bool: def check_id_in_page(self, element_id: str) -> bool:
css_selector = self.id_to_css_dict.get(element_id, "") css_selector = self.id_to_css_dict.get(element_id, "")
if css_selector: if css_selector:
@ -798,8 +826,36 @@ class IncrementalScrapePage:
return locator return locator
return None return None
def build_html_tree(self, element_tree: list[dict] | None = None) -> str: def build_html_tree(self, element_tree: list[dict] | None = None, need_skyvern_attrs: bool = True) -> str:
return "".join([json_to_html(element) for element in (element_tree or self.element_tree_trimmed)]) return "".join(
[
json_to_html(element, need_skyvern_attrs=need_skyvern_attrs)
for element in (element_tree or self.element_tree_trimmed)
]
)
def support_economy_elements_tree(self) -> bool:
return False
def build_element_tree(
self, fmt: ElementTreeFormat = ElementTreeFormat.HTML, html_need_skyvern_attrs: bool = True
) -> str:
if fmt == ElementTreeFormat.HTML:
return self.build_html_tree(
element_tree=self.element_tree_trimmed, need_skyvern_attrs=html_need_skyvern_attrs
)
if fmt == ElementTreeFormat.JSON:
return json.dumps(self.element_tree_trimmed)
raise UnknownElementTreeFormat(fmt=fmt)
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
raise NotImplementedError("Not implemented")
def _should_keep_unique_id(element: dict) -> bool: def _should_keep_unique_id(element: dict) -> bool: