trim svg elements when prompt exceeds context window ()

This commit is contained in:
Shuchang Zheng 2025-04-04 22:33:52 -04:00 committed by GitHub
parent 5e427fc401
commit 3c612968ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 126 additions and 14 deletions

2
poetry.lock generated
View file

@ -6521,4 +6521,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11,<3.12"
content-hash = "b43cb55e0c18ac83f0e32444132fd7618ef5b8355b0a90dbed55599d068c2892"
content-hash = "84b211a2b313b852996823fc4105d809b990e34cecd400c61d541561c010afdf"

View file

@ -54,6 +54,7 @@ json-repair = "^0.34.0"
pypdf = "^5.1.0"
fastmcp = "^0.4.1"
psutil = ">=7.0.0"
tiktoken = ">=0.9.0"
[tool.poetry.group.dev.dependencies]
isort = "^5.13.2"

View file

@ -68,6 +68,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions.actions import (
Action,
ActionStatus,
@ -1196,11 +1197,12 @@ class ForgeAgent:
)
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False)
verification_prompt = prompt_engine.load_prompt(
"check-user-goal",
verification_prompt = load_prompt_with_elements(
scraped_page=scraped_page_refreshed,
prompt_engine=prompt_engine,
template_name="check-user-goal",
navigation_goal=task.navigation_goal,
navigation_payload=task.navigation_payload,
elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML),
complete_criterion=task.complete_criterion,
)
@ -1432,7 +1434,7 @@ class ForgeAgent:
task,
step,
browser_state,
element_tree_in_prompt,
scraped_page,
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
expire_verification_code=True,
)
@ -1470,7 +1472,7 @@ class ForgeAgent:
task: Task,
step: Step,
browser_state: BrowserState,
element_tree_in_prompt: str,
scraped_page: ScrapedPage,
verification_code_check: bool = False,
expire_verification_code: bool = False,
) -> str:
@ -1525,13 +1527,14 @@ class ForgeAgent:
raise UnsupportedTaskType(task_type=task_type)
context = skyvern_context.ensure_context()
return prompt_engine.load_prompt(
template=template,
return load_prompt_with_elements(
scraped_page=scraped_page,
prompt_engine=prompt_engine,
template_name=template,
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
elements=element_tree_in_prompt,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
@ -2300,12 +2303,11 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = verification_code
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
extract_action_prompt = await self._build_extract_action_prompt(
task,
step,
browser_state,
element_tree_in_prompt,
scraped_page,
verification_code_check=False,
expire_verification_code=True,
)

View file

@ -139,7 +139,9 @@ async def _convert_svg_to_string(
skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element)
_, blocked = await skyvern_frame.get_blocking_element_id(await skyvern_element.get_element_handler())
_, blocked = await skyvern_frame.get_blocking_element_id(
await skyvern_element.get_element_handler(timeout=1000)
)
if not skyvern_element.is_interactable() and blocked:
_mark_element_as_dropped(element)
return

View file

@ -53,6 +53,7 @@ from skyvern.forge.sdk.workflow.models.yaml import (
WorkflowDefinitionYAML,
)
from skyvern.schemas.runs import ProxyLocation, RunType
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
from skyvern.webeye.utils.page import SkyvernFrame
@ -462,10 +463,11 @@ async def run_task_v2_helper(
continue
current_url = current_url if current_url else str(await SkyvernFrame.get_url(frame=page) if page else url)
task_v2_prompt = prompt_engine.load_prompt(
task_v2_prompt = load_prompt_with_elements(
scraped_page,
prompt_engine,
"task_v2",
current_url=current_url,
elements=element_tree_in_prompt,
user_goal=user_prompt,
task_history=task_history,
local_datetime=datetime.now(context.tz_info).isoformat(),

View file

@ -0,0 +1,47 @@
from typing import Any
import structlog
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.utils.token_counter import count_tokens
from skyvern.webeye.scraper.scraper import ScrapedPage
DEFAULT_MAX_TOKENS = 100000
LOG = structlog.get_logger()
def load_prompt_with_elements(
scraped_page: ScrapedPage,
prompt_engine: PromptEngine,
template_name: str,
**kwargs: Any,
) -> str:
prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs)
token_count = count_tokens(prompt)
if token_count > DEFAULT_MAX_TOKENS:
# get rid of all the secondary elements like SVG, etc
economy_elements_tree = scraped_page.build_economy_elements_tree()
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
economy_token_count = count_tokens(prompt)
LOG.warning(
"Prompt is longer than the max tokens. Going to use the economy elements tree.",
template_name=template_name,
token_count=token_count,
economy_token_count=economy_token_count,
max_tokens=DEFAULT_MAX_TOKENS,
)
if economy_token_count > DEFAULT_MAX_TOKENS:
# !!! HACK alert
# dump the last 1/3 of the html context and keep the first 2/3 of the html context
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3)
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs)
token_count_after_dump = count_tokens(prompt)
LOG.warning(
"Prompt is still longer than the max tokens. Will only keep the first 2/3 of the html context.",
template_name=template_name,
token_count=token_count,
economy_token_count=economy_token_count,
token_count_after_dump=token_count_after_dump,
max_tokens=DEFAULT_MAX_TOKENS,
)
return prompt

View file

@ -0,0 +1,5 @@
import tiktoken
def count_tokens(text: str) -> int:
return len(tiktoken.encoding_for_model("gpt-4o").encode(text))

View file

@ -229,6 +229,7 @@ class ScrapedPage(BaseModel):
hash_to_element_ids: dict[str, list[str]]
element_tree: list[dict]
element_tree_trimmed: list[dict]
economy_element_tree: list[dict] | None = None
screenshots: list[bytes]
url: str
html: str
@ -268,6 +269,58 @@ class ScrapedPage(BaseModel):
raise UnknownElementTreeFormat(fmt=fmt)
def build_economy_elements_tree(
self,
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
html_need_skyvern_attrs: bool = True,
percent_to_keep: float = 1,
) -> str:
"""
Economy elements tree doesn't include secondary elements like SVG, etc
"""
if not self.economy_element_tree:
economy_elements = []
copied_element_tree_trimmed = copy.deepcopy(self.element_tree_trimmed)
# Process each root element
for root_element in copied_element_tree_trimmed:
processed_element = self._process_element_for_economy_tree(root_element)
if processed_element:
economy_elements.append(processed_element)
self.economy_element_tree = economy_elements
final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)]
if fmt == ElementTreeFormat.JSON:
return json.dumps(final_element_tree)
if fmt == ElementTreeFormat.HTML:
return "".join(
json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs) for element in final_element_tree
)
raise UnknownElementTreeFormat(fmt=fmt)
def _process_element_for_economy_tree(self, element: dict) -> dict | None:
"""
Helper method to process an element for the economy tree using BFS.
Removes SVG elements and their children.
"""
# Skip SVG elements entirely
if element.get("tagName", "").lower() == "svg":
return None
# Process children using BFS
if "children" in element:
new_children = []
for child in element["children"]:
processed_child = self._process_element_for_economy_tree(child)
if processed_child:
new_children.append(processed_child)
element["children"] = new_children
return element
async def refresh(self, draw_boxes: bool = True) -> Self:
refreshed_page = await scrape_website(
browser_state=self._browser_state,