Skyvern/skyvern/cli/core/browser_ops.py
Marc Kelechava f422c37d46
Some checks are pending
Run tests and pre-commit / Run tests and pre-commit hooks (push) Waiting to run
Run tests and pre-commit / Frontend Lint and Build (push) Waiting to run
Publish Fern Docs / run (push) Waiting to run
feat: MCP observe/execute batch tools + AI speed optimizations (#5386)
2026-04-05 16:38:13 -07:00

784 lines
22 KiB
Python

"""Shared browser operations for MCP tools and CLI commands.
Each function: validate inputs -> call SDK -> return typed result.
Session resolution and output formatting are caller responsibilities.
"""
from __future__ import annotations
import json
import os
import re
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
from .guards import GuardError
@dataclass
class NavigateResult:
url: str
title: str
@dataclass
class ScreenshotResult:
data: bytes
full_page: bool = False
@dataclass
class ActResult:
prompt: str
completed: bool = True
@dataclass
class ExtractResult:
extracted: Any = None
def parse_extract_schema(schema: str | dict[str, Any] | None) -> dict[str, Any] | None:
"""Parse and validate an extraction schema payload."""
if schema is None:
return None
if isinstance(schema, dict):
return schema
try:
return json.loads(schema)
except (json.JSONDecodeError, TypeError) as e:
raise GuardError(f"Invalid JSON schema: {e}", "Provide schema as a valid JSON string")
async def do_navigate(
page: Any,
url: str,
timeout: int = 30000,
wait_until: str | None = None,
) -> NavigateResult:
await page.goto(url, timeout=timeout, wait_until=wait_until)
return NavigateResult(url=page.url, title=await page.title())
async def do_screenshot(
page: Any,
full_page: bool = False,
selector: str | None = None,
) -> ScreenshotResult:
if selector:
element = page.locator(selector)
data = await element.screenshot()
else:
data = await page.screenshot(full_page=full_page)
return ScreenshotResult(data=data, full_page=full_page)
async def do_act(
page: Any,
prompt: str,
skip_refresh: bool = False,
use_economy_tree: bool = False,
) -> ActResult:
await page.act(prompt, skip_refresh=skip_refresh, use_economy_tree=use_economy_tree)
return ActResult(prompt=prompt, completed=True)
async def do_extract(
page: Any,
prompt: str,
schema: str | dict[str, Any] | None = None,
skip_refresh: bool = False,
) -> ExtractResult:
parsed_schema = parse_extract_schema(schema)
extracted = await page.extract(prompt=prompt, schema=parsed_schema, skip_refresh=skip_refresh)
return ExtractResult(extracted=extracted)
# -- Semantic locators --
@dataclass
class FindResult:
selector: str
count: int
first_text: str | None
first_visible: bool
locator_map: dict[str, str] = {
"role": "get_by_role",
"text": "get_by_text",
"label": "get_by_label",
"placeholder": "get_by_placeholder",
"alt": "get_by_alt_text",
"testid": "get_by_test_id",
}
LOCATOR_TYPES = frozenset(locator_map.keys())
async def do_find(page: Any, by: str, value: str) -> FindResult:
"""Locate elements using Playwright's semantic locator API."""
if by not in locator_map:
raise GuardError(
f"Invalid locator type: {by!r}. Must be one of: {', '.join(sorted(LOCATOR_TYPES))}",
f"Use one of: {', '.join(sorted(LOCATOR_TYPES))}",
)
locator = getattr(page, locator_map[by])(value)
count = await locator.count()
first_text = await locator.first.text_content() if count > 0 else None
first_visible = await locator.first.is_visible() if count > 0 else False
return FindResult(
selector=f"{locator_map[by]}({value!r})",
count=count,
first_text=first_text,
first_visible=first_visible,
)
# -- Frame operations --
@dataclass
class FrameInfo:
index: int
name: str
url: str
is_main: bool
@dataclass
class FrameSwitchResult:
name: str | None
url: str | None
selector: str | None = None
requested_name: str | None = None
index: int | None = None
async def do_frame_switch(
page: Any,
*,
selector: str | None = None,
name: str | None = None,
index: int | None = None,
) -> FrameSwitchResult:
result = await page.frame_switch(selector=selector, name=name, index=index)
return FrameSwitchResult(
name=result.get("name"),
url=result.get("url"),
selector=selector,
requested_name=name,
index=index,
)
def do_frame_main(page: Any) -> None:
page.frame_main()
async def do_frame_list(page: Any) -> list[FrameInfo]:
frames = await page.frame_list()
return [FrameInfo(index=f["index"], name=f["name"], url=f["url"], is_main=f["is_main"]) for f in frames]
# -- Auth state persistence --
@dataclass
class StateSaveResult:
file_path: str
cookie_count: int
local_storage_count: int
session_storage_count: int
url: str
@dataclass
class StateLoadResult:
cookie_count: int
local_storage_count: int
session_storage_count: int
source_url: str
skipped_cookies: int
def _cookie_domain_matches(cookie_domain: str, page_domain: str) -> bool:
"""Check if a cookie's domain matches the current page domain per RFC 6265.
Handles leading dots (wildcard subdomains).
Rejects suffix attacks: 'evil-example.com' must NOT match 'example.com'.
"""
if not cookie_domain or not page_domain:
return False
cd = cookie_domain.lstrip(".")
if not cd:
return False
return page_domain == cd or page_domain.endswith("." + cd)
async def do_state_save(page: Any, browser: Any, file_path: Path) -> StateSaveResult:
"""Save browser auth state to a JSON file.
``page`` is the raw Playwright Page (not SkyvernBrowserPage).
``browser`` is a SkyvernBrowser — cookies accessed via ``browser._browser_context``.
"""
pw_context = browser._browser_context
cookies = await pw_context.cookies()
local_storage = await page.evaluate("() => Object.fromEntries(Object.entries(window.localStorage))")
session_storage = await page.evaluate("() => Object.fromEntries(Object.entries(window.sessionStorage))")
state = {
"version": 1,
"url": page.url,
"timestamp": datetime.now(timezone.utc).isoformat(),
"cookies": cookies,
"local_storage": local_storage,
"session_storage": session_storage,
}
file_path.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(str(file_path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w") as f:
json.dump(state, f, indent=2)
return StateSaveResult(
file_path=str(file_path),
cookie_count=len(cookies),
local_storage_count=len(local_storage),
session_storage_count=len(session_storage),
url=page.url,
)
async def do_state_load(
page: Any,
browser: Any,
file_path: Path,
current_domain: str,
) -> StateLoadResult:
"""Load browser auth state from a JSON file.
Validates JSON schema version. Filters cookies to only apply those matching
``current_domain`` to prevent cross-domain session injection.
"""
raw = file_path.read_text()
state = json.loads(raw)
if state.get("version") != 1:
raise ValueError(f"Unsupported state file version: {state.get('version')}")
pw_context = browser._browser_context
all_cookies = state.get("cookies", [])
safe_cookies = [c for c in all_cookies if _cookie_domain_matches(c.get("domain", ""), current_domain)]
skipped = len(all_cookies) - len(safe_cookies)
if safe_cookies:
await pw_context.add_cookies(safe_cookies)
local_storage = state.get("local_storage", {})
for k, v in local_storage.items():
await page.evaluate(
"(args) => window.localStorage.setItem(args[0], args[1])",
[k, v],
)
session_storage = state.get("session_storage", {})
for k, v in session_storage.items():
await page.evaluate(
"(args) => window.sessionStorage.setItem(args[0], args[1])",
[k, v],
)
return StateLoadResult(
cookie_count=len(safe_cookies),
local_storage_count=len(local_storage),
session_storage_count=len(session_storage),
source_url=state.get("url", ""),
skipped_cookies=skipped,
)
# -- DOM inspection --
async def do_get_html(page: Any, selector: str, outer: bool = False) -> str:
"""Get innerHTML or outerHTML from an element. ``page`` is raw Playwright Page."""
prop = "outerHTML" if outer else "innerHTML"
return await page.locator(selector).evaluate(f"el => el.{prop}")
async def do_get_value(page: Any, selector: str) -> str | None:
"""Get the current value of a form input element."""
return await page.locator(selector).input_value()
async def do_get_styles(page: Any, selector: str, properties: list[str] | None = None) -> dict[str, str]:
"""Get computed CSS styles from an element."""
if properties is not None:
if not properties:
return {}
return await page.locator(selector).evaluate(
"""(el, props) => {
const styles = window.getComputedStyle(el);
return Object.fromEntries(props.map(p => [p, styles.getPropertyValue(p)]));
}""",
properties,
)
return await page.locator(selector).evaluate(
"""el => {
const styles = window.getComputedStyle(el);
const result = {};
for (let i = 0; i < Math.min(styles.length, 100); i++) {
result[styles[i]] = styles.getPropertyValue(styles[i]);
}
return result;
}"""
)
# ---------------------------------------------------------------------------
# Network operations
# ---------------------------------------------------------------------------
# Fields stripped from list view to reduce payload. The detail tool returns
# the full entry dict (including these fields) via do_network_request_detail.
_LIST_STRIP_KEYS = frozenset({"response_headers"})
@dataclass
class NetworkRequestsResult:
requests: list[dict[str, Any]]
count: int
error: dict[str, Any] | None = None
@dataclass
class NetworkRequestDetailResult:
request: dict[str, Any] | None = None
body: str | None = None
found: bool = False
@dataclass
class NetworkRouteResult:
url_pattern: str = ""
action: str = ""
active_routes: list[str] = field(default_factory=list)
@dataclass
class NetworkUnrouteResult:
url_pattern: str = ""
removed: bool = False
active_routes: list[str] = field(default_factory=list)
def do_network_requests(
state: Any,
*,
url_pattern: str | None = None,
status_code: int | None = None,
method: str | None = None,
resource_type: str | None = None,
) -> NetworkRequestsResult:
"""Filter and return network request entries from state. Sync — no Playwright calls."""
entries = list(state.network_requests)
if url_pattern:
try:
compiled = re.compile(url_pattern)
entries = [e for e in entries if compiled.search(e.get("url", ""))]
except re.error:
from .result import ErrorCode, make_error
return NetworkRequestsResult(
requests=[],
count=0,
error=make_error(
ErrorCode.INVALID_INPUT,
f"Invalid regex pattern: {url_pattern}",
"Provide a valid Python regex pattern",
),
)
if status_code is not None:
entries = [e for e in entries if e.get("status") == status_code]
if method:
method_upper = method.upper()
entries = [e for e in entries if e.get("method") == method_upper]
if resource_type:
rt_lower = resource_type.lower()
entries = [e for e in entries if e.get("resource_type", "").lower() == rt_lower]
# Strip heavy fields for list view
display = [{k: v for k, v in e.items() if k not in _LIST_STRIP_KEYS} for e in entries]
return NetworkRequestsResult(requests=display, count=len(display))
def do_network_request_detail(state: Any, request_id: int) -> NetworkRequestDetailResult:
"""Look up a single request by ID and return full metadata + body."""
for entry in state.network_requests:
if entry.get("request_id") == request_id:
body = state.get_response_body(request_id)
return NetworkRequestDetailResult(request=dict(entry), body=body, found=True)
return NetworkRequestDetailResult()
async def do_network_route(
raw_page: Any,
state: Any,
*,
url_pattern: str,
action: Literal["abort", "mock"],
mock_status: int = 200,
mock_body: str | None = None,
mock_content_type: str | None = None,
) -> NetworkRouteResult:
"""Register a route handler on the Playwright page and track in SessionState."""
async def _handler(route: Any) -> None:
try:
if action == "abort":
await route.abort()
elif action == "mock":
headers: dict[str, str] = {}
if mock_content_type:
headers["content-type"] = mock_content_type
elif mock_body is not None:
headers["content-type"] = "application/json"
await route.fulfill(
status=mock_status,
headers=headers if headers else None,
body=mock_body or "",
)
else:
await route.abort()
except Exception:
try:
await route.abort()
except Exception:
pass
page_id = id(raw_page)
page_routes = state.active_routes.setdefault(page_id, set())
# Re-register: unroute existing handler for this pattern first
if url_pattern in page_routes:
try:
await raw_page.unroute(url_pattern)
page_routes.discard(url_pattern)
except Exception:
pass
await raw_page.route(url_pattern, _handler)
page_routes.add(url_pattern)
return NetworkRouteResult(
url_pattern=url_pattern,
action=action,
active_routes=sorted(page_routes),
)
async def do_network_unroute(raw_page: Any, state: Any, url_pattern: str) -> NetworkUnrouteResult:
"""Remove a route handler and update SessionState tracking."""
page_id = id(raw_page)
page_routes = state.active_routes.get(page_id, set())
removed = url_pattern in page_routes
if removed:
await raw_page.unroute(url_pattern)
page_routes.discard(url_pattern)
return NetworkUnrouteResult(
url_pattern=url_pattern,
removed=removed,
active_routes=sorted(page_routes),
)
# ---------------------------------------------------------------------------
# Observe — scoped accessibility tree snapshot with stable refs
# ---------------------------------------------------------------------------
INTERACTIVE_ROLES = frozenset(
{
"button",
"checkbox",
"combobox",
"link",
"listbox",
"menuitem",
"menuitemcheckbox",
"menuitemradio",
"option",
"radio",
"searchbox",
"slider",
"spinbutton",
"switch",
"tab",
"textbox",
"treeitem",
}
)
_ROLE_TO_TAG: dict[str, str] = {
"textbox": "input",
"searchbox": "input",
"checkbox": "input",
"radio": "input",
"slider": "input",
"spinbutton": "input",
"switch": "input",
"button": "button",
"link": "a",
"combobox": "select",
"listbox": "select",
"option": "option",
"tab": "button",
"menuitem": "li",
"menuitemcheckbox": "li",
"menuitemradio": "li",
"treeitem": "li",
}
_PASSWORD_NAME_RE = re.compile(
r"\bpass(?:word|phrase|code)s?\b|\bsecret\b|\btoken\b|\bcredential\b|\bpwd\b|\bpasswd\b|\bpin\b",
re.IGNORECASE,
)
# Structural fields always kept in serialized output; display fields filtered if empty.
_ELEMENT_KEEP_ALWAYS = frozenset({"ref", "role"})
@dataclass
class ObservedElement:
ref: str
role: str
name: str
tag: str
value: str | None = None
options: list[str] | None = None
@dataclass
class ObserveResult:
url: str
title: str
elements: list[ObservedElement]
element_count: int
total_on_page: int
def _flatten_a11y_tree(node: dict[str, Any] | None) -> list[dict[str, Any]]:
"""Recursively flatten an accessibility tree into a flat element list."""
if node is None:
return []
result: list[dict[str, Any]] = []
if node.get("role") and node["role"] != "WebArea":
result.append(node)
for child in node.get("children", []):
result.extend(_flatten_a11y_tree(child))
return result
def _is_password_field(role: str, name: str) -> bool:
"""DESIGN-2: Detect password-type fields for value redaction."""
if _PASSWORD_NAME_RE.search(name):
return True
return role == "textbox" and "password" in name.lower()
def _extract_options(node: dict[str, Any]) -> list[str] | None:
"""Extract option labels from combobox/listbox children."""
children = node.get("children")
if not children:
return None
opts = [c.get("name", "") for c in children if c.get("role") == "option"]
return opts if opts else None
async def do_observe(
page: Any,
selector: str | None = None,
interactive_only: bool = True,
max_elements: int = 50,
) -> ObserveResult:
"""Capture interactive elements with stable refs for batch operations."""
if selector:
element_handle = await page.locator(selector).first.element_handle()
snapshot = await page.accessibility.snapshot(root=element_handle)
else:
snapshot = await page.accessibility.snapshot()
all_elements = _flatten_a11y_tree(snapshot)
if interactive_only:
all_elements = [e for e in all_elements if e.get("role") in INTERACTIVE_ROLES]
total = len(all_elements)
capped = all_elements[:max_elements]
observed: list[ObservedElement] = []
for i, elem in enumerate(capped):
role = elem.get("role", "")
name = elem.get("name", "")
value = elem.get("value")
# DESIGN-2: Redact password field values
if value and _is_password_field(role, name):
value = "***"
observed.append(
ObservedElement(
ref=f"e{i}",
role=role,
name=name,
tag=_ROLE_TO_TAG.get(role, ""),
value=value,
options=_extract_options(elem),
)
)
return ObserveResult(
url=page.url,
title=await page.title(),
elements=observed,
element_count=len(observed),
total_on_page=total,
)
def serialize_elements(elements: list[ObservedElement]) -> list[dict[str, Any]]:
"""Serialize observed elements to dicts, filtering empty display fields."""
return [
{
k: v
for k, v in {
"ref": e.ref,
"role": e.role,
"name": e.name,
"tag": e.tag,
"value": e.value,
"options": e.options,
}.items()
if k in _ELEMENT_KEEP_ALWAYS or (v is not None and v != "")
}
for e in elements
]
def ref_to_selector(elem: dict[str, Any]) -> str:
"""Convert an observed element's a11y data to a Playwright role selector."""
role = elem.get("role", "")
name = elem.get("name", "")
if name:
escaped = name.replace('"', '\\"')
return f'role={role}[name="{escaped}"]'
return f"role={role}"
# ---------------------------------------------------------------------------
# Execute — batch multi-step execution with ref threading
# ---------------------------------------------------------------------------
# Tools that are blocked after a failed navigate step (DESIGN-3)
_SENSITIVE_TOOLS = frozenset({"type", "evaluate"})
_ALLOWED_EXECUTE_TOOLS = frozenset(
{
"navigate",
"click",
"type",
"press_key",
"select_option",
"hover",
"scroll",
"wait",
"observe",
"screenshot",
"evaluate",
}
)
MAX_EXECUTE_STEPS = 20
@dataclass
class ExecuteStep:
tool: str
params: dict[str, Any] = field(default_factory=dict)
@dataclass
class StepResult:
step: int
tool: str
ok: bool
wall_ms: int = 0
data: dict[str, Any] | None = None
error: str | None = None
@dataclass
class ExecuteResult:
steps_completed: int
steps_total: int
results: list[StepResult]
error_step: int | None
async def do_execute(
dispatch_fn: Any,
steps: list[ExecuteStep],
stop_on_error: bool = True,
) -> ExecuteResult:
"""Execute a sequence of deterministic browser operations in one batch.
dispatch_fn: async callable(step, ref_map) -> dict with tool result
"""
results: list[StepResult] = []
ref_map: dict[str, dict[str, Any]] = {}
nav_failed = False
for i, step in enumerate(steps):
# DESIGN-3: Block sensitive ops after failed navigate
if nav_failed and not stop_on_error and step.tool in _SENSITIVE_TOOLS:
results.append(
StepResult(
step=i,
tool=step.tool,
ok=False,
error="blocked_by_failed_navigate: refusing to execute sensitive "
"operation after navigation failure",
)
)
continue
t0 = time.monotonic()
try:
result = await dispatch_fn(step, ref_map)
wall_ms = int((time.monotonic() - t0) * 1000)
results.append(StepResult(step=i, tool=step.tool, ok=True, wall_ms=wall_ms, data=result))
# DESIGN-4: Each observe REPLACES the entire ref_map (not merges)
if step.tool == "observe" and result and "elements" in result:
ref_map = {elem["ref"]: elem for elem in result["elements"]}
except Exception as e:
wall_ms = int((time.monotonic() - t0) * 1000)
results.append(StepResult(step=i, tool=step.tool, ok=False, wall_ms=wall_ms, error=str(e)))
if step.tool == "navigate":
nav_failed = True
if stop_on_error:
break
return ExecuteResult(
steps_completed=len(results),
steps_total=len(steps),
results=results,
error_step=next((r.step for r in results if not r.ok), None),
)