mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-04-25 17:09:10 +00:00
trim svg elements when prompt exceeds context window (#2106)
This commit is contained in:
parent
5e427fc401
commit
3c612968ce
8 changed files with 126 additions and 14 deletions
2
poetry.lock
generated
2
poetry.lock
generated
|
@ -6521,4 +6521,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11,<3.12"
|
||||
content-hash = "b43cb55e0c18ac83f0e32444132fd7618ef5b8355b0a90dbed55599d068c2892"
|
||||
content-hash = "84b211a2b313b852996823fc4105d809b990e34cecd400c61d541561c010afdf"
|
||||
|
|
|
@ -54,6 +54,7 @@ json-repair = "^0.34.0"
|
|||
pypdf = "^5.1.0"
|
||||
fastmcp = "^0.4.1"
|
||||
psutil = ">=7.0.0"
|
||||
tiktoken = ">=0.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
isort = "^5.13.2"
|
||||
|
|
|
@ -68,6 +68,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
|
|||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
from skyvern.webeye.actions.actions import (
|
||||
Action,
|
||||
ActionStatus,
|
||||
|
@ -1196,11 +1197,12 @@ class ForgeAgent:
|
|||
)
|
||||
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False)
|
||||
|
||||
verification_prompt = prompt_engine.load_prompt(
|
||||
"check-user-goal",
|
||||
verification_prompt = load_prompt_with_elements(
|
||||
scraped_page=scraped_page_refreshed,
|
||||
prompt_engine=prompt_engine,
|
||||
template_name="check-user-goal",
|
||||
navigation_goal=task.navigation_goal,
|
||||
navigation_payload=task.navigation_payload,
|
||||
elements=scraped_page_refreshed.build_element_tree(ElementTreeFormat.HTML),
|
||||
complete_criterion=task.complete_criterion,
|
||||
)
|
||||
|
||||
|
@ -1432,7 +1434,7 @@ class ForgeAgent:
|
|||
task,
|
||||
step,
|
||||
browser_state,
|
||||
element_tree_in_prompt,
|
||||
scraped_page,
|
||||
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
||||
expire_verification_code=True,
|
||||
)
|
||||
|
@ -1470,7 +1472,7 @@ class ForgeAgent:
|
|||
task: Task,
|
||||
step: Step,
|
||||
browser_state: BrowserState,
|
||||
element_tree_in_prompt: str,
|
||||
scraped_page: ScrapedPage,
|
||||
verification_code_check: bool = False,
|
||||
expire_verification_code: bool = False,
|
||||
) -> str:
|
||||
|
@ -1525,13 +1527,14 @@ class ForgeAgent:
|
|||
raise UnsupportedTaskType(task_type=task_type)
|
||||
|
||||
context = skyvern_context.ensure_context()
|
||||
return prompt_engine.load_prompt(
|
||||
template=template,
|
||||
return load_prompt_with_elements(
|
||||
scraped_page=scraped_page,
|
||||
prompt_engine=prompt_engine,
|
||||
template_name=template,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload_str=json.dumps(final_navigation_payload),
|
||||
starting_url=starting_url,
|
||||
current_url=current_url,
|
||||
elements=element_tree_in_prompt,
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
action_history=actions_and_results_str,
|
||||
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
||||
|
@ -2300,12 +2303,11 @@ class ForgeAgent:
|
|||
current_context = skyvern_context.ensure_context()
|
||||
current_context.totp_codes[task.task_id] = verification_code
|
||||
|
||||
element_tree_in_prompt: str = scraped_page.build_element_tree(ElementTreeFormat.HTML)
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
element_tree_in_prompt,
|
||||
scraped_page,
|
||||
verification_code_check=False,
|
||||
expire_verification_code=True,
|
||||
)
|
||||
|
|
|
@ -139,7 +139,9 @@ async def _convert_svg_to_string(
|
|||
|
||||
skyvern_element = SkyvernElement(locator=locater, frame=skyvern_frame.get_frame(), static_element=element)
|
||||
|
||||
_, blocked = await skyvern_frame.get_blocking_element_id(await skyvern_element.get_element_handler())
|
||||
_, blocked = await skyvern_frame.get_blocking_element_id(
|
||||
await skyvern_element.get_element_handler(timeout=1000)
|
||||
)
|
||||
if not skyvern_element.is_interactable() and blocked:
|
||||
_mark_element_as_dropped(element)
|
||||
return
|
||||
|
|
|
@ -53,6 +53,7 @@ from skyvern.forge.sdk.workflow.models.yaml import (
|
|||
WorkflowDefinitionYAML,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocation, RunType
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
||||
from skyvern.webeye.utils.page import SkyvernFrame
|
||||
|
@ -462,10 +463,11 @@ async def run_task_v2_helper(
|
|||
continue
|
||||
current_url = current_url if current_url else str(await SkyvernFrame.get_url(frame=page) if page else url)
|
||||
|
||||
task_v2_prompt = prompt_engine.load_prompt(
|
||||
task_v2_prompt = load_prompt_with_elements(
|
||||
scraped_page,
|
||||
prompt_engine,
|
||||
"task_v2",
|
||||
current_url=current_url,
|
||||
elements=element_tree_in_prompt,
|
||||
user_goal=user_prompt,
|
||||
task_history=task_history,
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
|
|
47
skyvern/utils/prompt_engine.py
Normal file
47
skyvern/utils/prompt_engine.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.utils.token_counter import count_tokens
|
||||
from skyvern.webeye.scraper.scraper import ScrapedPage
|
||||
|
||||
DEFAULT_MAX_TOKENS = 100000
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def load_prompt_with_elements(
|
||||
scraped_page: ScrapedPage,
|
||||
prompt_engine: PromptEngine,
|
||||
template_name: str,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
prompt = prompt_engine.load_prompt(template_name, elements=scraped_page.build_element_tree(), **kwargs)
|
||||
token_count = count_tokens(prompt)
|
||||
if token_count > DEFAULT_MAX_TOKENS:
|
||||
# get rid of all the secondary elements like SVG, etc
|
||||
economy_elements_tree = scraped_page.build_economy_elements_tree()
|
||||
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree, **kwargs)
|
||||
economy_token_count = count_tokens(prompt)
|
||||
LOG.warning(
|
||||
"Prompt is longer than the max tokens. Going to use the economy elements tree.",
|
||||
template_name=template_name,
|
||||
token_count=token_count,
|
||||
economy_token_count=economy_token_count,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
)
|
||||
if economy_token_count > DEFAULT_MAX_TOKENS:
|
||||
# !!! HACK alert
|
||||
# dump the last 1/3 of the html context and keep the first 2/3 of the html context
|
||||
economy_elements_tree_dumped = scraped_page.build_economy_elements_tree(percent_to_keep=2 / 3)
|
||||
prompt = prompt_engine.load_prompt(template_name, elements=economy_elements_tree_dumped, **kwargs)
|
||||
token_count_after_dump = count_tokens(prompt)
|
||||
LOG.warning(
|
||||
"Prompt is still longer than the max tokens. Will only keep the first 2/3 of the html context.",
|
||||
template_name=template_name,
|
||||
token_count=token_count,
|
||||
economy_token_count=economy_token_count,
|
||||
token_count_after_dump=token_count_after_dump,
|
||||
max_tokens=DEFAULT_MAX_TOKENS,
|
||||
)
|
||||
return prompt
|
5
skyvern/utils/token_counter.py
Normal file
5
skyvern/utils/token_counter.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
import tiktoken
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
return len(tiktoken.encoding_for_model("gpt-4o").encode(text))
|
|
@ -229,6 +229,7 @@ class ScrapedPage(BaseModel):
|
|||
hash_to_element_ids: dict[str, list[str]]
|
||||
element_tree: list[dict]
|
||||
element_tree_trimmed: list[dict]
|
||||
economy_element_tree: list[dict] | None = None
|
||||
screenshots: list[bytes]
|
||||
url: str
|
||||
html: str
|
||||
|
@ -268,6 +269,58 @@ class ScrapedPage(BaseModel):
|
|||
|
||||
raise UnknownElementTreeFormat(fmt=fmt)
|
||||
|
||||
def build_economy_elements_tree(
|
||||
self,
|
||||
fmt: ElementTreeFormat = ElementTreeFormat.HTML,
|
||||
html_need_skyvern_attrs: bool = True,
|
||||
percent_to_keep: float = 1,
|
||||
) -> str:
|
||||
"""
|
||||
Economy elements tree doesn't include secondary elements like SVG, etc
|
||||
"""
|
||||
if not self.economy_element_tree:
|
||||
economy_elements = []
|
||||
copied_element_tree_trimmed = copy.deepcopy(self.element_tree_trimmed)
|
||||
|
||||
# Process each root element
|
||||
for root_element in copied_element_tree_trimmed:
|
||||
processed_element = self._process_element_for_economy_tree(root_element)
|
||||
if processed_element:
|
||||
economy_elements.append(processed_element)
|
||||
|
||||
self.economy_element_tree = economy_elements
|
||||
|
||||
final_element_tree = self.economy_element_tree[: int(len(self.economy_element_tree) * percent_to_keep)]
|
||||
|
||||
if fmt == ElementTreeFormat.JSON:
|
||||
return json.dumps(final_element_tree)
|
||||
|
||||
if fmt == ElementTreeFormat.HTML:
|
||||
return "".join(
|
||||
json_to_html(element, need_skyvern_attrs=html_need_skyvern_attrs) for element in final_element_tree
|
||||
)
|
||||
|
||||
raise UnknownElementTreeFormat(fmt=fmt)
|
||||
|
||||
def _process_element_for_economy_tree(self, element: dict) -> dict | None:
|
||||
"""
|
||||
Helper method to process an element for the economy tree using BFS.
|
||||
Removes SVG elements and their children.
|
||||
"""
|
||||
# Skip SVG elements entirely
|
||||
if element.get("tagName", "").lower() == "svg":
|
||||
return None
|
||||
|
||||
# Process children using BFS
|
||||
if "children" in element:
|
||||
new_children = []
|
||||
for child in element["children"]:
|
||||
processed_child = self._process_element_for_economy_tree(child)
|
||||
if processed_child:
|
||||
new_children.append(processed_child)
|
||||
element["children"] = new_children
|
||||
return element
|
||||
|
||||
async def refresh(self, draw_boxes: bool = True) -> Self:
|
||||
refreshed_page = await scrape_website(
|
||||
browser_state=self._browser_state,
|
||||
|
|
Loading…
Add table
Reference in a new issue